gpt4 book ai didi

size - ValueError : Target size (torch. Size([16])) 必须与输入大小相同 (torch.Size([16, 1]))

转载 作者:行者123 更新时间:2023-12-03 17:52:37 29 4
gpt4 key购买 nike

ValueError                                Traceback (most recent call last)
<ipython-input-30-33821ccddf5f> in <module>
23 output = model(data)
24 # calculate the batch loss
---> 25 loss = criterion(output, target)
26 # backward pass: compute gradient of the loss with respect to model parameters
27 loss.backward()

C:\Users\mnauf\Anaconda3\envs\federated_learning\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
487 result = self._slow_forward(*input, **kwargs)
488 else:
--> 489 result = self.forward(*input, **kwargs)
490 for hook in self._forward_hooks.values():
491 hook_result = hook(self, input, result)

C:\Users\mnauf\Anaconda3\envs\federated_learning\lib\site-packages\torch\nn\modules\loss.py in forward(self, input, target)
593 self.weight,
594 pos_weight=self.pos_weight,
--> 595 reduction=self.reduction)
596
597

C:\Users\mnauf\Anaconda3\envs\federated_learning\lib\site-packages\torch\nn\functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
2073
2074 if not (target.size() == input.size()):
-> 2075 raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
2076
2077 return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)

ValueError: Target size (torch.Size([16])) must be the same as input size (torch.Size([16, 1]))

我正在训练一个 CNN。处理马匹与人类数据集。 This is my code .我正在使用 criterion = nn.BCEWithLogitsLoss()optimizer = optim.RMSprop(model.parameters(), lr=0.01 )。我的最后一层是 self.fc2 = nn.Linear(512, 1) .最后一个神经元,会为马输出 1,为人输出 0,对吗?还是应该选择 2 个神经元作为输出?
16是批量大小。由于错误说 ValueError: Target size (torch.Size([16])) must be the same as input size (torch.Size([16, 1])) .我不明白,我需要在哪里进行更改以纠正错误。

最佳答案

target = target.unsqueeze(1) ,在将目标传递给标准之前,将目标张量大小从 [16] 更改为至 [16,1] .这样做解决了问题。此外,我还需要做 target = target.float()在将它传递给标准之前,因为我们的输出是 float 的。此外,代码中还有另一个错误。我在最后一层使用了 sigmoid 激活函数,但我不应该使用,因为我使用的标准已经内置了 sigmoid。

关于size - ValueError : Target size (torch. Size([16])) 必须与输入大小相同 (torch.Size([16, 1])),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57798033/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com