gpt4 book ai didi

pytorch - Pytorch 中软标签的交叉熵

转载 作者:行者123 更新时间:2023-12-05 04:39:32 28 4
gpt4 key购买 nike

我正在尝试定义二分类问题的损失函数。但是,目标标签不是硬标签0,1,而是0~1之间的一个 float 。

Pytorch 中的 torch.nn.CrossEntropy 不支持软标签,所以我想自己写一个交叉熵函数。

我的函数是这样的

def cross_entropy(self, pred, target):
loss = -torch.mean(torch.sum(target.flatten() * torch.log(pred.flatten())))
return loss

def step(self, batch: Any):
x, y = batch
logits = self.forward(x)
loss = self.criterion(logits, y)
preds = logits
# torch.argmax(logits, dim=1)
return loss, preds, y

然而它根本不起作用。

谁能给我一个建议,我的损失函数有没有错误?

最佳答案

好像BCELoss和健壮的版本 BCEWithLogitsLoss正在“开箱即用”地处理模糊目标。他们不希望 target 是二进制的“0 到 1 之间的任何数字都可以。
请阅读文档。

关于pytorch - Pytorch 中软标签的交叉熵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70429846/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com