gpt4 book ai didi

python - Tensorflow:张量上的 while 循环

转载 作者:行者123 更新时间:2023-11-28 22:14:14 25 4
gpt4 key购买 nike

我正在尝试对张量值应用 while 循环。例如,对于变量“a”,我试图逐渐增加张量的值,直到满足某个条件。但是,我不断收到此错误:

ValueError: Shape must be rank 0 but is rank 3 for 'while_12/LoopCond' (op: 'LoopCond') with input shapes: [3,1,1].

a = array([[[0.76393723]],
[[0.93270312]],
[[0.08361106]]])

a = np.random.random((3,1,1))
a1 = tf.constant(np.float64(a))
i = tf.constant(np.float64(6.14))

c = lambda i: tf.less(i, a1)
b = lambda x: tf.add(x, 0.1)
r = tf.while_loop(c, b, [a1])

最佳答案

tf.while_loop() 的第一个参数应该返回标量(等级 0 的张量实际上是一个标量 - 这就是错误消息的内容)。在您的示例中,如果 a1 张量中的所有数字都小于 6.14,您可能希望条件返回 true。这可以通过 tf.reduce_all() 来实现(逻辑与)和 tf.reduce_any() (逻辑或)。

该片段对我有用:

tf.reset_default_graph()

a = np.random.random_integers(3, size=(3,2))
print(a)
# [[1 1]
# [2 3]
# [1 1]]

a1 = tf.constant(a)
i = 6

# condition returns True till any number in `x` is less than 6
condition = lambda x : tf.reduce_any(tf.less(x, i))
body = lambda x : tf.add(x, 1)
loop = tf.while_loop(
condition,
body,
[a1],
)

with tf.Session() as sess:
result = sess.run(loop)
print(result)
# [[6 6]
# [7 8]
# [6 6]]
# All numbers now are greater than 6

关于python - Tensorflow:张量上的 while 循环,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53483153/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com