gpt4 book ai didi

python - 寻找两个pytorch张量的非交集

转载 作者:行者123 更新时间:2023-12-02 01:45:43 24 4
gpt4 key购买 nike

提前感谢大家的帮助!我试图在 PyTorch 中做的事情类似于 numpy 的 setdiff1d 。例如,给出以下两个张量:

t1 = torch.tensor([1, 9, 12, 5, 24]).to('cuda:0')
t2 = torch.tensor([1, 24]).to('cuda:0')

预期输出应该是(排序或未排序):

torch.tensor([9, 12, 5])

理想情况下,操作在 GPU 上完成,并且 GPU 和 CPU 之间没有来回。非常感谢!

最佳答案

我遇到了同样的问题,但在使用较大的数组时,建议的解决方案太慢了。以下简单的解决方案适用于 CPU 和 GPU,并且比其他建议的解决方案要快得多:

combined = torch.cat((t1, t2))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]

关于python - 寻找两个pytorch张量的非交集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55110047/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com