gpt4 book ai didi

python - 什么是PyTorch中的运行损失以及如何计算

转载 作者:行者123 更新时间:2023-12-03 13:44:00 25 4
gpt4 key购买 nike

我查看了PyTorch文档中的this教程,以了解迁移学习。我只有一条线听不懂。

使用loss = criterion(outputs, labels)计算损失后,使用running_loss += loss.item() * inputs.size(0)计算运行损失,最后使用running_loss / dataset_sizes[phase]计算纪元损失。
loss.item()是否不应该用于整个微型批处理(如果我错了,请纠正我)。也就是说,如果batch_size为4,则loss.item()会给整个4张图片带来损失。如果是这样,为什么在计算loss.item()时将inputs.size(0)running_loss相乘?在这种情况下,这一步骤难道不是一个额外的乘法吗?

任何帮助,将不胜感激。谢谢!

最佳答案

这是因为 CrossEntropy 或其他损失函数给定的损失除以元素数,即默认情况下减少参数为mean

torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')



因此, loss.item()包含整个微型批处理的损失,但除以批处理大小。这就是为什么 loss.item()乘以批处理大小(由 inputs.size(0)给出),同时计算 running_loss的原因。

关于python - 什么是PyTorch中的运行损失以及如何计算,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61092523/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com