gpt4 book ai didi

python - 使用相同大小的索引张量拆分 torch 张量

转载 作者:行者123 更新时间:2023-12-05 02:01:15 40 4
gpt4 key购买 nike

假设我有张量

t = torch.tensor([1,2,3,4,5])

我想使用相同大小的索引张量来拆分它,该张量告诉我每个元素应该进行哪个拆分。

indices = torch.tensor([0,1,1,0,2])

这样最后的结果就是

splits
[tensor([1,4]), tensor([2,3]), tensor([5])]

在 Pytorch 中有没有一种巧妙的方法来做到这一点?

编辑:通常会有超过 2 或 3 次拆分。

最佳答案

对于一般情况,可以使用 argsort 来完成:

def mask_split(tensor, indices):
sorter = torch.argsort(indices)
_, counts = torch.unique(indices, return_counts=True)
return torch.split(t[sorter], counts.tolist())


mask_split(t, indices)

虽然如果这是您的真实用例,使用@flawr answer 可能会更好(list comprehension 也可能更快,因为它不需要排序),像这样:

def mask_split(tensor, indices):
unique = torch.unique(indices)
return [tensor[indices == i] for i in unique]

关于python - 使用相同大小的索引张量拆分 torch 张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66736492/

40 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com