gpt4 book ai didi

python - Pytorch:我们可以直接在 forward() 函数中使用 nn.Module 层吗?

转载 作者:行者123 更新时间:2023-12-01 11:11:42 25 4
gpt4 key购买 nike

一般来说,
在构造函数中,我们声明了我们想要使用的所有层。
在 forward 函数中,我们定义了模型将如何运行,从输入到输出。

我的问题是如果调用那些预定义/内置的 nn.模块 直接在 forward()功能?这是Keras function API 的合法样式火炬 ?如果不是,为什么?

更新:测试型号 以这种方式构建的确实运行成功,没有警报。但与传统方式相比,训练损失会缓慢下降。

import torch.nn as nn
from cnn import CNN

class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.num_embeddings = 2020
self.embedding_dim = 51

def forward(self, input):
x = nn.Embedding(self.num_embeddings, self.embedding_dim)(input)
# CNN is a customized class and nn.Module subclassed
# we will ignore the arguments for its instantiation
x = CNN(...)(x)
x = nn.ReLu()(x)
x = nn.Dropout(p=0.2)(x)
return output = x

最佳答案

您需要考虑可训练参数的范围。

如果你在 forward 中定义一个卷积层你的模型的函数,那么这个“层”的范围和它的可训练参数是函数本地的,并且在每次调用 forward 后都会被丢弃。方法。您无法更新和训练在每个 forward 之后不断丢弃的权重。经过。
但是,当转换层是您的 model 的成员时它的范围超出了forward方法和可训练参数只要 model对象存在。通过这种方式,您可以更新和训练模型及其权重。

关于python - Pytorch:我们可以直接在 forward() 函数中使用 nn.Module 层吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59642925/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com