gpt4 book ai didi

python - 如何使用 torch.hub.load 加载本地模型?

转载 作者:行者123 更新时间:2023-12-04 07:43:48 68 4
gpt4 key购买 nike

我需要避免从网上下载模型(由于安装的机器的限制)。
这有效,但从网上下载模型

model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True)
我已经放置了 .pth文件和 hubconf.py/tmp/文件夹中的文件并将我的代码更改为
model = torch.hub.load('/tmp/', 'deeplabv3_resnet101', pretrained=True,s ource='local')
但令我惊讶的是它仍然从互联网上下载模型。我究竟做错了什么?如何在本地加载模型。
只是为了给您提供更多详细信息,我在运行时具有只读卷的 Docker 容器中执行所有这些操作,因此这就是新文件下载失败的原因。
谢谢,
约翰

最佳答案

有两种方法可以在没有互联网的机器上获得可交付的模型。
1.在普通机器上加载带有预训练模型的DeepLab,使用jit编译器将其导出为图形,放入机器中。脚本很容易遵循:

# To export
model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True).eval()
traced_graph = torch.jit.trace(model, torch.randn(1, 3, H, W))
traced_graph.save('DeepLab.pth')

# To load
model = torch.jit.load('DeepLab.pth').eval().to(device)
在这种情况下,权重和网络结构保存为计算图,因此您不需要任何额外的文件。
  • 看看torchvision的github repo .

  • 有一个 download url用于带有 Resnet101 主干权重的 DeepLabV3。
    您可以下载一次这些权重,然后使用来自 torchvision 的 deeplab 和 预训练=假手动标记和加载重量。
    model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=False)
    model.load_state_dict(torch.load('downloaded weights path'))
    考虑到,可能有 ['state_dict'] 或 state dict 中的一些类似父键,您将在其中使用:
    model.load_state_dict(torch.load('downloaded weights path')['state_dict'])

    关于python - 如何使用 torch.hub.load 加载本地模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67302634/

    68 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com