gpt4 book ai didi

python - 在 pytorch 中训练期间 best_state 随模型变化

转载 作者:行者123 更新时间:2023-12-01 07:44:31 24 4
gpt4 key购买 nike

我想保存最好的模型,然后在测试期间加载它。所以我使用了以下方法:

def train():  
#training steps …
if acc > best_acc:
best_state = model.state_dict()
best_acc = acc
return best_state

然后,在主函数中我使用:

model.load_state_dict(best_state)  

恢复模型。

但是,我发现best_state始终与训练时的最后状态相同,而不是最佳状态。有谁知道原因以及如何避免它?

顺便说一句,我知道我可以使用torch.save(the_model.state_dict(), PATH),然后通过以下方式加载模型the_model.load_state_dict(torch.load(PATH))。但是,我不想将参数保存到文件中,因为训练和测试函数位于一个文件中。

最佳答案

model.state_dict()OrderedDict

from collections import OrderedDict

您可以使用:

from copy import deepcopy

解决问题

相反:

best_state = model.state_dict() 

您应该使用:

best_state = copy.deepcopy(model.state_dict())

深(而不是浅)复制使可变的 OrderedDict 实例不会改变 best_state

你可以查看我的other answer在 PyTorch 中保存状态字典。

关于python - 在 pytorch 中训练期间 best_state 随模型变化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56526698/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com