gpt4 book ai didi

Pytorch实现WGAN用于动漫头像生成

转载 作者:qq735679552 更新时间:2022-09-29 22:32:09 26 4
gpt4 key购买 nike

CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.

这篇CFSDN的博客文章Pytorch实现WGAN用于动漫头像生成由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.

WGAN与GAN的不同

  • 去除sigmoid
  • 使用具有动量的优化方法,比如使用RMSProp
  • 要对Discriminator的权重做修整限制以确保lipschitz连续约

WGAN实战卷积生成动漫头像 

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from anime_face_generator.dataset import ImageDataset
 
batch_size = 32
num_epoch = 100
z_dimension = 100
dir_path = './wgan_img'
 
# 创建文件夹
if not os.path.exists(dir_path):
   os.mkdir(dir_path)
 
 
def to_img(x):
   """因为我们在生成器里面用了tanh"""
   out = 0.5 * (x + 1 )
   return out
 
 
dataset = ImageDataset()
dataloader = DataLoader(dataset, batch_size = 32 , shuffle = False )
 
 
class Generator(nn.Module):
   def __init__( self ):
     super ().__init__()
 
     self .gen = nn.Sequential(
       # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
       nn.ConvTranspose2d( 100 , 512 , 4 , 1 , 0 , bias = False ),
       nn.BatchNorm2d( 512 ),
       nn.ReLU( True ),
       # 上一步的输出形状:(512) x 4 x 4
       nn.ConvTranspose2d( 512 , 256 , 4 , 2 , 1 , bias = False ),
       nn.BatchNorm2d( 256 ),
       nn.ReLU( True ),
       # 上一步的输出形状: (256) x 8 x 8
       nn.ConvTranspose2d( 256 , 128 , 4 , 2 , 1 , bias = False ),
       nn.BatchNorm2d( 128 ),
       nn.ReLU( True ),
       # 上一步的输出形状: (256) x 16 x 16
       nn.ConvTranspose2d( 128 , 64 , 4 , 2 , 1 , bias = False ),
       nn.BatchNorm2d( 64 ),
       nn.ReLU( True ),
       # 上一步的输出形状:(256) x 32 x 32
       nn.ConvTranspose2d( 64 , 3 , 5 , 3 , 1 , bias = False ),
       nn.Tanh() # 输出范围 -1~1 故而采用Tanh
       # nn.Sigmoid()
       # 输出形状:3 x 96 x 96
     )
 
   def forward( self , x):
     x = self .gen(x)
     return x
 
   def weight_init(m):
     # weight_initialization: important for wgan
     class_name = m.__class__.__name__
     if class_name.find( 'Conv' ) ! = - 1 :
       m.weight.data.normal_( 0 , 0.02 )
     elif class_name.find( 'Norm' ) ! = - 1 :
       m.weight.data.normal_( 1.0 , 0.02 )
 
 
class Discriminator(nn.Module):
   def __init__( self ):
     super ().__init__()
     self .dis = nn.Sequential(
       nn.Conv2d( 3 , 64 , 5 , 3 , 1 , bias = False ),
       nn.LeakyReLU( 0.2 , inplace = True ),
       # 输出 (64) x 32 x 32
 
       nn.Conv2d( 64 , 128 , 4 , 2 , 1 , bias = False ),
       nn.BatchNorm2d( 128 ),
       nn.LeakyReLU( 0.2 , inplace = True ),
       # 输出 (128) x 16 x 16
 
       nn.Conv2d( 128 , 256 , 4 , 2 , 1 , bias = False ),
       nn.BatchNorm2d( 256 ),
       nn.LeakyReLU( 0.2 , inplace = True ),
       # 输出 (256) x 8 x 8
 
       nn.Conv2d( 256 , 512 , 4 , 2 , 1 , bias = False ),
       nn.BatchNorm2d( 512 ),
       nn.LeakyReLU( 0.2 , inplace = True ),
       # 输出 (512) x 4 x 4
 
       nn.Conv2d( 512 , 1 , 4 , 1 , 0 , bias = False ),
       nn.Flatten(),
       # nn.Sigmoid() # 输出一个数(概率)
     )
 
   def forward( self , x):
     x = self .dis(x)
     return x
 
   def weight_init(m):
     # weight_initialization: important for wgan
     class_name = m.__class__.__name__
     if class_name.find( 'Conv' ) ! = - 1 :
       m.weight.data.normal_( 0 , 0.02 )
     elif class_name.find( 'Norm' ) ! = - 1 :
       m.weight.data.normal_( 1.0 , 0.02 )
 
 
