gpt4 book ai didi

python - Pytorch:有没有类似torch.argmax的函数,真的可以保持原始数据的维度?

转载 作者:太空宇宙 更新时间:2023-11-04 04:28:24 31 4
gpt4 key购买 nike

例如,代码是

input = torch.randn(3, 10)
result = torch.argmax(input, dim=0, keepdim=True)

输入

tensor([[ 1.5742,  0.8183, -2.3005, -1.1650, -0.2451],
[ 1.0553, 0.6021, -0.4938, -1.5379, -1.2054],
[-0.1728, 0.8372, -1.9181, -0.9110, 0.2422]])

结果

tensor([[ 0,  2,  1,  2,  2]])

但是,我想要这样的结果

tensor([[ 1,  0,  0,  0,  0],
[ 0, 0, 1, 0, 0],
[ 0, 1, 0, 1, 1]])

最佳答案

终于解决了。但这种解决方案可能效率不高。代码如下,

input = torch.randn(3, 10)
result = torch.argmax(input, dim=0, keepdim=True)
result_0 = result == 0
result_1 = result == 1
result_2 = result == 2
result = torch.cat((result_0, result_1, result_2), 0)

关于python - Pytorch:有没有类似torch.argmax的函数,真的可以保持原始数据的维度?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53116477/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com