gpt4 book ai didi

python - pytorch torch.jit.trace 返回函数而不是 torch.jit.ScriptModule

转载 作者:行者123 更新时间:2023-11-28 17:00:51 28 4
gpt4 key购买 nike

我需要在 C++ 中运行预训练的 pytorch 神经网络模型(在 python 中训练)来进行预测。

为此,我按照此处给出的有关如何在 C++ 中加载 pytorch 模型的说明进行操作:https://pytorch.org/tutorials/advanced/cpp_export.html

但是当我尝试按照教程第一步中所述通过跟踪获取 torch.jit.ScriptModule 时:

    traced_script_module =
torch.jit.trace(model, (input_tensor_1, input_tensor_2))

它不返回 torch.jit.ScriptModule,而是返回一个函数:

    print(type(traced_script_module))
<type 'function'>

当我运行时:

    traced_script_module.save("model.pt")

然后会导致以下错误:

Traceback (most recent call last):
File "serialize_model.py", line 60, in <module>
traced_script_module.save("model.pt")
AttributeError: 'function' object has no attribute 'save'

对我做错了什么有什么想法吗?

最佳答案

感谢询问Jatentaki .我在 Python 中使用 PyTorch 0.4,当我更新到 1.0 时它可以正常工作。

关于python - pytorch torch.jit.trace 返回函数而不是 torch.jit.ScriptModule,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54650423/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com