def save(model, filename = "model.pt" , out_dir = "out/" ):
   if model is not None :
     if not os.path.exists(out_dir):
       os.mkdir(out_dir)
     torch.save({ 'model' : model.state_dict()}, out_dir + filename)
   else :
     print ( "[ERROR]:Please build a model!!!" )
 
 
import QuickModelBuilder as builder
 
if __name__ = = '__main__' :
   one = torch.FloatTensor([ 1 ]).cuda()
   mone = - 1 * one
 
   is_print = True
   # 创建对象
   D = Discriminator()
   G = Generator()
   D.weight_init()
   G.weight_init()
 
   if torch.cuda.is_available():
     D = D.cuda()
     G = G.cuda()
 
   lr = 2e - 4
   d_optimizer = torch.optim.RMSprop(D.parameters(), lr = lr, )
   g_optimizer = torch.optim.RMSprop(G.parameters(), lr = lr, )
   d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma = 0.99 )
   g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma = 0.99 )
 
   fake_img = None
 
   # ##########################进入训练##判别器的判断过程#####################
   for epoch in range (num_epoch): # 进行多个epoch的训练
     pbar = builder.MyTqdm(epoch = epoch, maxval = len (dataloader))
     for i, img in enumerate (dataloader):
       num_img = img.size( 0 )
       real_img = img.cuda() # 将tensor变成Variable放入计算图中
       # 这里的优化器是D的优化器
       for param in D.parameters():
         param.requires_grad = True
       # ########判别器训练train#####################
       # 分为两部分:1、真的图像判别为真;2、假的图像判别为假
 
       # 计算真实图片的损失
       d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0
       real_out = D(real_img) # 将真实图片放入判别器中
       d_loss_real = real_out.mean( 0 ).view( 1 )
       d_loss_real.backward(one)
 
       # 计算生成图片的损失
       z = torch.randn(num_img, z_dimension).cuda() # 随机生成一些噪声
       z = z.reshape(num_img, z_dimension, 1 , 1 )
       fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
       fake_out = D(fake_img) # 判别器判断假的图片,
       d_loss_fake = fake_out.mean( 0 ).view( 1 )
       d_loss_fake.backward(mone)
 
       d_loss = d_loss_fake - d_loss_real
       d_optimizer.step() # 更新参数
 
       # 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01
       for parm in D.parameters():
         parm.data.clamp_( - 0.01 , 0.01 )
 
       # ==================训练生成器============================
       # ###############################生成网络的训练###############################
       for param in D.parameters():
         param.requires_grad = False
 
       # 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D
       g_optimizer.zero_grad() # 梯度归0
 
       z = torch.randn(num_img, z_dimension).cuda()
       z = z.reshape(num_img, z_dimension, 1 , 1 )
       fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片
       output = D(fake_img) # 经过判别器得到的结果
       # g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss
       g_loss = torch.mean(output).view( 1 )
       # bp and optimize
       g_loss.backward(one) # 进行反向传播
       g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数
 
       # 打印中间的损失
       pbar.set_right_info(d_loss = d_loss.data.item(),
                 g_loss = g_loss.data.item(),
                 real_scores = real_out.data.mean().item(),
                 fake_scores = fake_out.data.mean().item(),
                 )
       pbar.update()
       try :
         fake_images = to_img(fake_img.cpu())
         save_image(fake_images, dir_path + '/fake_images-{}.png' . format (epoch + 1 ))
       except :
         pass
       if is_print:
         is_print = False
         real_images = to_img(real_img.cpu())
         save_image(real_images, dir_path + '/real_images.png' )
     pbar.finish()
     d_scheduler.step()
     g_scheduler.step()
     save(D, "wgan_D.pt" )
     save(G, "wgan_G.pt" )

到此这篇关于Pytorch实现WGAN用于动漫头像生成的文章就介绍到这了,更多相关Pytorch实现WGAN用于动漫头像生成内容请搜索我以前的文章或继续浏览下面的相关文章希望大家以后多多支持我! 。

原文链接:https://blog.csdn.net/bu_fo/article/details/109808354 。

最后此篇关于Pytorch实现WGAN用于动漫头像生成的文章就讲到这里了,如果你想了解更多关于Pytorch实现WGAN用于动漫头像生成的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com