gpt4 book ai didi

python - Pytorch 张量索引

转载 作者:行者123 更新时间:2023-12-02 16:40:05 25 4
gpt4 key购买 nike

我目前正在努力将一些代码从tensorflow转换为pytorch,我遇到了tf.gather的问题func,pytorch中没有直接函数可以转换它。

我想做的基本上是索引,我有两个张量,特征张量形状为[minibatch, 60, 2]和索引张量[minibatch, 8],假设第一个张量是张量 A,第二个张量是 B

在Tensorflow中,直接用tf.gather(A, B, batch_dims=1)转换

如何在 pytorch 中实现这一目标?

我尝试过A[B]索引。这似乎行不通

A[0]B[0] 有效,但形状的输出为 [8, 2]

我需要[minibatch, 8, 2]的形状

如果我像 [stack, 8, 2] 这样堆叠张量,它可能会起作用,但我不知道该怎么做

tensorflow
out = tf.gather(logits, indices, batch_dims=1)
pytorch
out = A[B] -> something like this will be great

[minibatch, 8, 2]的输出形状

最佳答案

我认为您正在寻找 torch.gather

out = torch.gather(A, 1, B[..., None].expand(*B.shape, A.shape[-1]))

关于python - Pytorch 张量索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57071002/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com