gpt4 book ai didi

Trainin, saving and importing a Libtorch Neural Network in C++(Libtorch神经网络的C++训练、存储和导入)

转载 作者:bug小助手 更新时间:2023-10-25 16:41:50 34 4
gpt4 key购买 nike



I'm starting to get into neural networks in C++ using libtorch. I've created a simple game and trained a neural network with the Q learning method. I trained the network in a program, and now that I have a trained model I want to use it in the graphical version of the game so I can see what the model can do. I saved the model to a model.pt file as follows (reference to this forum question):

我开始使用libtorch在C++中学习神经网络。我已经创建了一个简单的游戏,并用Q学习方法训练了一个神经网络。我在一个程序中训练了网络,现在我有了一个训练好的模型,我想在游戏的图形版本中使用它,这样我就可以看到这个模型能做什么。我将模型保存到了一个模型.pt文件,如下所示(引用这个论坛的问题):


string model_path = "model.pt";
torch::serialize::OutputArchive output_archive;
qNetwork.save(output_archive);
output_archive.save_to(model_path);

It generates a model.pt file. Now I'm having trouble importing the model to the other program (which runs an SDL2 graphic version of the game). I've figured out how to properly import libtorch to that file, the problem is I don't know how to use the already trained model. I tried what this article suggests:

它会生成一个mod.pt文件。现在,我在将模型导入到另一个程序(运行游戏的SDL2图形版本)时遇到了麻烦。我已经弄清楚了如何正确地将libtorch导入到该文件中,问题是我不知道如何使用已经训练好的模型。我尝试了这篇文章的建议:


torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}

For reference this is the Torch module:

以下是火炬模块,供参考:


class QNetwork : public torch::nn::Module {
private:
torch::nn::Linear layer1{nullptr};
torch::nn::Linear layer2{nullptr};
torch::nn::Linear outputLayer{nullptr};
public:
QNetwork(int inputSize, int outputSize) {
layer1 = register_module("layer1", torch::nn::Linear(inputSize, 64));
layer2 = register_module("layer2", torch::nn::Linear(64, 64));
outputLayer = register_module("outputLayer", torch::nn::Linear(64, outputSize));
}

torch::Tensor forward(torch::Tensor x) {
x = torch::relu(layer1->forward(x));
x = torch::relu(layer2->forward(x));
x = outputLayer->forward(x);
return x;
}
};

How to resolve this issue?

如何解决这个问题?


更多回答
优秀答案推荐
更多回答

34 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com