gpt4 book ai didi

pytorch - 用 CPU 加载泡菜保存的 GPU 张量?

转载 作者:行者123 更新时间:2023-12-04 09:25:55 25 4
gpt4 key购买 nike

我在 GPU 上使用 pickle 保存了 Bert 的最后一个隐藏层,以供后续过程使用。

# output is the last hidden layer of bert, transformed on GPU
with open(filename, 'wb') as f:
pk.dump(output, f)
是否可以在没有 GPU 的情况下将其加载到我的个人笔记本电脑上?我尝试了以下代码,但都失败了。
# 1st try
with open(filename, 'rb') as f:
torch.load(f, map_location='cpu')

# 2nd
torch.load(filename, map_location=torch.device('cpu'))
都得到以下错误
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
可以在我的笔记本电脑上加载文件吗?

最佳答案

如果你使用 pytorch,你可以通过保存 state_dict 来为自己省去一些麻烦。模型而不是模型本身。 state_dict是一个有序字典,用于存储神经网络的权重。
保存程序:

import torch
model = MyFabulousPytorchModel()
torch.save(model.state_dict(), "best_model.pt")
加载它需要您先初始化模型:
import torch
device = 'cuda' if torch.cuda.is_available() else 'gpu'
model = MyFabulousPytorchModel()
model.load_state_dict(torch.load(PATH_TO_MODEL))
model.device(device)
保存 state_dict的好处很多而不是直接对象。其中之一与您的问题有关:将您的模型移植到不同的环境并不像您希望的那样轻松。另一个优点是保存检查点要容易得多,让您可以继续训练,就好像训练从未停止过一样。你所要做的就是保存优化器的状态和损失:
保存检查点:
# somewhere in your training loop:
opt.zero_grad()
pred = model(x)
loss = loss_func(pred, target)

torch.save({"model": model.state_dict(), "opt": opt.state_dict(), "loss":loss}, "checkpoing.pt")
我强烈建议查看文档以获取有关如何使用 pytorch 保存和加载模型的更多信息。如果你了解它的内部运作,这是一个相当顺利的过程。 https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference
希望有帮助=)
编辑:
更直接地,为了解决您的问题,我推荐以下内容
1- 在您用来训练模型的计算机上:
import torch
model = torch.load("PATH_TO_MODEL")
torch.save(model.state_dict(), "PATH.pt")
2- 在另一台计算机上:
import torch
from FILE_WHERE_THE_MODEL_CLASS_IS_DEFINED import Model

model = Model() # initialize one instance of the model)
model.load_state_dict(torch.load("PATH.pt")

关于pytorch - 用 CPU 加载泡菜保存的 GPU 张量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63008865/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com