gpt4 book ai didi

python - 如何在 TensorFlow while_loop 中赋值

转载 作者:太空宇宙 更新时间:2023-11-03 14:57:15 24 4
gpt4 key购买 nike

目标是在 TensorFlow 中实现一个循环函数,以随时间过滤信号。

input随后呈现为 ​​[batch, in_depth, in_height, in_width, in_channels] 形式的 5 维张量。我想使用 tf.while_loop 迭代in_depth并根据先前时间步的值重新分配值。但是,我无法在循环内重新分配变量值。

为了简化问题,我创建了问题的一维版本:

def condition(i, signal):
return tf.less(i, signal.shape[0])

def operation(i, signal):
signal = tf.get_variable("signal")
signal = signal[i].assign(signal[i-1]*2)
i = tf.add(i, 1)
return (i, signal)

with tf.variable_scope("scope"):
i = tf.constant(1)
init = tf.constant_initializer(0)
signal = tf.get_variable("scope", [4], tf.float32, init, trainable = False)
signal = tf.assign(signal[0], 1.2)

with tf.variable_scope("scope", reuse = True):
loops_vars = [i, signal]
i, signal = tf.while_loop(condition, operation, loop_vars, back_prop = False)

with tf.Session() as session:
session.run(tf.global_variables_initializer())
i, signal = session.run([i, signal])

tf.assign 返回一个操作,该操作必须在 session 中运行才能进行评估 ( see here for further details )。

我预计,TensorFlow 会在循环内链接操作,因此一旦我运行 session 并请求 signal 就会执行分配。但是,当我执行给定的代码并打印结果时,signal包含[1.2, 0, 0, 0]i包含(如预期)4 .

我的误解是什么以及如何更改代码以使 signal 的值被重新分配?

最佳答案

虽然循环变量仅通过主体函数的返回值进行更新,但您不应使用自己的赋值操作。相反,您需要返回希望 signal 在循环后具有的值,就像 i 一样。

此外,您不应在正文或条件中使用 tf.get_variable,而应使用您收到的参数。

# ...
def operation(i, signal):
shape = signal.shape
signal = tf.concat([signal[:i], [signal[i - 1] * 2], signal[i + 1:]], axis=0)
signal.set_shape(shape) # Shapes have to be invariant for the loop
i = tf.add(i, 1)
return (i, signal)

with tf.variable_scope("scope"):
i = tf.constant(1)
init = tf.constant_initializer(1.2) # init signal here and avoid tf.assign
signal = tf.get_variable("scope", [4], tf.float32, init, trainable = False)
# signal = tf.assign(signal[0], 1.2)

# ...

关于python - 如何在 TensorFlow while_loop 中赋值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45437680/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com