gpt4 book ai didi

python - ValueError : not enough values to unpack (expected 3, 在 Pytorch 中得到 2)

转载 作者:行者123 更新时间:2023-12-05 06:00:11 24 4
gpt4 key购买 nike

这是我定义的验证函数
当我加载模型并使用此代码开始预测时,我收到了使用 PyTorch 的错误。在此之后,我迭代了纪元循环和批处理循环,但遇到了这个错误。

def validate_epoch(net, val_loader,loss_type='CE'):
net.train(False)
running_loss = 0.0
sm = nn.Softmax(dim=1)

truth = []
preds = []
bar = tqdm(total=len(val_loader), desc='Processing', ncols=90)
names_all = []
n_batches = len(val_loader)
for i, (batch, targets, names) in enumerate(val_loader):
if loss_type == 'CE':
labels = Variable(targets.float())
inputs = Variable(batch)
elif loss_type == 'MSE':
labels = Variable(targets.float())
inputs = Variable(batch)

outputs = net(inputs)
labels = labels.long()
loss = criterion(outputs, labels)
if loss_type =='CE':
probs = sm(outputs).data.cpu().numpy()
elif loss_type =='MSE':
probs = outputs
probs[probs < 0] = 0
probs[probs > 4] = 4
probs = probs.view(1,-1).squeeze(0).round().data.cpu().numpy()
preds.append(probs)
truth.append(targets.cpu().numpy())
names_all.extend(names)
running_loss += loss.item()
bar.update(1)
gc.collect()
gc.collect()
bar.close()
if loss_type =='CE':
preds = np.vstack(preds)
else:
preds = np.hstack(preds)
truth = np.hstack(truth)
return running_loss / n_batches, preds, truth, names_all

这是我调用验证函数的主要函数,在加载模型并开始在测试加载器上进行预测时获取错误

criterion = nn.CrossEntropyLoss()

model.eval()

test_losses = []
test_mse = []
test_kappa = []
test_acc = []


test_started = time.time()

test_loss, probs, truth, file_names = validate_epoch(model, test_iterator)

正如您在回溯错误中看到的那样,它给出了一些终端显示错误:

ValueError                                Traceback (most recent call last)
<ipython-input-27-d2b4a1ca3852> in <module>
12 test_started = time.time()
13
---> 14 test_loss, probs, truth, file_names = validate_epoch(model, test_iterator)
15 preds = probs.argmax(1)
16

<ipython-input-25-34e29e0ff6ed> in validate_epoch(net, val_loader, loss_type)
9 names_all = []
10 n_batches = len(val_loader)
---> 11 for i, (batch, targets, names) in enumerate(val_loader):
12 if loss_type == 'CE':
13 labels = Variable(targets.float())

ValueError: not enough values to unpack (expected 3, got 2)

最佳答案

来自 torchvision.datasets.ImageFolder documentation :

“返回:(sample, target),其中 target 是目标类的 class_index。”

因此,非常简单,您当前使用的数据集对象返回一个包含 2 个项目的元组。如果您尝试将此元组存储在 3 个变量中,则会出现错误。正确的行是:

for i, (batch, targets) in enumerate(val_loader):

如果您真的需要名称(我假设这是每个图像的文件路径),您可以定义一个新的数据集对象,该对象继承自 ImageFolder 数据集并重载 __getitem__ 函数也返回此信息。

关于python - ValueError : not enough values to unpack (expected 3, 在 Pytorch 中得到 2),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67815772/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com