gpt4 book ai didi

python - 尽管形状相同 : if not (target. size() == input.size()) 但出现类型错误: 'int' 对象不可调用

转载 作者:行者123 更新时间:2023-12-01 00:36:23 28 4
gpt4 key购买 nike

这是我收到的错误消息。在第一行中,我输出了预测目标的形状。根据我的理解,错误是由于这些形状不一样而产生的,但在这里它们显然是一样的。

torch.Size([6890, 3]) torch.Size([6890, 3])

Traceback (most recent call last):
File "train.py", line 251, in <module>
main()
File "train.py", line 230, in main
train(net, training_dataset, targets, device, criterion, optimizer, epoch, args.epochs)
File "train.py", line 101, in train
loss = criterion(predicted, target.detach().cpu().numpy())
File "/home/hb119056/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/hb119056/.local/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 443, in forward
return F.mse_loss(input, target, reduction=self.reduction)
File "/home/hb119056/.local/lib/python3.6/site-packages/torch/nn/functional.py", line 2244, in mse_loss
if not (target.size() == input.size()):
TypeError: 'int' object is not callable

我希望提供所有相关的上下文信息,如果没有,请告诉我。感谢您的任何建议!

编辑:这是发生此错误的代码部分:

        target = torch.from_numpy(np.load(file_dir + '/points/points{:03}.npy'.format(i))).to(device)
rv = torch.zeros(12 * outputs.shape[0])

for j in [x for x in range(10) if x != i]:
source = torch.from_numpy(np.load(file_dir + '/points/points{:03}.npy'.format(j))).to(device)
rv = factor.ransac(source, target, prob, n_iter, tol, device) # some self-written RANSAC-like method

predicted = factor.predict(source, rv, outputs)
print(target.shape, predicted.shape)
loss = criterion(predicted, target.detach().cpu().numpy()) ## error occurs here

标准nn.MSELoss()

最佳答案

有点晚了,但也许对其他人有帮助。刚刚为自己解决了同样的问题。

正如 Alpha 在他的回答中所说,我们不能为 numpy 数组调用 .size() 。但我们可以为张量调用 .size() 。因此,我们需要将我们的目标设为张量。你可以这样做:

target = torch.from_numpy(target)

我正在使用 GPU,因此我还需要将目标发送到 GPU。你可以这样做:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
target = target.to(device)

然后损失函数必须完美运行。

关于python - 尽管形状相同 : if not (target. size() == input.size()) 但出现类型错误: 'int' 对象不可调用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57724134/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com