gpt4 book ai didi

python - 批量收集/GatherND

转载 作者:行者123 更新时间:2023-12-01 09:17:08 24 4
gpt4 key购买 nike

我想知道是否有办法在 TensorFlow 中使用 gather_nd 或类似的方法执行以下操作。

我有两个张量:

  • ,形状为 [128, 100],
  • 索引,形状为 [128, 3],

其中每行索引都包含沿第二个维度的索引(对于同一行)。我想使用索引来索引。例如,我想要执行此操作的东西(使用宽松的符号来表示张量):

values  = [[0, 0, 0, 1, 1, 0, 1], 
[1, 1, 0, 0, 1, 0, 0]]
indices = [[2, 3, 6],
[0, 2, 3]]
batched_gather(values, indices) = [[0, 1, 1], [1, 0, 0]]

此操作将遍历每一行索引,并使用索引行上执行收集 行。

在 TensorFlow 中是否有一种简单的方法可以做到这一点?

谢谢!

最佳答案

不确定这是否符合“简单”的条件,但您可以使用 gather_nd 来实现:

def batched_gather(values, indices):
row_indices = tf.range(0, tf.shape(values)[0])[:, tf.newaxis]
row_indices = tf.tile(row_indices, [1, tf.shape(indices)[-1]])
indices = tf.stack([row_indices, indices], axis=-1)
return tf.gather_nd(values, indices)

解释:这个想法是构造索引向量,例如[0, 1],意思是“第0行第1列中的值”。
列索引已在函数的 indices 参数中给出。
行索引是从 0 到例如 的简单级数。 128(在您的示例中),但根据每行的列索引数重复(平铺)(在您的示例中为 3;如果此数字是,则可以对其进行硬编码而不是使用 tf.shape已修复)。
然后将行索引和列索引堆叠起来以生成索引向量。在您的示例中,生成的索引将是

array([[[0, 2],
[0, 3],
[0, 6]],

[[1, 0],
[1, 2],
[1, 3]]])

gather_nd产生所需的结果。

关于python - 批量收集/GatherND,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51143210/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com