作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
嗨,只是在玩代码,我得到了交叉熵损失权重实现的意外结果。
pred=torch.tensor([[8,5,3,2,6,1,6,8,4],[2,5,1,3,4,6,2,2,6],[1,1,5,8,9,2,5,2,8],[2,2,6,4,1,1,7,8,3],[2,2,2,7,1,7,3,4,9]]).float()
label=torch.tensor([[3],[7],[8],[2],[5]],dtype=torch.int64)
weights=torch.tensor([1,1,1,10,1,6,1,1,1],dtype=torch.float32)
使用这种样本变量,pytorch 的交叉熵损失为 4.7894
loss = F.cross_entropy(pred, label, weight=weights,reduction='mean')
> 4.7894
我手动实现了交叉熵损失代码,如下所示
one_hot = torch.zeros_like(pred).scatter(1, label.view(-1, 1), 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).sum(dim=1).mean()
如果没有权重值,这种实现与 pytorch 的交叉熵函数给出相同的结果。但是有重量值
one_hot = torch.zeros_like(pred).scatter(1, label.view(-1, 1), 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb)*weights.sum(dim=1).sum()/weights.sum()
> 3.9564
它使用 pytorch 模块(4.7894)给出不同的损失值。
最佳答案
我发现了问题。这很简单...
我不应该除以权重的总和。
而不是用 wt.sum()
分割( wt=one_hot*weight
) 得到了 4.7894。
>>> wt = one_hot*weights
>>> loss = -(one_hot * log_prb * weights).sum(dim=1).sum() / wt.sum()
4.7894
分母仅与“相关”权重值有关,而不是整体。
关于python - 使用权重手动计算的交叉熵损失,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68727252/
我是一名优秀的程序员,十分优秀!