gpt4 book ai didi

python - Pytorch 张量索引 : How to gather rows by tensor containing indices

转载 作者:行者123 更新时间:2023-12-04 02:48:56 27 4
gpt4 key购买 nike

我有张量:
ID :形状 (7000,1) 包含索引,如 [[1],[0],[2],...] x : 形状(7000, 3 ,255)ids张量编码x的粗体标记维度的索引应该选择哪个。
我想在结果向量中收集选定的切片:
结果:形状 (7000,255)
背景:
对于 3 个元素中的每一个,我都有一些分数(形状 = (7000,3)),并且只想选择得分最高的那个。因此,我使用了该功能

ids = torch.argmax(scores,1,True)
给我最大的ID。我已经尝试使用收集功能来做到这一点:
result = x.gather(1,ids)
但这没有用。

最佳答案

这是您可能正在寻找的解决方案

ids = ids.repeat(1, 255).view(-1, 1, 255)

一个例子如下:

x = torch.arange(24).view(4, 3, 2) 
"""
tensor([[[ 0, 1],
[ 2, 3],
[ 4, 5]],

[[ 6, 7],
[ 8, 9],
[10, 11]],

[[12, 13],
[14, 15],
[16, 17]],

[[18, 19],
[20, 21],
[22, 23]]])
"""
ids = torch.randint(0, 3, size=(4, 1))
"""
tensor([[0],
[2],
[0],
[2]])
"""
idx = ids.repeat(1, 2).view(4, 1, 2)
"""
tensor([[[0, 0]],

[[2, 2]],

[[0, 0]],

[[2, 2]]])
"""

torch.gather(x, 1, idx)
"""
tensor([[[ 0, 1]],

[[10, 11]],

[[12, 13]],

[[22, 23]]])
"""

关于python - Pytorch 张量索引 : How to gather rows by tensor containing indices,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55881002/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com