gpt4 book ai didi

python - Tensorflow:为什么 tf.case 给我错误的结果?

转载 作者:行者123 更新时间:2023-11-28 18:24:08 25 4
gpt4 key购买 nike

我正在尝试使用 tf.case ( https://www.tensorflow.org/api_docs/python/tf/case ) 有条件地更新张量。如图所示,当 global_step == 2 时,我尝试将 learning_rate 更新为 0.01,当 global_step == 2 时更新为 0.001 global_step == 4

但是,当 global_step == 2 时,我已经得到 learning_rate = 0.001。经过进一步检查,当 global_step == 2 时,tf.case 似乎给了我错误的结果(我得到的是 0.001 而不是 0.01)。即使 0.01 的谓词评估为 True,而 0.001 的谓词评估为 False,也会发生这种情况。

我做错了什么,还是这是一个错误?

TF 版本:1.0.0

代码:

import tensorflow as tf

global_step = tf.Variable(0, dtype=tf.int64)
train_op = tf.assign(global_step, global_step + 1)
learning_rate = tf.Variable(0.1, dtype=tf.float32, name='learning_rate')

# Update the learning_rate tensor conditionally
# When global_step == 2, update to 0.01
# When global_step == 4, update to 0.001
cases = []
case_tensors = []
for step, new_rate in [(2, 0.01), (4, 0.001)]:
pred = tf.equal(global_step, step)
fn_tensor = tf.constant(new_rate, dtype=tf.float32)
cases.append((pred, lambda: fn_tensor))
case_tensors.append((pred, fn_tensor))
update = tf.case(cases, default=lambda: learning_rate)
updated_learning_rate = tf.assign(learning_rate, update)

print tf.__version__
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in xrange(6):
print sess.run([global_step, case_tensors, update, updated_learning_rate])
sess.run(train_op)

结果:

1.0.0
[0, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1]
[1, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1]
[2, [(True, 0.0099999998), (False, 0.001)], 0.001, 0.001]
[3, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001]
[4, [(False, 0.0099999998), (True, 0.001)], 0.001, 0.001]
[5, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001]

最佳答案

这在 https://github.com/tensorflow/tensorflow/issues/8776 中得到了回答

事实证明,如果在 fn_tensors 中,lambda 返回在 lambda 外部创建的张量,则 tf.case 行为未定义。正确的用法是定义 lambda,使它们返回新创建的张量。

根据链接的 Github 问题,这种用法是必需的,因为 tf.case 必须自己创建张量,以便将张量的输入连接到谓词的正确分支。

关于python - Tensorflow:为什么 tf.case 给我错误的结果?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42728235/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com