gpt4 book ai didi

nn.CrossEntropyLoss() 的 Pytorch 输入

转载 作者:行者123 更新时间:2023-12-04 15:29:22 24 4
gpt4 key购买 nike

我正在尝试在 PyTorch 中对一个简单的 0,1 标记数据集执行逻辑回归。标准或损失定义为:criterion = nn.CrossEntropyLoss() .型号为:model = LogisticRegression(1,2)
我有一个数据点,它是一对:dat = (-3.5, 0) ,第一个元素是数据点,第二个元素是相应的标签。
然后我将输入的第一个元素转换为张量:tensor_input = torch.Tensor([dat[0]]) .
然后我将模型应用于 tensor_input:outputs = model(tensor_input) .
然后我将标签转换为张量:tensor_label = torch.Tensor([dat[1]]) .
现在,当我尝试这样做时,事情会中断:loss = criterion(outputs, tensor_label) .它给出错误:RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

import torch
import torch.nn as nn

class LogisticRegression(nn.Module):
def __init__(self, input_size, num_classes):
super(LogisticRegression, self).__init__()
self.linear = nn.Linear(input_size, num_classes)

def forward(self, x):
out = self.linear(x)
return out

model = LogisticRegression(1,2)
criterion = nn.CrossEntropyLoss()
dat = (-3.5,0)
tensor_input = torch.Tensor([dat[0]])
outputs = binary_model(tensor_input)
tensor_label = torch.Tensor([dat[1]])
loss = criterion(outputs, tensor_label)

我一辈子都搞不清楚。

最佳答案

在大多数情况下,PyTorch 文档在解释不同功能方面做得非常出色;它们通常包括预期的输入维度,以及一些简单的例子。
您可以找到 nn.CrossEntropyLoss() 的说明here .

为了完成您的具体示例,让我们首先查看预期的输入维度:

Input: (N,C) where C = number of classes. [...]



除此之外,N 通常是指批量大小(样本数)。将此与您目前拥有的进行比较:
outputs.shape
>>> torch.Size([2])

IE。目前我们只有 (2,) 的输入维度,而不是 (1,2) ,正如 PyTorch 所预期的那样。我们可以通过简单地使用 .unsqueeze() 为我们当前的张量添加一个“假”维度来缓解这个问题。像这样:
outputs = binary_model(tensor_input).unsqueeze(dim=0)
outputs.shape
>>> torch.Size([1,2])

现在我们知道了,让我们看看目标的预期输入:

Target: (N) [...]



所以我们已经得到了正确的形状。但是,如果我们尝试这样做,我们仍然会遇到错误:
RuntimeError: Expected object of scalar type Long but got scalar type Float 
for argument #2 'target'.

同样,错误消息相当有表现力。这里的问题是 PyTorch 张量(默认情况下)被解释为 torch.FloatTensors ,但输入应该是整数(或 Long)。我们可以通过在张量创建期间指定确切类型来简单地做到这一点:
tensor_label = torch.LongTensor([dat[1]])

我在 Linux 下使用 PyTorch 1.0 仅供引用。

关于nn.CrossEntropyLoss() 的 Pytorch 输入,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53936136/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com