gpt4 book ai didi

python - 在 PyTorch 中,nll_loss 的输入是什么?

转载 作者:行者123 更新时间:2023-12-01 07:30:22 28 4
gpt4 key购买 nike

我正在看这里的教程:https://pytorch.org/tutorials/beginner/fgsm_tutorial.html

import torch.nn.functional as F
loss = F.nll_loss(output, target)

上面两行代码中,“目标”到底是什么?他们加载目标数据集,但从不讨论它到底是什么。文档也很难理解。

最佳答案

通过运行以下代码来检查自己:

test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
])),
batch_size=1, shuffle=True)
for data, target in test_loader:
print(data, target)
break

这里,data 基本上是灰度 MNIST 图像,target09 之间的标签。

因此,在 loss = F.nll_loss(output, target) 中,output 是模型预测(模型在给出图像/数据时预测的内容),并且 target 是给定图像的实际标签。

此外,在上面的示例中,检查以下行:

output = model(data) # shape [1, 10]
init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability

# If the initial prediction is wrong, don't bother attacking, just move on
if init_pred.item() != target.item():
continue

# Calculate the loss
loss = F.nll_loss(output, target)

在上面的代码中,只有那些 output-target 对被传递到 F.nll_loss 损失函数,其中模型预测正确。如果无法正确预测标签,则跳过此后的所有操作(包括损失计算)并继续 test_loader 中的下一个示例。

关于python - 在 PyTorch 中,nll_loss 的输入是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57229669/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com