gpt4 book ai didi

c++ - 线性、Conv1d、Conv2d、...、LSTM 的通用类,

转载 作者:塔克拉玛干 更新时间:2023-11-03 07:39:12 24 4
gpt4 key购买 nike

有没有什么类全部torch::nn::Linear, torch::nn::Conv1d, torch::nn::Conv2d, ... torch::nn::GRU, ....都继承自那个? torch::nn::Module似乎是一个不错的选择,虽然有一个中间类,叫做torch::nn::Cloneable ,所以 torch::nn::Module不起作用。另外,torch::nn::Cloneable本身是一个模板,因此需要在声明中键入。我想创建一个通用的 class model , 其中有 std::vector<the common class> layers ,以便稍后我可以填写 layers使用我想要的任何类型的图层,例如 Linear, LSTM,等等,现在的API有没有这样的能力?这可以在 python 中轻松完成,尽管这里我们需要声明,这阻碍了 python 的简单性。

谢谢,阿夫欣

最佳答案

我发现 nn::sequential可以用于这个目的,不需要前向实现,可以是正点,同时也是负点。 nn::sequential已经要求每个模块都有一个前向实现,并按照它们添加的顺序调用前向函数。因此,虽然它已经足够好了,但不能像 Dense-Net 这样创建一个特别的非常规前向传递用于一般用途。

此外,似乎nn::sequential只使用 std::vector<nn::AnyModule>作为其底层模块列表。所以,std::vector<nn::AnyModule>也可能会被使用。

关于c++ - 线性、Conv1d、Conv2d、...、LSTM 的通用类,,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55223728/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com