gpt4 book ai didi

python - PyTorch [1 if x > 0.5 else 0 for x in output ] 带张量

转载 作者:行者123 更新时间:2023-12-03 22:48:23 25 4
gpt4 key购买 nike

我有一个 sigmoid 函数的列表输出作为 PyTorch 中的张量

例如

output (type) = torch.Size([4]) tensor([0.4481, 0.4014, 0.5820, 0.2877], device='cuda:0',

在进行二元分类时,我想将所有值从 0.5 变为 0,将高于 0.5 的值变为 1。

传统上,您可以使用 NumPy 数组使用列表迭代器:
output_prediction = [1 if x > 0.5 else 0 for x in outputs ]

这会起作用,但是我稍后必须将 output_prediction 转换回张量才能使用
torch.sum(ouput_prediction == labels.data)

其中labels.data 是标签的二进制张量。

有没有办法将列表迭代器与张量一起使用?

最佳答案

prob = torch.tensor([0.3,0.4,0.6,0.7])

out = (prob>0.5).float()
# tensor([0.,0.,1.,1.])

说明:在pytorch中,可以直接使用 prob>0.5得到一个 torch.bool类型的张量。然后您可以通过 .float() 转换为浮点类型。

关于python - PyTorch [1 if x > 0.5 else 0 for x in output ] 带张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58002836/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com