gpt4 book ai didi

lstm - Pytorch LSTM : Target Dimension in Calculating Cross Entropy Loss

转载 作者:行者123 更新时间:2023-12-01 01:43:48 25 4
gpt4 key购买 nike

我一直在尝试在 Pytorch 中使用 LSTM(LSTM 后跟自定义模型中的线性层),但在计算损失时出现以下错误:
Assertion cur_target >= 0 && cur_target < n_classes' failed.
我定义了损失函数:

criterion = nn.CrossEntropyLoss()

然后用
loss += criterion(output, target)

我给目标的维度是 [sequence_length, number_of_classes],输出的维度是 [sequence_length, 1, number_of_classes]。

我所遵循的示例似乎在做同样的事情,但在 Pytorch docs on cross entropy loss. 上却有所不同。

文档说目标应该是维度 (N),其中每个值是 0 ≤ targets[i] ≤ C−1 并且 C 是类的数量。我将目标更改为该形式,但现在我收到一条错误消息(序列长度为 75,并且有 55 个类):
Expected target size (75, 55), got torch.Size([75])

我已经尝试查看这两个错误的解决方案,但仍然无法正常工作。我对目标的正确尺寸以及第一个错误背后的实际含义感到困惑(不同的搜索对错误给出了非常不同的含义,没有任何修复工作)。

谢谢

最佳答案

您可以使用 squeeze()在您的 output张量,这将返回一个张量,其中删除了大小为 1 的所有维度。

此简短代码使用您在问题中提到的形状:

sequence_length   = 75
number_of_classes = 55
# creates random tensor of your output shape
output = torch.rand(sequence_length, 1, number_of_classes)
# creates tensor with random targets
target = torch.randint(55, (75,)).long()

# define loss function and calculate loss
criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
print(loss)

导致您描述的错误:

ValueError: Expected target size (75, 55), got torch.Size([75])

所以使用 squeeze()在您的 output张量通过使其正确形状来解决您的问题。

修正形状的示例:

sequence_length   = 75
number_of_classes = 55
# creates random tensor of your output shape
output = torch.rand(sequence_length, 1, number_of_classes)
# creates tensor with random targets
target = torch.randint(55, (75,)).long()

# define loss function and calculate loss
criterion = nn.CrossEntropyLoss()

# apply squeeze() on output tensor to change shape form [75, 1, 55] to [75, 55]
loss = criterion(output.squeeze(), target)
print(loss)

输出:
tensor(4.0442)

使用 squeeze()[75, 1, 55] 改变你的张量形状至 [75, 55]所以输出和目标形状匹配!

你也可以使用其他方法来 reshape 你的张量,重要的是你有 [sequence_length, number_of_classes] 的形状。而不是 [sequence_length, 1, number_of_classes] .

您的目标应该是 LongTensor分别 torch.long 类型的张量包含类。这里的形状是 [sequence_length] .

编辑:
传递给交叉熵函数时上面示例中的形状:

输出: torch.Size([75, 55])目标: torch.Size([75])
下面是一个更一般的示例,CE 的输出和目标应该是什么样的。在这种情况下,我们假设我们有 5 个不同的目标类,长度为 1、2 和 3 的序列有三个示例:

# init CE Loss function
criterion = nn.CrossEntropyLoss()

# sequence of length 1
output = torch.rand(1, 5)
# in this case the 1th class is our target, index of 1th class is 0
target = torch.LongTensor([0])
loss = criterion(output, target)
print('Sequence of length 1:')
print('Output:', output, 'shape:', output.shape)
print('Target:', target, 'shape:', target.shape)
print('Loss:', loss)

# sequence of length 2
output = torch.rand(2, 5)
# targets are here 1th class for the first element and 2th class for the second element
target = torch.LongTensor([0, 1])
loss = criterion(output, target)
print('\nSequence of length 2:')
print('Output:', output, 'shape:', output.shape)
print('Target:', target, 'shape:', target.shape)
print('Loss:', loss)

# sequence of length 3
output = torch.rand(3, 5)
# targets here 1th class, 2th class and 2th class again for the last element of the sequence
target = torch.LongTensor([0, 1, 1])
loss = criterion(output, target)
print('\nSequence of length 3:')
print('Output:', output, 'shape:', output.shape)
print('Target:', target, 'shape:', target.shape)
print('Loss:', loss)

输出:

Sequence of length 1:
Output: tensor([[ 0.1956, 0.0395, 0.6564, 0.4000, 0.2875]]) shape: torch.Size([1, 5])
Target: tensor([ 0]) shape: torch.Size([1])
Loss: tensor(1.7516)

Sequence of length 2:
Output: tensor([[ 0.9905, 0.2267, 0.7583, 0.4865, 0.3220],
[ 0.8073, 0.1803, 0.5290, 0.3179, 0.2746]]) shape: torch.Size([2, 5])
Target: tensor([ 0, 1]) shape: torch.Size([2])
Loss: tensor(1.5469)

Sequence of length 3:
Output: tensor([[ 0.8497, 0.2728, 0.3329, 0.2278, 0.1459],
[ 0.4899, 0.2487, 0.4730, 0.9970, 0.1350],
[ 0.0869, 0.9306, 0.1526, 0.2206, 0.6328]]) shape: torch.Size([3, 5])
Target: tensor([ 0, 1, 1]) shape: torch.Size([3])
Loss: tensor(1.3918)

我希望这有帮助!

关于lstm - Pytorch LSTM : Target Dimension in Calculating Cross Entropy Loss,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53455780/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com