gpt4 book ai didi

python - 尝试理解 PyTorch 中的 cross_entropy 损失

转载 作者:行者123 更新时间:2023-11-30 09:42:23 24 4
gpt4 key购买 nike

这是一个非常新手的问题,但我正在尝试解决 Torch 中的 cross_entropy 损失问题,因此我创建了以下代码:

x = torch.FloatTensor([
[1.,0.,0.]
,[0.,1.,0.]
,[0.,0.,1.]
])

print(x.argmax(dim=1))

y = torch.LongTensor([0,1,2])
loss = torch.nn.functional.cross_entropy(x, y)

print(loss)

输出以下内容:

tensor([0, 1, 2])
tensor(0.5514)

我不明白的是,如果我的输入与预期输出匹配,为什么损失不为 0?

最佳答案

这是因为您给交叉熵函数的输入不是像您那样的概率,而是使用以下公式将 logits 转换为概率:

probas = np.exp(logits)/np.sum(np.exp(logits), axis=1)

所以这里 pytorch 将在您的情况下使用的概率矩阵是:

[0.5761168847658291,  0.21194155761708547,  0.21194155761708547]
[0.21194155761708547, 0.5761168847658291, 0.21194155761708547]
[0.21194155761708547, 0.21194155761708547, 0.5761168847658291]

关于python - 尝试理解 PyTorch 中的 cross_entropy 损失,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57161524/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com