gpt4 book ai didi

tensorflow - TensorFlow 中的条件执行

转载 作者:行者123 更新时间:2023-12-04 07:35:37 30 4
gpt4 key购买 nike

如何根据条件选择执行图表的一部分?

我的网络有一部分只有在 feed_dict 中提供占位符值时才会执行.如果未提供该值,则采用备用路径。我该如何使用 tensorflow 来实现它?

以下是我的代码的相关部分:

sess.run(accuracy, feed_dict={inputs: mnist.test.images, outputs: mnist.test.labels})

N = tf.shape(outputs)
cost = 0
if N > 0:
y_N = tf.slice(h_c, [0, 0], N)
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(y_N, outputs, name='xentropy')
cost = tf.reduce_mean(cross_entropy, name='xentropy_mean')

在上面的代码中,我正在寻找可以用来代替 if N > 0: 的东西。

最佳答案

人力资源管理系统。您想要的可能是 tf.control_flow_ops.cond()
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_ops.py#L597

但这并没有导出到 tf 命名空间中,我在回答时没有检查这个接口(interface)的保证稳定性如何,但它用于已发布的模型中,所以去吧。 :)

但是:因为您实际上在构建 feed_dict 时事先知道您想要什么路径,所以您也可以采用不同的方法通过模型调用单独的路径。执行此操作的标准方法是,例如,设置如下代码:

def model(input, n_greater_than):
... cleverness ...
if n_greater_than:
... other cleverness...
return tf.reduce_mean(input)


out1 = model(input, True)
out2 = model(input, False)

然后根据您在即将运行计算并设置 feed_dict 时所知道的内容来拉出 out1 或 out2 节点。请记住,默认情况下,如果模型引用相同的变量(在 model() 函数之外创建它们),那么您基本上将有两条单独的路径通过。

您可以在卷积 mnist 示例中看到一个示例: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/image/mnist/convolutional.py#L165

如果可以的话,我很喜欢在不引入控制流依赖项的情况下这样做。

关于tensorflow - TensorFlow 中的条件执行,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33686902/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com