gpt4 book ai didi

python - 在 Tensorflow 图中查找所需的占位符

转载 作者:太空宇宙 更新时间:2023-11-03 15:51:19 27 4
gpt4 key购买 nike

在 Tensorflow 中,有没有办法找到评估某个输出张量所需的所有占位符张量?也就是说,当调用 sess.run(output_tensor) 时,是否有一个函数会返回必须输入到 feed_dict 中的所有(占位符)张量?

这是我想用伪代码做的一个例子:

import tensorflow as tf

a = tf.placeholder(dtype=tf.float32,shape=())
b = tf.placeholder(dtype=tf.float32,shape=())
c = tf.placeholder(dtype=tf.float32,shape=())
d = a + b
f = b + c

# This should return [a,b] or [a.name,b.name]
d_input_tensors = get_dependencies(d)

# This should return [b,c] or [b.name,c.name]
f_input_tensors = get_dependencies(f)

编辑:澄清一下,我不是(必须)寻找图中的所有占位符,只是寻找定义特定输出张量所需的占位符。所需的占位符可能只是图中所有占位符的一个子集。

最佳答案

经过一些修补和发现 this几乎相同的 SO 问题,我想出了以下解决方案:

def get_tensor_dependencies(tensor):

# If a tensor is passed in, get its op
try:
tensor_op = tensor.op
except:
tensor_op = tensor

# Recursively analyze inputs
dependencies = []
for inp in tensor_op.inputs:
new_d = get_tensor_dependencies(inp)
non_repeated = [d for d in new_d if d not in dependencies]
dependencies = [*dependencies, *non_repeated]

# If we've reached the "end", return the op's name
if len(tensor_op.inputs) == 0:
dependencies = [tensor_op.name]

# Return a list of tensor op names
return dependencies

注意:这不仅会返回占位符,还会返回变量和常量。如果 dependencies = [tensor_op.name] 被替换为 dependencies = [tensor_op.name] if tensor_op.type == 'Placeholder' else [],那么只有占位符会被返回。

关于python - 在 Tensorflow 图中查找所需的占位符,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46657891/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com