gpt4 book ai didi

python - 使用 PyTorch 的交叉熵损失函数是否需要 One-Hot 编码?

转载 作者:行者123 更新时间:2023-12-03 15:54:35 24 4
gpt4 key购买 nike

例如,如果我想解决 MNIST 分类问题,我们有 10 个输出类。对于 PyTorch,我想使用 torch.nn.CrossEntropyLoss功能。我是否必须格式化目标以便它们是单热编码的,还是我可以简单地使用数据集附带的类标签?

最佳答案

nn.CrossEntropyLoss需要整数标签。它在内部所做的是,它根本不会对类标签进行一次性编码,而是使用标签索引到输出概率向量中,以计算您决定使用此类作为最终标签时的损失。这个小而重要的细节使计算损失更容易,并且是执行单热编码的等效操作,测量每个输出神经元的输出损失,因为输出层中的每个值都为零,但在目标类中索引的神经元除外.因此,如果您已经提供了标签,则无需对数据进行一次性编码。
文档对此有更多见解:https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html .在文档中你会看到 targets它作为输入参数的一部分。这些是您的标签,它们被描述为:
Targets
这清楚地显示了应该如何塑造输入以及预期的内容。如果您实际上想对数据进行单热编码,则需要使用 torch.nn.functional.one_hot .为了最好地复制交叉熵损失在幕后所做的事情,您还需要 nn.functional.log_softmax作为最终输出,您必须另外编写自己的损失层,因为没有一个 PyTorch 层使用对数 softmax 输入和单热编码目标。然而,nn.CrossEntropyLoss将这两种操作结合​​在一起,如果您的输出只是类标签,则是首选,因此无需进行转换。

关于python - 使用 PyTorch 的交叉熵损失函数是否需要 One-Hot 编码?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62456558/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com