gpt4 book ai didi

python-3.x - 在 Pytorch 内置的自定义 Batchnorm 中更新 running_mean 和 running_var 有问题吗?

转载 作者:行者123 更新时间:2023-12-04 15:30:34 29 4
gpt4 key购买 nike

我一直在尝试实现自定义批归一化函数,以便它可以扩展到多 GPU 版本,特别是 Pytorch 中的 DataParallel 模块。自定义批归一化在使用 1 个 GPU 时工作正常,但是,当扩展到2或更多,运行均值和方差在前向函数中起作用,但当它从网络返回时,均值和方差重新初始化为0和1。

torch.nn.DataParallel 在警告部分提到“在每个转发中,模块在每个设备上复制,因此对转发中运行模块的任何更新都将丢失。例如,如果模块具有计数器属性在每次转发中递增,它将始终保持在初始值,因为更新是在转发后销毁的副本上完成的。”但我不太确定如何保留默认设备的均值和方差。

我已经提供了多 GPU 训练期间获得的结果的代码。此代码利用提供的 Batchnorm here .

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torch.nn.parameter import Parameter

class ptrblck_BatchNorm2d(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super(ptrblck_BatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)

def forward(self, input):
self._check_input_dim(input)

exponential_average_factor = 0.0

if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum

# calculate running estimates
if self.training:
mean = input.mean([0, 2, 3])
# use biased var in train
var = input.var([0, 2, 3], unbiased=False)
n = input.numel() / input.size(1)
with torch.no_grad():
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
# update running_var with unbiased var
self.running_var = exponential_average_factor * var * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_var
else:
mean = self.running_mean
var = self.running_var

input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
if self.affine:
input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

return input


class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.bn1 = ptrblck_BatchNorm2d(64)
print("==> printing bn1 mean when init")
print(self.bn1.running_mean)
print("==> printing bn1 when init")
print(self.bn1.running_mean)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.classifier = nn.Linear(64, 10)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.pool(x)
x = self.avgpool(x)

x = x.view(x.size(0), -1)
x = self.classifier(x)
print("======================================================")
print("==> printing bn1 running mean from NET during forward")
print(net.module.bn1.running_mean)
print("==> printing bn1 running mean from SELF. during forward")
print(self.bn1.running_mean)
print("==> printing bn1 running var from NET during forward")
print(net.module.bn1.running_var)
print("==> printing bn1 running mean from SELF. during forward")
print(self.bn1.running_var)
return x

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])


transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])


trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
net = net()
net = torch.nn.DataParallel(net).cuda()
print('Number of GPU {}'.format(torch.cuda.device_count()))

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0

for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.cuda(), targets.cuda()
outputs = net(inputs)
loss = criterion(outputs, targets)
print("====================================================")
print("==> printing bn1 running mean FROM net after forward")
print(net.module.bn1.running_mean)
print("==> printing bn1 running var FROM net after forward")
print(net.module.bn1.running_var)

break
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()

# train_loss += loss.item()
# _, predicted = outputs.max(1)
# total += targets.size(0)
# correct += predicted.eq(targets).sum().item()

# break


for epoch in range(0, 1):
train(epoch)

结果:

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Building model..
==> printing bn1 mean when init
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
==> printing bn1 when init
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Number of GPU 2

