gpt4 book ai didi

python - 越来越大的正 WGAN-GP 损失

转载 作者:行者123 更新时间:2023-11-28 21:34:19 32 4
gpt4 key购买 nike

我正在研究在 PyTorch 中使用带有梯度惩罚的 Wasserstein GAN,但始终得到大的、正的生成器损失,并且随着时间的推移而增加。
我从 Caogang's implementation 大量借钱,但我使用了 this implementation 中使用的鉴别器和生成器损失因为我得到 Invalid gradient at index 0 - expected shape[] but got [1]如果我尝试调用 .backward()onemone草纲实现中使用的参数。

我正在对增强的 WikiArt 数据集(> 400k 64x64 图像)和 CIFAR-10 进行训练,并且得到了一个正常的 WGAN(使用权重裁剪)[即尽管 D 和 G 损失都徘徊在 3 左右 [我使用 torch.mean(D_real) 计算它们],但它在 25 个时期后产生了可通过的图像]等]适用于所有时代。然而,在 WGAN-GP 版本中,生成器损失在 WikiArt 和 CIFAR-10 数据集上都急剧增加,并且完全无法在 WikiArt 上生成噪声以外的任何内容。

以下是 CIFAR-10 上 25 个时期后损失的示例:
WGAN-GP loss

我没有使用任何技巧,比如单边标签平滑,我使用默认学习率 0.001、Adam 优化器和我为每次生成器更新训练鉴别器 5 次。为什么会发生这种疯狂的减重行为,为什么正常的减重 WGAN 在 WikiArt 上仍然“有效”但 WANGP 完全失败?

这与结构无关,无论 G 和 D 是 DCGAN 还是使用 this modified DCGAN, the Creative Adversarial Network 时都会发生这种情况。 ,这要求 D 能够对图像进行分类,而 G 生成模糊图像。

以下是我目前train的相关部分方法:

self.generator = Can64Generator(self.z_noise, self.channels, self.num_gen_filters).to(self.device)
self.discriminator =WCan64Discriminator(self.channels,self.y_dim, self.num_disc_filters).to(self.device)
style_criterion = nn.CrossEntropyLoss()

self.disc_optimizer = optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.9))
self.gen_optimizer = optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.9))


while i < len(dataloader):
j = 0
disc_loss_epoch = []
gen_loss_epoch = []
if self.type == "can":
disc_class_loss_epoch = []
gen_class_loss_epoch = []

if self.gradient_penalty == False:
# critic training methodology in official WGAN implementation
if gen_iterations < 25 or (gen_iterations % 500 == 0):
disc_iters = 100
else:
disc_iters = self.disc_iterations

while j < disc_iters and (i < len(dataloader)):
# if using wgan with weight clipping
if self.gradient_penalty == False:
# Train Discriminator
for param in self.discriminator.parameters():
param.data.clamp_(self.lower_clamp,self.upper_clamp)


for param in self.discriminator.parameters():
param.requires_grad_(True)

j+=1
i+=1
data = data_iterator.next()
self.discriminator.zero_grad()
real_images, image_labels = data
# image labels are the the image's classes (e.g. Impressionism)
real_images = real_images.to(self.device)
batch_size = real_images.size(0)
real_image_labels = torch.LongTensor(batch_size).to(self.device)
real_image_labels.copy_(image_labels)

labels = torch.full((batch_size,),real_label,device=self.device)

if self.type == 'can':
predicted_output_real, predicted_styles_real = self.discriminator(real_images.detach())
predicted_styles_real = predicted_styles_real.to(self.device)
disc_class_loss = style_criterion(predicted_styles_real,real_image_labels)
disc_class_loss.backward(retain_graph=True)

else:
predicted_output_real = self.discriminator(real_images.detach())

disc_loss_real = -torch.mean(predicted_output_real)


# fake

noise = torch.randn(batch_size,self.z_noise,1,1,device=self.device)
with torch.no_grad():
noise_g = noise.detach()
fake_images = self.generator(noise_g)
labels.fill_(fake_label)

if self.type == 'can':
predicted_output_fake, predicted_styles_fake = self.discriminator(fake_images)

else:
predicted_output_fake = self.discriminator(fake_images)



disc_gen_z_1 = predicted_output_fake.mean().item()

disc_loss_fake = torch.mean(predicted_output_fake)


#via https://github.com/znxlwm/pytorch-generative-model-collections/blob/master/WGAN_GP.py
if self.gradient_penalty:
# gradient penalty
alpha = torch.rand((real_images.size()[0], 1, 1, 1)).to(self.device)
x_hat = alpha * real_images.data + (1 - alpha) * fake_images.data
x_hat.requires_grad_(True)
if self.type == 'can':
pred_hat, _ = self.discriminator(x_hat)
else:
pred_hat = self.discriminator(x_hat)
gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).to(self.device),
create_graph=True, retain_graph=True, only_inputs=True)[0]

gradient_penalty = lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean()
disc_loss = disc_loss_fake + disc_loss_real + gradient_penalty
else:
disc_loss = disc_loss_fake + disc_loss_real


if self.type == 'can':
disc_loss += disc_class_loss.mean()

disc_x = disc_loss.mean().item()
disc_loss.backward(retain_graph=True)
self.disc_optimizer.step()



