gpt4 book ai didi

python - 如何解决这个问题(Pytorch RuntimeError : 1D target tensor expected, multi-target not supported)

转载 作者:行者123 更新时间:2023-12-04 07:57:21 25 4
gpt4 key购买 nike

我是 pytorch 和深度学习的新手
我的数据集 53502 x 58,
我的代码有问题

model = nn.Sequential(
nn.Linear(58,64),
nn.ReLU(),
nn.Linear(64,32),
nn.ReLU(),
nn.Linear(32,16),
nn.ReLU(),
nn.Linear(16,2),
nn.LogSoftmax(1)
)

criterion = nn.NLLLoss()
optimizer = optim.AdamW(model.parameters(), lr = 0.0001)
epoch = 500
train_cost, test_cost = [], []
for i in range(epoch):
model.train()
cost = 0
for feature, target in trainloader:
output = model(feature) #feedforward
loss = criterion(output, target) #loss
loss.backward() #backprop

optimizer.step() #update weight
optimizer.zero_grad() #zero grad

cost += loss.item() * feature.shape[0]
train_cost.append(cost / len(train_set))

with torch.no_grad():
model.eval()
cost = 0
for feature, target in testloader:
output = model(feature) #feedforward
loss = criterion(output, target) #loss

cost += loss.item() * feature.shape
test_cost.append(cost / len(test_set))

print(f'\repoch {i+1}/{epoch} | train_cost: {train_cost[-1]} | test_cost : {test_cost[-1]}', end = "")
然后我遇到了这样的问题
   2262                          .format(input.size(0), target.size(0)))
2263 if dim == 2:
-> 2264 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
2265 elif dim == 4:
2266 ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

RuntimeError: 1D target tensor expected, multi-target not supported
怎么了?
如何解决这个问题呢?
为什么会这样?
非常感谢您提前!

最佳答案

使用时 NLLLoss目标张量必须包含标签的索引表示而不是 one-hot。例如:
我想这就是你的目标的样子:

target = [0, 0, 1, 0]
只需将其转换为仅作为 1 索引的数字即可:
[0, 0, 1, 0] -> [2]
[1, 0, 0, 0] -> [0]
[0, 0, 0, 1] -> [3]
然后将其转换为长张量,即:
target = [2]
target = torch.Tensor(target).type(torch.LongTensor)
可能会令人困惑,您的输出是一个具有类长度的张量,而您的目标是一个数字,但事实就是如此。
您可以自己查看 here .

关于python - 如何解决这个问题(Pytorch RuntimeError : 1D target tensor expected, multi-target not supported),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66635987/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com