gpt4 book ai didi

python - 如何在 pytorch 中保存和加载 nn.Parameter(),以便可以继续优化它?

转载 作者:行者123 更新时间:2023-12-04 14:58:09 24 4
gpt4 key购买 nike

我知道如何存储和加载 nn.Model,但找不到如何为 nn.Parameter 创建检查点。我试过这个版本,但优化器在恢复后没有更改 nn.Parameter 值。

from torch import nn as nn
import torch
from torch.optim import Adam

alpha = torch.ones(10)
lr = 0.001
alpha = nn.Parameter(alpha)
print(alpha)
alpha_optimizer = Adam([alpha], lr=lr)

for i in range(10):
alpha_loss = - alpha.mean()
alpha_optimizer.zero_grad()
alpha_loss.backward()
alpha_optimizer.step()
print(alpha)
path = "./test.pt"
state = dict(alpha_optimizer=alpha_optimizer.state_dict(),
alpha=alpha)
torch.save(state, path)
checkpoint = torch.load(path)
alpha = checkpoint["alpha"]
alpha_optimizer.load_state_dict(checkpoint["alpha_optimizer"])
for i in range(10):
alpha_loss = - alpha.mean()
alpha_optimizer.zero_grad()
alpha_loss.backward()
alpha_optimizer.step()
print(alpha)

最佳答案

问题是优化器仍然引用旧的 alpha(检查 id(alpha)id(alpha_optimizer.param_groups[0 ]["params"][0]) 在最后一个 for 循环之前),而当您从 alpha = checkpoint["中的检查点加载它时会设置一个新的alpha"].

您需要在加载其状态之前更新优化器的参数:

# ....
torch.save(state, path)
checkpoint = torch.load(path)

# here's where the reference of alpha changes, and the source of the problem
alpha = checkpoint["alpha"]

# reset optim
alpha_optimizer = Adam([alpha], lr=lr)
alpha_optimizer.load_state_dict(checkpoint["alpha_optimizer"])

for i in range(10):
alpha_loss = - alpha.mean()
alpha_optimizer.zero_grad()
alpha_loss.backward()
alpha_optimizer.step()
print(alpha)

关于python - 如何在 pytorch 中保存和加载 nn.Parameter(),以便可以继续优化它?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67505670/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com