gpt4 book ai didi

python - Pytorch:为什么 print(model) 不显示激活函数?

转载 作者:行者123 更新时间:2023-12-04 01:33:08 26 4
gpt4 key购买 nike

我需要从 pytorch 中经过训练的 NN 中提取权重、偏差和至少激活函数的类型。

我知道要提取权重和偏差,命令是:
model.parameters()
但我不知道如何提取图层上使用的激活函数。这是我的网络

class NetWithODE(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output, sampling_interval, scaler_features):
super(NetWithODE, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.predict = torch.nn.Linear(n_hidden, n_output) # output layer
self.sampling_interval = sampling_interval
self.device = torch.device("cpu")
self.dtype = torch.float
self.scaler_features = scaler_features

def forward(self, x):
x0 = x.clone().requires_grad_(True)
# activation function for hidden layer
x = F.relu(self.hidden(x))
# linear output, here r should be the output
r = self.predict(x)
# Now the r enters the integrator
x = self.integrate(r, x0)

return x

def integrate(self, r, x0):
# RK4 steps per interval
M = 4
DT = self.sampling_interval / M
X = x0

for j in range(M):
k1 = self.ode(X, r)
k2 = self.ode(X + DT / 2 * k1, r)
k3 = self.ode(X + DT / 2 * k2, r)
k4 = self.ode(X + DT * k3, r)
X = X + DT / 6 * (k1 + 2 * k2 + 2 * k3 + k4)

return X

def ode(self, x0, r):
qF = r[0, 0]
qA = r[0, 1]
qP = r[0, 2]
mu = r[0, 3]

FRU = x0[0, 0]
AMC = x0[0, 1]
PHB = x0[0, 2]
TBM = x0[0, 3]

fFRU = qF * TBM
fAMC = qA * TBM
fPHB = qP - mu * PHB
fTBM = mu * TBM

return torch.stack((fFRU, fAMC, fPHB, fTBM), 0)

如果我运行命令
print(model)


我明白了
NetWithODE(
(hidden): Linear(in_features=4, out_features=10, bias=True)
(predict): Linear(in_features=10, out_features=4, bias=True)
)

但是我在哪里可以获得激活函数(在这种情况下是 Relu)?

我有 pytorch 1.4。

最佳答案

向网络图添加操作有两种方式:低级功能方式和更高级的对象方式。您需要后者来使您的结构可观察,在第一种情况下只是调用(不完全是,但是......)一个函数而不存储有关它的信息。所以,而不是

    def forward(self, x):
...
x = F.relu(self.hidden(x))

它一定是这样的
def __init__(...):
...
self.myFirstRelu= torch.nn.ReLU()

def forward(self, x):
...
x1 = self.hidden(x)
x2 = self.myFirstRelu(x1)

无论如何,这两种方式的混合通常是个坏主意,尽管 torchvision模型有这样的不一致: models.inception_v3不注册池,例如 >:-( (编辑:它已在 2020 年 6 月修复,谢谢,mitmul!)。

升级版:
- 谢谢,这行得通,现在如果我打印,我会看到 ReLU()。但这似乎只以 中定义的相同顺序打印函数。初始化 .有没有办法获得层和激活函数之间的关联?例如,我想知道哪个激活应用于第 1 层,哪个激活应用于第 2 层,依此类推...

没有统一的方法,但这里有一些技巧:
对象方式:

- 只是按顺序初始化它们

-使用 torch.nn.Sequential

- 像这样的节点上的钩子(Hook)回调 -
def hook( m, i, o):
print( m._get_name() )

for ( mo ) in model.modules():
mo.register_forward_hook(hook)

功能和对象方式:

-利用内部模型图,建立在前向传播上,如 torchviz做( https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py ),或者只使用由所述 torchviz 生成的绘图.

关于python - Pytorch:为什么 print(model) 不显示激活函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60484859/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com