gpt4 book ai didi

python - Pytorch 将张量转换为一个热

转载 作者:行者123 更新时间:2023-12-02 19:28:03 25 4
gpt4 key购买 nike

将用 n 值填充的形状张量 (batch_size, height, width) 转换为形状张量 (batch_size, n, height, width) 的最简单方法是什么?我在下面创建了解决方案,但看起来有更简单快捷的方法可以做到这一点


def batch_tensor_to_onehot(tnsr, classes):
tnsr = tnsr.unsqueeze(1)
res = []
for cls in range(classes):
res.append((tnsr == cls).long())
return torch.cat(res, dim=1)

最佳答案

您可以使用 torch.nn.functional.one_hot .

针对您的情况:

a = torch.nn.functional.one_hot(tnsr, num_classes=classes)
out = a.permute(0, 3, 1, 2)

关于python - Pytorch 将张量转换为一个热,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62245173/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com