gpt4 book ai didi

c++ - pytorch/libtorch C++ 中的自定义子模块

转载 作者:行者123 更新时间:2023-12-02 10:17:16 29 4
gpt4 key购买 nike

完全披露,几天前我在 PyTorch 论坛上问了同样的问题,但没有得到回复,所以这在技术上是一个转发,但我相信这仍然是一个好问题,因为我一直无法在网上的任何地方找到答案.开始:

您能否展示一个将 register_module 与自定义模块一起使用的示例?
我在网上找到的唯一示例是将线性层或卷积层注册为子模块。

我尝试编写自己的模块并将其注册到另一个模块,但我无法让它工作。
我的 IDE 告诉我 no instance of overloaded function "MyModel::register_module" matches the argument list -- argument types are: (const char [14], TreeEmbedding)
(TreeEmbedding 是我创建的另一个结构的名称,它扩展了 torch::nn::Module。)

我错过了什么吗?这方面的一个例子会很有帮助。

编辑:附加上下文如下。

我有一个头文件“model.h”,其中包含以下内容:

struct TreeEmbedding : torch::nn::Module {
TreeEmbedding();
torch::Tensor forward(Graph tree);
};

struct MyModel : torch::nn::Module{
size_t embeddingSize;
TreeEmbedding treeEmbedding;

MyModel(size_t embeddingSize=10);
torch::Tensor forward(std::vector<Graph> clauses, std::vector<Graph> contexts);
};

我还有一个 cpp 文件“model.cpp”,其中包含以下内容:

MyModel::MyModel(size_t embeddingSize) :
embeddingSize(embeddingSize)
{
treeEmbedding = register_module("treeEmbedding", TreeEmbedding{});
}

此设置仍然存在与上述相同的错误。文档中的代码确实有效(使用线性层等内置组件),但使用自定义模块却不行。在追踪torch::nn::Linear之后,看起来好像是 ModuleHolder (不管那是什么……)

谢谢,
jack

最佳答案

如果有人可以提供更多详细信息,我会接受更好的答案,但以防万一有人想知道,我想我会提供我能找到的小信息:

register_module 接受一个字符串作为它的第一个参数,它的第二个参数可以是 ModuleHolder (我不知道这是什么......),或者它可以是你模块的 shared_ptr 。所以这是我的例子:

treeEmbedding = register_module<TreeEmbedding>("treeEmbedding", make_shared<TreeEmbedding>());

到目前为止,这似乎对我有用。

关于c++ - pytorch/libtorch C++ 中的自定义子模块,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61515915/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com