gpt4 book ai didi

python - 运行时错误 - 张量的元素 0 不需要 grad 并且没有 grad_fn

转载 作者:行者123 更新时间:2023-12-05 01:35:44 27 4
gpt4 key购买 nike

我正在使用 Unet 模型进行语义分割 - 我有一个自定义的图像数据集及其掩码,均为 .png 格式。我查看了在线论坛并尝试了一些东西,但效果不佳?有关如何解决错误或改进代码的任何建议都会有所帮助。

model.eval()
with torch.no_grad():
for xb, yb in val_dl:
yb_pred = model(xb.to(device))
# yb_pred = yb_pred["out"].cpu()
print(yb_pred.shape)
yb_pred = torch.argmax(yb_pred,axis = 1)
break

print(yb_pred.shape)


criteron = nn.CrossEntropyLoss(reduction = 'sum')
opt = optim.Adam(model.parameters(), lr = 3e-4)

def loss_batch(loss_func, output, target, opt = None):
loss = loss_func(output, target)

if opt is not None:
opt.zero_grad()
loss.backward()
opt.step()
return loss.item(), None

lr_scheduler = ReduceLROnPlateau(opt, mode = 'min', factor = 0.5, patience= 20, verbose = 1)

def get_lr(opt):
for param_group in opt.param_groups:
return param_group['lr']

current_lr = get_lr(opt)
print('current_lr = {}'.format(current_lr))


def loss_epoch(model, loss_func, dataset_dl, sanity_check = False, opt = None):
running_loss = 0.0
len_data = len(dataset_dl.dataset)

for xb, yb in dataset_dl:
xb = xb.to(device)
yb = yb.to(device)

# xb = torch.tensor(xbh, requires_grad=True)

output = model(xb)

loss_b, metric_b = loss_batch(loss_func, output, yb, opt)
running_loss += loss_b
if sanity_check is True:
break
loss = running_loss/float(len_data)
return loss, None

def train_val(model, params):
num_epochs = params["num_epochs"]
loss_func = params["loss_func"]
opt = params["optimizer"]
train_dl = params["train_dl"]
val_dl = params["val_dl"]
sanity_check = params["sanity_check"]
lr_scheduler = params["lr_scheduler"]
path2weights = params["path2weights"]

loss_history = {"train": [],
"val": []}
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = float('inf')

for epoch in range(num_epochs):
current_lr = get_lr(opt)
print('Epoch {}/{}, current_lr = {}'.format(epoch, num_epochs - 1, current_lr))

with torch.enable_grad():
model.train()
train_loss, _ = loss_epoch(model, loss_func, train_dl, sanity_check, opt)
loss_history["train"].append(train_loss)
model.eval()

with torch.no_grad():
val_loss, _ = loss_epoch(model, loss_func, val_dl, sanity_check, opt)
loss_history["val"].append(val_loss)

if val_loss < best_loss:
best_loss = val_loss
best_model_wts = copy.deepcopy(model.state_dict())
torch.save(model.state_dict(), path2weights)
print("copied best model weights!!")

lr_scheduler.step(val_loss)
if current_lr != get_lr(opt):
print("Loading best model weights!!")
model.load_state_dict(best_model_wts)
print("train Loss: %.6f" %(train_loss))
print("val_loss: %.6f" %(val_loss))
print("-"*20)

model.load_state_dict(best_model_wts)
return model, loss_history, metric_history


path2models = "./models/"
if not os.path.exists(path2models):
os.mkdir(path2models)

param_train = {
"num_epochs": 10,
"loss_func": criteron,
"optimizer": opt,
"train_dl": train_dl,
"val_dl": val_dl,
"sanity_check": False,
"lr_scheduler": lr_scheduler,
"path2weights": path2models + "weights.pt"
model, loss_hist, _ = train_val(model, param_train)

错误消息看起来像 -文件“”,第 10 行,位于模型, loss_hist, _ = train_val(model, param_train)

文件“”,第 27 行,在 train_val 中val_loss, _ = loss_epoch(model, loss_func, val_dl, sanity_check, opt)

文件“”,第13行,在loss_epochloss_b, metric_b = loss_batch(loss_func, output, yb, opt)

文件“”,第6行,在loss_batch损失.向后()

文件“C:\Users\W540\anaconda3\lib\site-packages\torch\tensor.py”,第 198 行,向后torch.autograd.backward(self, gradient, retain_graph, create_graph)

文件“C:\Users\W540\anaconda3\lib\site-packages\torch\autograd_init_.py”,第 100 行,向后allow_unreachable=True) # allow_unreachable 标志

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

我不确定将哪个变量设置为 require_grad = True 或者我应该在哪里启用 grad...

最佳答案

你可以在 loss.backward() 之前尝试这个:

loss = Variable(loss, requires_grad = True)

或者,因为变量已从 PyTorch 中删除(仍然存在但已弃用),您可以使用以下代码简单地做同样的事情:

loss.requires_grad = True

关于python - 运行时错误 - 张量的元素 0 不需要 grad 并且没有 grad_fn,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62699306/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com