gpt4 book ai didi

tensorflow - 在 tf.where() 中使用 == 条件的问题

转载 作者:行者123 更新时间:2023-12-04 01:31:52 24 4
gpt4 key购买 nike

我最近开始使用 tensorflow 并尝试使用 tf.where() 函数。我注意到每当我使用“==”条件时它都会抛出错误。例如,当我尝试以下操作时:

t = tf.constant([[1, 2, 3], 
[4, 5, 6]])

t2 = tf.where(t==2)
t3 = tf.gather_nd(t,t2)

t3_shape= tf.shape(t)[0]

with tf.Session() as sess:
print(sess.run([t3]))

它抛出以下错误:

InvalidArgumentError: WhereOp : Unhandled input dimensions: 0



谁能解释一下这里可能有什么错误?
提前致谢!

最佳答案

您需要 tf.equal 进行元素比较:

t2 = tf.where(tf.equal(t, 2))

t = tf.constant([[1, 2, 3],
[4, 5, 6]])

t2 = tf.where(tf.equal(t, 2))
t3 = tf.gather_nd(t,t2)
t3_shape= tf.shape(t)[0]

with tf.Session() as sess:
print(sess.run([t3]))

# [array([2])]

关于tensorflow - 在 tf.where() 中使用 == 条件的问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51488866/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com