gpt4 book ai didi

python - Tensorflow tf.scatter_update不更新浮点值

转载 作者:太空宇宙 更新时间:2023-11-03 21:38:59 24 4
gpt4 key购买 nike

如果变量的类型定义为 float,则 tf.scatter_update 似乎不会更新变量。您可以尝试以下代码:

import tensorflow as tf

def cond(size, i):
return tf.less(i,size)

def body(size, i):
b=2*7.5+c

with tf.variable_scope("a", reuse=tf.AUTO_REUSE):
a = tf.get_variable("a",[6],dtype=tf.float32)

a = tf.scatter_update(a,i,b)

with tf.control_dependencies([a]):
return (size, i+1)

with tf.Session() as sess:
c=tf.constant(4.0)
i = tf.constant(0)
size = tf.constant(6)
_,i = tf.while_loop(cond,
body,
[size, i])

a = tf.get_variable("a",[6],dtype=tf.float32)

init = tf.initialize_all_variables()
sess.run(init)
print(sess.run([a,i]))

结果是随机的!因为我没有故意初始化变量来查看它如何更新,所以它似乎永远不会更新,并且每次都会打印出随机初始化。您会看到类似这样的内容:

[数组([-0.35466522, 0.44001752, 0.21131486, -0.48532146, 0.3019274 , -0.19926369], dtype=float32), 6]

这是一个错误吗?如您所见,我仍在使用 tf.control_dependency,并且仅当变量 a 的类型设置为 float 时才会发生这种情况。

最佳答案

你的预期输出是这样的吗?

[array([19., 19., 19., 19., 19., 19.], dtype=float32), 6]

这两种模式产生了这一点。

模式 1

import tensorflow as tf

def cond(size, i):
return tf.less(i,size)

def body(size, i):
b=2*7.5+c

with tf.variable_scope("a", reuse=tf.AUTO_REUSE):
a = tf.get_variable("a",[6],dtype=tf.float32)

a = tf.scatter_update(a,i,b)
with tf.control_dependencies([a]):
return (size, i+1)

with tf.Session() as sess:
c=tf.constant(4.0)
i = tf.constant(0)
size = tf.constant(6)
_,i = tf.while_loop(cond,
body,
[size, i])

with tf.variable_scope("a", reuse=tf.AUTO_REUSE):
a = tf.get_variable("a",[6],dtype=tf.float32)

init = tf.initialize_all_variables()
sess.run(init)

print(sess.run([a,i]))

模式 2

def body(size, i):
b=2*7.5+c

a = tf.get_variable("a",[6],dtype=tf.float32)

a = tf.scatter_update(a,i,b)
#Reuse variables
tf.get_variable_scope().reuse_variables()

with tf.control_dependencies([a]):
return (size, i+1)

with tf.Session() as sess:
c=tf.constant(4.0)
i = tf.constant(0)
size = tf.constant(6)
_,i = tf.while_loop(cond,
body,
[size, i])

a = tf.get_variable("a",[6],dtype=tf.float32)

init = tf.initialize_all_variables()
sess.run(init)

print(sess.run([a,i]))

关于python - Tensorflow tf.scatter_update不更新浮点值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53029118/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com