我有一个 Tensorflow v1 版本的unet,它使用 SGD 进行了很好的训练,学习率为 0.05。
我在 Pytorch 中重写了网络,因为我想公开一些在 Tensorflow 中并不那么容易的功能。
我的模型始终预测空蒙版,因此我尝试使模型过度拟合一张图像。
可以对一张示例图像进行过拟合来预测一个掩模,但它仅适用于 Adam,学习率为 0.0005 和 1000 epoch。我的旧模型可以在 10 个 epoch 左右完成。
我看不出任何明显的错误。我一定做错了什么,因为这是一个微不足道的问题,应该需要很少的调整。
import numpy as np
import cv2
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, relu=True):
super().__init__()
if relu:
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
else:
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True, relu=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels, relu=relu)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 512)
self.up1 = Up(1024, 256, bilinear)
self.up2 = Up(512, 128, bilinear)
self.up3 = Up(256, 64, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
def decode_segmap(image, num_classes=3):
label_colors = np.array([(128, 0, 0),
(0, 128, 0), (0, 0, 128)])
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
b = np.zeros_like(image).astype(np.uint8)
for l in range(0, num_classes):
idx = image == l
r[idx] = label_colors[l, 0]
g[idx] = label_colors[l, 1]
b[idx] = label_colors[l, 2]
rgb = np.stack([r, g, b], axis=2)
return rgb
def load_batch(batch_size):
rotated_frame = Image.open('0test.png')
rotated_gt = Image.open('0label.png')
trf = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = [0.2455], std = [0.2684])])
rotated_frame = trf(rotated_frame).unsqueeze(0)
trf = transforms.Compose([
transforms.ToTensor()])
rotated_gt = trf(rotated_gt).unsqueeze(0)
rotated_frame = torch.mean(rotated_frame, 1).unsqueeze(1)
rotated_gt = torch.mean(rotated_gt, 1).unsqueeze(1)
return rotated_frame.to(device), rotated_gt.type(torch.long).to(device).squeeze(1)
net = UNet(1, 3)
net.to(device=device)
# Loss
#optimizer = optim.RMSprop(net.parameters(), lr=0.005, weight_decay=1e-8)
optimizer = optim.SGD(net.parameters(), lr=0.0005)
#optimizer = optim.Adam(net.parameters(), lr=0.0005)
criterion = nn.CrossEntropyLoss()
# Load data
rotated_frame, rotated_gt = load_batch(1)
print(rotated_frame.shape)
print(rotated_gt.shape)
# Train
epochs = 1000
losses = []
for epoch in range(epochs):
predicted = net(rotated_frame)
loss = criterion(predicted, rotated_gt)
losses.append(loss)
loss.backward()
optimizer.step()
print('Epoch {}/{} Loss: {}'.format(epoch, epochs, loss))
output = torch.argmax(predicted.squeeze(), dim=0).detach().cpu().numpy()
a, b = np.min(output), np.max(output)
print('Predicted: min: {} max: {}'.format(a, b))
print(output.shape)
rgb = decode_segmap(output)
plt.imshow(rgb)
plt.savefig('predicted_argmaxed.png')
gt = rotated_gt.squeeze().detach().cpu().numpy()
a, b = np.min(gt), np.max(gt)
print('Gt: min: {} max: {}'.format(a, b))
rgb = decode_segmap(gt)
plt.imshow(rgb)
plt.savefig('gt_argmaxed.png')
示例图片在这里:
任何帮助将不胜感激!
如果您使用的是CrossEntropyLoss
,您是否尝试过为类添加权重?
weights = torch.tensor([0.75, 1], dtype=torch.float)
criterion = torch.nn.CrossEntropyLoss(weight=weights,
reduction='none').to(device)
如果您的模型生成一个空蒙版(例如白色蒙版),理论上它可以最大程度地减少损失,因为全白图像似乎是更突出的类别,具体取决于您尝试的类别数量为边框类添加更多权重。
您在其中看到的权重是我在进行二元分类时使用的,当时一个类大约为 70%,其他为 30%。
否则,正如纳特提到的那样,BN 也可以提供帮助。你的学习率似乎也有点太低了。
编辑:只是为了澄清,来自文档:
weight (Tensor, optional) – a manual rescaling weight given to each class. If given, has to be a Tensor of size C
reduction (string, optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: 'mean'
我是一名优秀的程序员,十分优秀!