gpt4 book ai didi

python - PyTorch:运行时错误:输入、输出和索引必须在当前设备上

转载 作者:行者123 更新时间:2023-12-03 15:47:24 24 4
gpt4 key购买 nike

我正在火炬上运行 BERT 模型。这是一个多类情感分类任务,大约有 30,000 行。我已经把所有东西都放在了 cuda 上,但不知道为什么会出现以下运行时错误。这是我的代码:

for epoch in tqdm(range(1, epochs+1)):

model.train()

loss_train_total = 0

progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
for batch in progress_bar:

model.zero_grad()

batch = tuple(b.to(device) for b in batch)

inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'labels': batch[2],
}

outputs = model(**inputs)

loss = outputs[0]
loss_train_total += loss.item()
loss.backward()

torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

optimizer.step()
scheduler.step()

progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})


torch.save(model.state_dict(), f'finetuned_BERT_epoch_{epoch}.model')

tqdm.write(f'\nEpoch {epoch}')

loss_train_avg = loss_train_total/len(dataloader_train)
tqdm.write(f'Training loss: {loss_train_avg}')

val_loss, predictions, true_vals = evaluate(dataloader_validation)
val_f1 = f1_score_func(predictions, true_vals)
tqdm.write(f'Validation loss: {val_loss}')
tqdm.write(f'F1 Score (Weighted): {val_f1}')

---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-67-9306225bb55a> in <module>()
17 }
18
---> 19 outputs = model(**inputs)
20
21 loss = outputs[0]

8 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
1850 # remove once script supports set_grad_enabled
1851 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1852 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
1853
1854

RuntimeError: Input, output and indices must be on the current device
任何建议将不胜感激。谢谢!

最佳答案

您应该将您的模型放在设备上,这可能是 cuda:

device = "cuda:0"
model = model.to(device)

然后确保模型的输入(输入)也在同一设备上:
input = input.to(device)
它应该工作!

关于python - PyTorch:运行时错误:输入、输出和索引必须在当前设备上,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64914598/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com