gpt4 book ai didi

python - Tensorflow:如何在使用 tf.where() 时保持批量维度?

转载 作者:行者123 更新时间:2023-12-03 21:00:45 34 4
gpt4 key购买 nike

我正在尝试选择与零不同的元素,然后再使用它们。我的输入张量有批量维度,所以我想保留它并且不要在批量中混合数据。我想 tf.gather_nd()对我有用,但首先我必须获得所需数据的索引,我发现 tf.where() .我尝试了以下方法:

img = tf.constant([[[1., 0., 0.], 
[0., 0., 2.],
[0., 3, 0.]],
[[1., 2., 3.],
[0., 0., 1.],
[0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]

indexes = tf.where(tf.not_equal(img, 0.))

我希望 indexes保持批次尺寸,但它具有形状 [7, 2] .我怀疑问题来自不同批次中满足条件的不同点数。

有没有办法让索引保持批量维度?提前致谢。

编辑: indexes有形状 [7, 3]其中第一个昏暗是指点的数量,第二个昏暗是指点的位置(包括它属于哪个批次)。但我需要 indexes具有特定的批次维度,因为稍后我想用它来收集来自 img 的数据:

Y = tf.gather_nd(img, indexes)

我要 Y具有批量维度,但为 indexes没有,我得到了一个扁平张量,其中混合了来自不同批次的数据。

最佳答案

实际上,您可能做错了什么:当我运行您的代码时,indexes是尺寸 (7,3)而不是 (7,2) . 3对应于您的 3 个维度,而 7对应于 img 中非零元素的数量.
sess.run(indexes) 的完整结果:

array([[0, 0, 0],
[0, 1, 2],
[0, 2, 1],
[1, 0, 0],
[1, 0, 1],
[1, 0, 2],
[1, 1, 2]])

关于python - Tensorflow:如何在使用 tf.where() 时保持批量维度?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57921044/

34 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com