# train generator
for param in self.discriminator.parameters():
param.requires_grad_(False)

self.generator.zero_grad()
labels.fill_(real_label)

if self.type == 'can':
predicted_output_fake, predicted_styles_fake = self.discriminator(fake_images)
predicted_styles_fake = predicted_styles_fake.to(self.device)

else:
predicted_output_fake = self.discriminator(fake_images)

gen_loss = -torch.mean(predicted_output_fake)
disc_gen_z_2 = gen_loss.mean().item()

if self.type == 'can':
fake_batch_labels = 1.0/self.y_dim * torch.ones_like(predicted_styles_fake)
fake_batch_labels = torch.mean(fake_batch_labels,1).long().to(self.device)
gen_class_loss = style_criterion(predicted_styles_fake,fake_batch_labels)
gen_class_loss.backward(retain_graph=True)
gen_loss += gen_class_loss.mean()

gen_loss.backward()
gen_iterations += 1

这是(DCGAN)生成器的代码:
class Can64Generator(nn.Module):
def __init__(self, z_noise, channels, num_gen_filters):
super(Can64Generator,self).__init__()
self.ngpu = 1
self.main = nn.Sequential(
nn.ConvTranspose2d(z_noise, num_gen_filters * 16, 4, 1, 0, bias=False),
nn.BatchNorm2d(num_gen_filters * 16),
nn.ReLU(True),
nn.ConvTranspose2d(num_gen_filters * 16, num_gen_filters * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_gen_filters * 4),
nn.ReLU(True),
nn.ConvTranspose2d(num_gen_filters * 4, num_gen_filters * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_gen_filters * 2),
nn.ReLU(True),
nn.ConvTranspose2d(num_gen_filters * 2, num_gen_filters, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_gen_filters),
nn.ReLU(True),
nn.ConvTranspose2d(num_gen_filters, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, inp):
output = self.main(inp)
return output

这是(当前的)CAN 鉴别器,它有额外的层
风格(图像类)分类):
class Can64Discriminator(nn.Module):

def __init__(self, channels,y_dim, num_disc_filters):
super(Can64Discriminator, self).__init__()
self.ngpu = 1
self.conv = nn.Sequential(
nn.Conv2d(channels, num_disc_filters // 2, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(num_disc_filters // 2, num_disc_filters, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_disc_filters),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(num_disc_filters, num_disc_filters * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_disc_filters * 2),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(num_disc_filters * 2, num_disc_filters * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_disc_filters * 4),
nn.LeakyReLU(0.2, inplace=True),

nn.Conv2d(num_disc_filters * 4, num_disc_filters * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(num_disc_filters * 8),
nn.LeakyReLU(0.2, inplace=True),

)
# was this
#self.final_conv = nn.Conv2d(num_disc_filters * 8, num_disc_filters * 8, 4, 2, 1, bias=False)

self.real_fake_head = nn.Linear(num_disc_filters * 8, 1)

# no bn and lrelu needed
self.sig = nn.Sigmoid()
self.fc = nn.Sequential()
self.fc.add_module("linear_layer{0}".format(num_disc_filters*16),nn.Linear(num_disc_filters*8,num_disc_filters*16))
self.fc.add_module("linear_layer{0}".format(num_disc_filters*8),nn.Linear(num_disc_filters*16,num_disc_filters*8))
self.fc.add_module("linear_layer{0}".format(num_disc_filters),nn.Linear(num_disc_filters*8,y_dim))
self.fc.add_module('softmax',nn.Softmax(dim=1))

def forward(self, inp):
x = self.conv(inp)
x = x.view(x.size(0),-1)
real_out = self.sig(self.real_fake_head(x))
real_out = real_out.view(-1,1).squeeze(1)
style = self.fc(x)
#style = torch.mean(style,1) # CrossEntropyLoss requires input be (N,C)
return real_out,style

WANGP 版本和我的 GAN 的 WGAN 版本之间的唯一区别是 WGAN 版本使用 RMSproplr=0.00005并根据 WGAN 论文剪裁鉴别器​​的权重。

什么可能导致这种情况?我想做出尽可能小的改变,因为我想单独比较损失函数。即使在 CIFAR-10 上使用未修改的 DCGAN 鉴别器时也会遇到同样的问题。我遇到这个可能是因为我目前只训练了 25 个时期,还是有其他原因?有趣的是,当使用 LSGAN ( nn.MSELoss() ) 时,我的 GAN 也完全无法产生噪音以外的任何东西。

提前致谢!

最佳答案

鉴别器中的批量归一化通过梯度惩罚打破了 Wasserstein GAN。作者自己提倡使用层归一化,但这在他们的论文 (https://papers.nips.cc/paper/7159-improved-training-of-wasserstein-gans.pdf) 中用粗体清楚地写了。很难说您的代码中是否还有其他错误,但我建议您彻底阅读 DCGAN 和 Wasserstein GAN 论文,并对超参数进行真正的笔记。弄错它们真的会破坏 GAN 的性能,并且进行超参数搜索会很快变得昂贵。

顺便说一下,转置卷积会在您的输出图像中产生阶梯状伪影。改用图像调整大小。对于这种现象的深入解释,我可以推荐以下资源(https://distill.pub/2016/deconv-checkerboard/)。

关于python - 越来越大的正 WGAN-GP 损失,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53479523/

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com