gpt4 book ai didi

python - LayerNorm inside nn.Sequential in torch

转载 作者:行者123 更新时间:2023-12-05 06:11:37 30 4
gpt4 key购买 nike

我正在尝试在 nn.Sequential 中使用 LayerNorm在 torch 中。这就是我要找的-

import torch.nn as nn

class LayerNormCnn(nn.Module):
def __init__(self):
super(LayerNormCnn, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
nn.LayerNorm(),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.LayerNorm(),
nn.ReLU(),
)

def forward(self, x):
x = self.net(x)
return x

不幸的是,它不起作用,因为 LayerNorm需要 normalized_shape 作为输入。上面的代码抛出以下异常-

    nn.LayerNorm(),
TypeError: __init__() missing 1 required positional argument: 'normalized_shape'

现在,这就是我实现它的方式-

import torch.nn as nn
import torch.nn.functional as F


class LayerNormCnn(nn.Module):
def __init__(self, state_shape):
super(LayerNormCnn, self).__init__()
self.conv1 = nn.Conv2d(state_shape[0], 32, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)

# compute shape by doing a forward pass
with torch.no_grad():
fake_input = torch.randn(1, *state_shape)
out = self.conv1(fake_input)
bn1_size = out.size()[1:]
out = self.conv2(out)
bn2_size = out.size()[1:]

self.bn1 = nn.LayerNorm(bn1_size)
self.bn2 = nn.LayerNorm(bn2_size)

def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
return x

if __name__ == '__main__':
in_shape = (3, 128, 128)
batch_size = 32

model = LayerNormCnn(in_shape)
x = torch.randn((batch_size,) + in_shape)
out = model(x)
print(out.shape)

是否可以在 nn.Sequential 中使用 LayerNorm?

最佳答案

original层归一化论文建议不要在 CNN 中使用层归一化,因为图像边界周围的感受野将具有不同的值,而不是实际图像内容中的感受野。这个问题不会出现在 RNN 中,这是层范数最初测试的目的。您确定要使用 LayerNorm 吗?如果您想将不同的归一化技术与 BatchNorm 进行比较,请考虑 GroupNorm .这摆脱了 LayerNorm 假设,即层中的 所有 channel 对预测的贡献相同,这是有问题的,特别是如果该层是卷积层。相反,每个 channel 被进一步分成组,这仍然允许 GN 层跨 channel 学习不同的统计数据。

请引用here进行相关讨论。

关于python - LayerNorm inside nn.Sequential in torch,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63914843/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com