gpt4 book ai didi

tensorflow - tf.scatter_update() 在 while_loop() 中如何工作

转载 作者:行者123 更新时间:2023-12-03 00:37:03 29 4
gpt4 key购买 nike

我正在尝试使用 tf.scatter_update() 更新 tf.while_loop() 内的 tf.Variable。但是,结果是初始值而不是更新后的值。这是我想要做的示例代码:

from __future__ import print_function

import tensorflow as tf

def cond(sequence_len, step):
return tf.less(step,sequence_len)

def body(sequence_len, step):

begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin,1,step,use_locking=None)

tf.get_variable_scope().reuse_variables()
return (sequence_len, step+1)

with tf.Graph().as_default():

sess = tf.Session()
step = tf.constant(0)
sequence_len = tf.constant(10)
_,step, = tf.while_loop(cond,
body,
[sequence_len, step],
parallel_iterations=10,
back_prop=True,
swap_memory=False,
name=None)

begin = tf.get_variable("begin",[3],dtype=tf.int32)

init = tf.initialize_all_variables()
sess.run(init)

print(sess.run([begin,step]))

结果为:[array([0, 0, 0], dtype=int32), 10]。不过,我认为结果应该是[0, 0, 10]。我在这里做错了什么吗?

最佳答案

这里的问题是循环体中没有任何内容依赖于您的 tf.scatter_update() 操作,因此它永远不会被执行。使其工作的最简单方法是添加对返回值更新的控制依赖:

def body(sequence_len, step): 
begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin, 1, step, use_locking=None)
tf.get_variable_scope().reuse_variables()

with tf.control_dependencies([begin]):
return (sequence_len, step+1)

请注意,这个问题并非 TensorFlow 中的循环所独有。如果您刚刚定义了一个名为 begintf.scatter_update() 操作,但对其调用 sess.run() ,或者依赖于它,那么更新就不会发生。当您使用tf.while_loop()时无法直接运行循环体中定义的操作,因此产生副作用的最简单方法是添加控件依赖项。

请注意,最终结果是[0, 9, 0]:每次迭代将当前步骤分配给begin[1],并在最后一次迭代中将值分配给begin[1]当前步骤的值为 9(当 step == 10 时条件为 false)。

关于tensorflow - tf.scatter_update() 在 while_loop() 中如何工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37063081/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com