gpt4 book ai didi

python - tf.while_loop 并行运行时给出错误结果

转载 作者:太空宇宙 更新时间:2023-11-03 21:44:11 25 4
gpt4 key购买 nike

我想按行更新 tensorflow 中tf.while_loop内的二维tf.variable。因此,我使用 tf.assign 方法。问题是我的实现和 parallel_iterations>1 结果是错误的。使用 parallel_iterations=1 结果是正确的。代码是这样的:

a = tf.Variable(tf.zeros([100, 100]), dtype=tf.int64)

i = tf.constant(0)
def condition(i, var):
return tf.less(i, 100)

def body(i, var):
updated_row = method() # This method returns a [1, 100] tensor which is the updated row for the variable
temp = tf.assign(a[i], updated_row)
return [tf.add(i, 1), temp]

z = tf.while_loop(condition, body, [i, a], back_prop=False, parallel_iterations=10)

迭代是完全独立的,我不知道问题是什么。

奇怪的是,如果我像这样更改代码:

a = tf.Variable(tf.zeros([100, 100]), dtype=tf.int64)

i = tf.constant(0)
def condition(i, var):
return tf.less(i, 100)

def body(i, var):
zeros = lambda: tf.zeros([100, 100], dtype=tf.int64)
temp = tf.Variable(initial_value=zeros, dtype=tf.int64)
updated_row = method() # This method returns a [1, 100] tensor which is the updated row for the variable
temp = tf.assign(temp[i], updated_row)
return [tf.add(i, 1), temp]

z = tf.while_loop(condition, body, [i, a], back_prop=False, parallel_iterations=10)

该代码给出了 parallel_iterations>1 的正确结果。有人可以解释一下这里发生了什么,并给我一个有效的解决方案来更新变量,因为我要更新的原始变量很大,而我找到的解决方案效率非常低。

最佳答案

您不需要为此使用变量,您只需在循环体上生成行更新张量即可:

import tensorflow as tf

def method(i):
# Placeholder logic
return tf.cast(tf.range(i, i + 100), tf.float32)

def condition(i, var):
return tf.less(i, 100)

def body(i, var):
# Produce new row
updated_row = method(i)
# Index vector that is 1 only on the row to update
idx = tf.equal(tf.range(tf.shape(a)[0]), i)
idx = tf.cast(idx[:, tf.newaxis], var.dtype)
# Compose the new tensor with the old one and the new row
var_updated = (1 - idx) * var + idx * updated_row
return [tf.add(i, 1), var_updated]

# Start with zeros
a = tf.zeros([100, 100], tf.float32)
i = tf.constant(0)
i_end, a_updated = tf.while_loop(condition, body, [i, a], parallel_iterations=10)

with tf.Session() as sess:
print(sess.run(a_updated))

输出:

[[  0.   1.   2. ...  97.  98.  99.]
[ 1. 2. 3. ... 98. 99. 100.]
[ 2. 3. 4. ... 99. 100. 101.]
...
[ 97. 98. 99. ... 194. 195. 196.]
[ 98. 99. 100. ... 195. 196. 197.]
[ 99. 100. 101. ... 196. 197. 198.]]

关于python - tf.while_loop 并行运行时给出错误结果,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52611575/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com