gpt4 book ai didi

deep-learning - 多标签分类Focal loss的实现

转载 作者:行者123 更新时间:2023-12-03 16:50:45 40 4
gpt4 key购买 nike

尝试为多标签分类编写焦点损失

class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.25):
self._gamma = gamma
self._alpha = alpha

def forward(self, y_true, y_pred):
cross_entropy_loss = torch.nn.BCELoss(y_true, y_pred)
p_t = ((y_true * y_pred) +
((1 - y_true) * (1 - y_pred)))
modulating_factor = 1.0
if self._gamma:
modulating_factor = torch.pow(1.0 - p_t, self._gamma)
alpha_weight_factor = 1.0
if self._alpha is not None:
alpha_weight_factor = (y_true * self._alpha +
(1 - y_true) * (1 - self._alpha))
focal_cross_entropy_loss = (modulating_factor * alpha_weight_factor *
cross_entropy_loss)
return focal_cross_entropy_loss.mean()

但是当我运行这个时,我得到
  File "train.py", line 82, in <module>
loss = loss_fn(output, target)
File "/home/bubbles/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 538, in __call__
for hook in self._forward_pre_hooks.values():
File "/home/bubbles/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 591, in __getattr__
type(self).__name__, name))
AttributeError: 'FocalLoss' object has no attribute '_forward_pre_hooks'

任何建议都会非常有帮助,在此先感谢。

最佳答案

你不应该继承 torch.nn.Module因为它是为具有可学习参数的模块(例如神经网络)而设计的。

只需创建普通的仿函数或函数就可以了。

顺便提一句。如果你继承它,你应该调用 super().__init__()在您的__init__() 中的某处.

关于deep-learning - 多标签分类Focal loss的实现,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57635169/

40 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com