gpt4 book ai didi

python - 参数组合无效 - eq()

转载 作者:行者123 更新时间:2023-11-30 09:43:54 25 4
gpt4 key购买 nike

我正在使用共享代码 here测试 CNN 图像分类器。当我调用测试函数时,我在 line 155 上收到此错误:

test_acc += torch.sum(prediction == labels.data)
TypeError: eq() received an invalid combination of arguments - got (numpy.ndarray), but expected one of:
* (Tensor other)
didn't match because some of the arguments have invalid types: ([31;1mnumpy.ndarray[0m)
* (Number other)
didn't match because some of the arguments have invalid types: ([31;1mnumpy.ndarray[0m)

test 函数的片段:

def test():
model.eval()
test_acc = 0.0
for i, (images, labels) in enumerate(test_loader):

if cuda_avail:
images = Variable(images.cuda())
labels = Variable(labels.cuda())

#Predict classes using images from the test set
outputs = model(images)
_,prediction = torch.max(outputs.data, 1)
prediction = prediction.cpu().numpy()
test_acc += torch.sum(prediction == labels.data) #line 155



#Compute the average acc and loss over all 10000 test images
test_acc = test_acc / 10000

return test_acc

经过快速搜索,我发现该错误可能与预测标签之间的比较有关,如SO question所示.

知道如何解决这个问题吗?

最佳答案

为什么这里有 .numpy() prediction = Prediction.cpu().numpy()?这样您就可以将 PyTorch 张量转换为 NumPy 数组,使其与 labels.data 进行比较时的类型不兼容。

删除 .numpy() 部分应该可以解决问题。

关于python - 参数组合无效 - eq(),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55147511/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com