gpt4 book ai didi

python - tf.while_loop 中的副作用

转载 作者:太空宇宙 更新时间:2023-11-04 04:45:44 24 4
gpt4 key购买 nike

我目前很难理解 tensorflow 的工作原理,而且我觉得 python 界面有点晦涩难懂。

我最近尝试在 tf.while_loop 中运行一个简单的 print 语句,但有很多事情我仍然不清楚:

import tensorflow as tf

nb_iter = tf.constant(value=10)
#This solution does not work at all
#nb_iter = tf.get_variable('nb_iter', shape=(1), dtype=tf.int32, trainable=False)
i = tf.get_variable('i', shape=(), trainable=False,
initializer=tf.zeros_initializer(), dtype=nb_iter.dtype)

loop_condition = lambda i: tf.less(i, nb_iter)
def loop_body(i):
tf.Print(i, [i], message='Another iteration')
return [tf.add(i, 1)]

i = tf.while_loop(loop_condition, loop_body, [i])

initializer_op = tf.global_variables_initializer()

with tf.Session() as sess:
sess.run(initializer_op)
res = sess.run(i)
print('res is now {}'.format(res))

请注意,如果我用

初始化 nb_iter
nb_iter = tf.get_variable('nb_iter', shape=(1), dtype=tf.int32, trainable=False)

我收到以下错误:

ValueError: Shape must be rank 0 but is rank 1 for 'while/LoopCond' (op: 'LoopCond') with input shapes: [1].

当我尝试使用 'i' 索引来索引张量时,情况变得更糟(示例未在此处显示),然后出现以下错误

alueError: Operation 'while/strided_slice' has been marked as not fetchable.

有人能给我指点一份文档,解释 tf.while_loop 在与 tf.Variables 一起使用时如何工作,以及是否可以在循环内使用 side_effects(如打印),以及使用循环变量索引张量?

预先感谢您的帮助

最佳答案

我的第一个例子实际上有很多问题:

如果运算符没有副作用(即 i = tf.Print()),则不执行 tf.Print

如果 bool 值是标量,则它是 0 阶张量,而不是 1 阶张量。 ...

这是有效的代码:

import tensorflow as tf

#nb_iter = tf.constant(value=10)
#This solution does not work at all
nb_iter = tf.get_variable('nb_iter', shape=(), dtype=tf.int32, trainable=False,
initializer=tf.zeros_initializer())
nb_iter = tf.add(nb_iter,10)
i = tf.get_variable('i', shape=(), trainable=False,
initializer=tf.zeros_initializer(), dtype=nb_iter.dtype)
v = tf.get_variable('v', shape=(10), trainable=False,
initializer=tf.random_uniform_initializer, dtype=tf.float32)

loop_condition = lambda i: tf.less(i, nb_iter)
def loop_body(i):
i = tf.Print(i, [v[i]], message='Another vector element: ')
return [tf.add(i, 1)]

i = tf.while_loop(loop_condition, loop_body, [i])

initializer_op = tf.global_variables_initializer()

with tf.Session() as sess:
sess.run(initializer_op)
res = sess.run(i)
print('res is now {}'.format(res))

输出:

Another vector element: [0.203766704]
Another vector element: [0.692927241]
Another vector element: [0.732221603]
Another vector element: [0.0556482077]
Another vector element: [0.422092319]
Another vector element: [0.597698212]
Another vector element: [0.92387116]
Another vector element: [0.590101123]
Another vector element: [0.741415381]
Another vector element: [0.514917374]
res is now 10

关于python - tf.while_loop 中的副作用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49686860/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com