gpt4 book ai didi

python - 如何在pytorch模型中加载检查点文件?

转载 作者:行者123 更新时间:2023-12-02 08:59:30 27 4
gpt4 key购买 nike

在我的 pytorch 模型中,我正在像这样初始化我的模型和优化器。

model = MyModelClass(config, shape, x_tr_mean, x_tr,std)
optimizer = optim.SGD(model.parameters(), lr=config.learning_rate)

这是我的检查点文件的路径。

checkpoint_file = os.path.join(config.save_dir, "checkpoint.pth")

为了加载此检查点文件,我检查并查看检查点文件是否存在,然后加载它以及模型和优化器。

if os.path.exists(checkpoint_file):
if config.resume:
torch.load(checkpoint_file)
model.load_state_dict(torch.load(checkpoint_file))
optimizer.load_state_dict(torch.load(checkpoint_file))

此外,这是我保存模型和优化器的方法。

 torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'iter_idx': iter_idx, 'best_va_acc': best_va_acc}, checkpoint_file)

出于某种原因,每当我运行此代码时,我都会收到一个奇怪的错误。

model.load_state_dict(torch.load(checkpoint_file))
File "/home/Josh/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for MyModelClass:
Missing key(s) in state_dict: "mean", "std", "attribute.weight", "attribute.bias".
Unexpected key(s) in state_dict: "model", "optimizer", "iter_idx", "best_va_acc"

有谁知道为什么我会收到此错误?

最佳答案

您将模型参数保存在字典中。您应该使用之前保存时使用的 key 来加载模型检查点和 state_dict,如下所示:

if os.path.exists(checkpoint_file):
if config.resume:
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

可以查看官方tutorial在 PyTorch 网站上了解更多信息。

关于python - 如何在pytorch模型中加载检查点文件?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54677683/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com