gpt4 book ai didi

pytorch - 正确使用交叉熵作为元素序列的损失函数

转载 作者:行者123 更新时间:2023-12-05 05:55:46 29 4
gpt4 key购买 nike

我有一个序列标记任务。

作为输入,我有一个形状为 [batch_size, sequence_length] 的元素序列,并且该序列的每个元素都应该分配有某个类。

作为神经网络训练过程中的损失函数,我使用 Cross-entropy .

如何正确使用?我的变量 target_predictions 的形状为 [batch_size, sequence_length, number_of_classes]target 的形状为 [batch_size, sequence_length]

文档说:

enter image description here

我知道如果我使用 CrossEntropyLoss(target_predictions.permute(0, 2, 1), target),一切都会正常进行。但我担心 torch 将我的 sequence_length 解释为屏幕截图中的 d_1 变量,并认为这是多维损失,但事实并非如此。

我应该如何正确操作?

最佳答案

使用 CE Loss 会给你损失而不是标签。默认情况下,将采用平均值,这可能是您所追求的,并且带有置换的代码段会很好(使用此损失,您可以通过向后训练您的神经网络)。

要获得预测的类,只需在适当的维度上取 argmax,在没有排列的情况下,它将是:

labels = torch.argmax(target_predictions, dim=-1)

这将为您提供包含类的 (batch, sequence_length) 输出。

关于pytorch - 正确使用交叉熵作为元素序列的损失函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69367671/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com