gpt4 book ai didi

python - 在无法访问模型类代码的情况下保存 PyTorch 模型

转载 作者:行者123 更新时间:2023-12-03 15:25:03 25 4
gpt4 key购买 nike

如何在不需要在某处定义模型类的情况下保存 PyTorch 模型?

免责声明 :
Best way to save a trained model in PyTorch? ,还有无需访问模型类代码即可保存模型的解决方案(或工作解决方案)。

最佳答案

如果您打算使用可用的 Pytorch 库(即 Python、C++ 或它支持的其他平台中的 Pytorch)进行推理,那么最好的方法是通过 TorchScript .

我觉得最简单的就是用trace = torch.jit.trace(model, typical_input)然后 torch.jit.save(trace, path) .然后,您可以使用 torch.jit.load(path) 加载跟踪模型。 .

这是一个非常简单的例子。我们制作两个文件:
train.py :

import torch

class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x):
x = torch.relu(self.linear(x))
return x

model = Model()
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
with torch.no_grad():
print(model(x))
traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, "model.pth")
infer.py :

import torch
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
loaded_trace = torch.jit.load("model.pth")
with torch.no_grad():
print(loaded_trace(x))

按顺序运行这些会得到结果:

python train.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
[0.0000, 0.5272, 0.3481, 0.1743]])

python infer.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
[0.0000, 0.5272, 0.3481, 0.1743]])

结果是一样的,所以我们很好。 (请注意,由于 nn.Linear 层初始化的随机性,这里每次的结果都会不同)。

TorchScript 提供了将更复杂的架构和图形定义(包括 if 语句、while 循环等)保存在单个文件中,而无需在推理时重新定义图形。有关更高级的可能性,请参阅文档(上面链接)。

关于python - 在无法访问模型类代码的情况下保存 PyTorch 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59287728/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com