gpt4 book ai didi

python - tensorflow 相当于 torch.gather

转载 作者:行者123 更新时间:2023-12-01 09:05:04 41 4
gpt4 key购买 nike

我有一个形状为 (16, 4096, 3) 的张量。我有另一个形状为 (16, 32768, 3) 的索引张量。我正在尝试收集 dim=1 上的值。这最初是在 pytorch 中使用 gather function 完成的。如下图-

# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)

请注意,输出b的大小与idx的大小相同。然而,当我应用tensorflow的gather函数时,我得到了完全不同的输出。发现输出尺寸不匹配,如下所示 -

b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)

我也尝试过使用tf.gather_nd但没有成功。见下文-

b = tf.gather_nd(a, idx)
# b.shape (16, 32768)

为什么我会得到不同形状的张量? 我想得到与 pytorch 计算的形状相同的张量。

换句话说,我想知道 torch.gather 的 tensorflow 等效项。

最佳答案

对于 2D 情况,有一种方法可以做到这一点:

# a.shape (16L, 10L)
# idx.shape (16L,1)
idx = tf.stack([tf.range(tf.shape(idx)[0]),idx[:,0]],axis=-1)
b = tf.gather_nd(a,idx)

但是,对于ND情况,此方法可能非常复杂

关于python - tensorflow 相当于 torch.gather,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52129909/

41 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com