Epoch: 0
======================================================
==> printing bn1 running mean from NET during forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([ 0.0053, 0.0010, -0.0077, -0.0290, 0.0241, 0.0258, -0.0048, 0.0151,
-0.0133, 0.0080, 0.0197, -0.0042, -0.0188, 0.0233, 0.0310, -0.0230,
-0.0133, 0.0222, 0.0119, -0.0042, -0.0220, -0.0169, -0.0342, -0.0025,
0.0338, -0.0070, 0.0202, 0.0050, 0.0108, 0.0008, 0.0363, 0.0347,
-0.0106, 0.0082, 0.0128, 0.0074, 0.0111, -0.0030, -0.0089, 0.0070,
-0.0262, -0.0029, 0.0053, -0.0136, -0.0183, 0.0045, -0.0014, -0.0221,
0.0132, 0.0064, 0.0388, -0.0220, -0.0008, 0.0400, -0.0187, 0.0397,
-0.0131, -0.0176, 0.0035, 0.0055, -0.0270, 0.0066, -0.0149, 0.0135],
device='cuda:0')
==> printing bn1 running var from NET during forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([0.9665, 0.9073, 0.9220, 1.0947, 1.0687, 0.9624, 0.9252, 0.9131, 0.9066,
0.9536, 0.9258, 0.9203, 1.0359, 0.9690, 1.1066, 1.0636, 0.9135, 0.9644,
0.9373, 0.9846, 0.9696, 0.9454, 1.0459, 0.9245, 0.9778, 0.9709, 0.9352,
0.9995, 0.9657, 0.9510, 1.0943, 1.0171, 0.9298, 1.0747, 0.9341, 0.9635,
0.9978, 0.9303, 0.9261, 0.9137, 0.9569, 1.0066, 1.0463, 0.9955, 0.9621,
0.9172, 0.9836, 0.9817, 0.9086, 0.9576, 1.0905, 0.9861, 0.9661, 1.1773,
0.9345, 1.0904, 0.9133, 1.0660, 0.9164, 0.9058, 0.9446, 0.9225, 1.0914,
0.9292], device='cuda:0')
======================================================
==> printing bn1 running mean from NET during forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([-0.0020, 0.0002, -0.0103, -0.0426, 0.0386, 0.0311, -0.0059, 0.0151,
-0.0140, 0.0145, 0.0218, -0.0029, -0.0281, 0.0284, 0.0449, -0.0329,
-0.0107, 0.0278, 0.0135, -0.0123, -0.0260, -0.0214, -0.0423, -0.0035,
0.0410, -0.0097, 0.0276, 0.0102, 0.0197, -0.0001, 0.0483, 0.0451,
-0.0078, 0.0190, 0.0135, -0.0004, 0.0196, -0.0028, -0.0140, 0.0070,
-0.0332, -0.0110, 0.0151, -0.0210, -0.0226, 0.0074, -0.0088, -0.0314,
0.0125, -0.0003, 0.0505, -0.0312, 0.0086, 0.0544, -0.0245, 0.0528,
-0.0086, -0.0290, 0.0063, 0.0042, -0.0339, 0.0061, -0.0277, 0.0092],
device='cuda:1')
==> printing bn1 running var from NET during forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([0.9665, 0.9072, 0.9211, 1.0999, 1.0714, 0.9610, 0.9209, 0.9125, 0.9063,
0.9553, 0.9260, 0.9189, 1.0386, 0.9706, 1.1139, 1.0610, 0.9121, 0.9660,
0.9366, 0.9886, 0.9683, 0.9454, 1.0511, 0.9227, 0.9792, 0.9704, 0.9330,
0.9989, 0.9657, 0.9476, 1.1008, 1.0191, 0.9294, 1.0814, 0.9320, 0.9642,
1.0006, 0.9287, 0.9254, 0.9128, 0.9559, 1.0100, 1.0521, 0.9972, 0.9621,
0.9168, 0.9849, 0.9803, 0.9083, 0.9556, 1.0946, 0.9865, 0.9651, 1.1880,
0.9330, 1.0959, 0.9116, 1.0706, 0.9149, 0.9057, 0.9450, 0.9215, 1.0972,
0.9261], device='cuda:1')
====================================================
==> printing bn1 running mean FROM net after forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
device='cuda:0')
==> printing bn1 running var FROM net after forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')

如何确保使用默认设备的运行估计值?目前,我没有致力于同步 Batchnorm。

最佳答案

替换

self.running_mean = (...)

self.running_mean.copy_(...)

完成了任务。

Reference

关于python-3.x - 在 Pytorch 内置的自定义 Batchnorm 中更新 running_mean 和 running_var 有问题吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61311334/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com