gpt4 book ai didi

deep-learning - 在 pytorch 中使用 BatchNorm 进行训练

转载 作者:行者123 更新时间:2023-12-04 00:02:31 29 4
gpt4 key购买 nike

我想知道在 pytorch 中使用 BatchNorm 进行训练时是否需要做一些特别的事情。据我了解,gammabeta 参数使用梯度更新,这通常由优化器完成。但是,批处理的均值和方差会使用动量缓慢更新。

  1. 那么我们是否需要在均值和方差参数更新时向优化器指定,还是 pytorch 会自动处理这个问题?
  2. 有没有办法访问 BN 层的均值和方差,以便我在训练模型时确保它发生变化。

如果需要,这里是我的模型和训练程序:

def bn_drop_lin(n_in:int, n_out:int, bn:bool=True, p:float=0.):
"Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`."
layers = [nn.BatchNorm1d(n_in)] if bn else []
if p != 0: layers.append(nn.Dropout(p))
layers.append(nn.Linear(n_in, n_out))

return nn.Sequential(*layers)

class Model(nn.Module):
def __init__(self, i, o, h=()):
super().__init__()

nodes = (i,) + h + (o,)
self.layers = nn.ModuleList([bn_drop_lin(i,o, p=0.5)
for i, o in zip(nodes[:-1], nodes[1:])])

def forward(self, x):
x = x.view(x.shape[0], -1)
for layer in self.layers[:-1]:
x = F.relu(layer(x))

return self.layers[-1](x)

培训:

for i, data in enumerate(trainloader):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data

# zero the parameter gradients
optimizer.zero_grad()

# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

最佳答案

根据模型是处于训练模式还是评估模式,Batchnorm 层的行为会有所不同。

net 处于训练模式时(即在调用 net.train() 之后),net 中包含的批处理规范层将使用批处理统计数据以及 gamma 和 beta 参数来缩放和转换每个 mini-batch。在训练模式下,运行均值和方差也会进行调整。这些对运行均值和方差的更新发生在前向传递期间(当调用 net(inputs) 时)。 gamma 和 beta 参数与任何其他 pytorch 参数一样,仅在调用 optimizer.step() 时更新。

net 处于 eval 模式 (net.eval()) 时,batch norm 使用训练期间计算的历史运行均值和运行方差来缩放和转换样本。

您可以通过显示层 running_meanrunning_var 成员来检查批范数层的运行均值和方差,以确保批范数按预期更新它们。可以通过分别显示批处理规范层的 weightbias 成员来访问可学习的 gamma 和 beta 参数。

编辑

下面是一个简单的演示代码,显示 running_mean 在转发过程中被更新。观察它没有被优化器更新。

>>> import torch
>>> import torch.nn as nn
>>> layer = nn.BatchNorm1d(5)
>>> layer.train()
>>> layer.running_mean
tensor([0., 0., 0., 0., 0.])
>>> result = layer(torch.randn(5,5))
>>> layer.running_mean
tensor([ 0.0271, 0.0152, -0.0403, -0.0703, -0.0056])

关于deep-learning - 在 pytorch 中使用 BatchNorm 进行训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57865112/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com