作者热门文章
- ubuntu12.04环境下使用kvm ioctl接口实现最简单的虚拟机
- Ubuntu 通过无线网络安装Ubuntu Server启动系统后连接无线网络的方法
- 在Ubuntu上搭建网桥的方法
- ubuntu 虚拟机上网方式及相关配置详解
CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.
这篇CFSDN的博客文章Pytorch实现WGAN用于动漫头像生成由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
|
import
torch
import
torch.nn as nn
import
torchvision.transforms as transforms
from
torch.utils.data
import
DataLoader
from
torchvision.utils
import
save_image
import
os
from
anime_face_generator.dataset
import
ImageDataset
batch_size
=
32
num_epoch
=
100
z_dimension
=
100
dir_path
=
'./wgan_img'
# 创建文件夹
if
not
os.path.exists(dir_path):
os.mkdir(dir_path)
def
to_img(x):
"""因为我们在生成器里面用了tanh"""
out
=
0.5
*
(x
+
1
)
return
out
dataset
=
ImageDataset()
dataloader
=
DataLoader(dataset, batch_size
=
32
, shuffle
=
False
)
class
Generator(nn.Module):
def
__init__(
self
):
super
().__init__()
self
.gen
=
nn.Sequential(
# 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
nn.ConvTranspose2d(
100
,
512
,
4
,
1
,
0
, bias
=
False
),
nn.BatchNorm2d(
512
),
nn.ReLU(
True
),
# 上一步的输出形状:(512) x 4 x 4
nn.ConvTranspose2d(
512
,
256
,
4
,
2
,
1
, bias
=
False
),
nn.BatchNorm2d(
256
),
nn.ReLU(
True
),
# 上一步的输出形状: (256) x 8 x 8
nn.ConvTranspose2d(
256
,
128
,
4
,
2
,
1
, bias
=
False
),
nn.BatchNorm2d(
128
),
nn.ReLU(
True
),
# 上一步的输出形状: (256) x 16 x 16
nn.ConvTranspose2d(
128
,
64
,
4
,
2
,
1
, bias
=
False
),
nn.BatchNorm2d(
64
),
nn.ReLU(
True
),
# 上一步的输出形状:(256) x 32 x 32
nn.ConvTranspose2d(
64
,
3
,
5
,
3
,
1
, bias
=
False
),
nn.Tanh()
# 输出范围 -1~1 故而采用Tanh
# nn.Sigmoid()
# 输出形状:3 x 96 x 96
)
def
forward(
self
, x):
x
=
self
.gen(x)
return
x
def
weight_init(m):
# weight_initialization: important for wgan
class_name
=
m.__class__.__name__
if
class_name.find(
'Conv'
) !
=
-
1
:
m.weight.data.normal_(
0
,
0.02
)
elif
class_name.find(
'Norm'
) !
=
-
1
:
m.weight.data.normal_(
1.0
,
0.02
)
class
Discriminator(nn.Module):
def
__init__(
self
):
super
().__init__()
self
.dis
=
nn.Sequential(
nn.Conv2d(
3
,
64
,
5
,
3
,
1
, bias
=
False
),
nn.LeakyReLU(
0.2
, inplace
=
True
),
# 输出 (64) x 32 x 32
nn.Conv2d(
64
,
128
,
4
,
2
,
1
, bias
=
False
),
nn.BatchNorm2d(
128
),
nn.LeakyReLU(
0.2
, inplace
=
True
),
# 输出 (128) x 16 x 16
nn.Conv2d(
128
,
256
,
4
,
2
,
1
, bias
=
False
),
nn.BatchNorm2d(
256
),
nn.LeakyReLU(
0.2
, inplace
=
True
),
# 输出 (256) x 8 x 8
nn.Conv2d(
256
,
512
,
4
,
2
,
1
, bias
=
False
),
nn.BatchNorm2d(
512
),
nn.LeakyReLU(
0.2
, inplace
=
True
),
# 输出 (512) x 4 x 4
nn.Conv2d(
512
,
1
,
4
,
1
,
0
, bias
=
False
),
nn.Flatten(),
# nn.Sigmoid() # 输出一个数(概率)
)
def
forward(
self
, x):
x
=
self
.dis(x)
return
x
def
weight_init(m):
# weight_initialization: important for wgan
class_name
=
m.__class__.__name__
if
class_name.find(
'Conv'
) !
=
-
1
:
m.weight.data.normal_(
0
,
0.02
)
elif
class_name.find(
'Norm'
) !
=
-
1
:
m.weight.data.normal_(
1.0
,
0.02
)
def
save(model, filename
=
"model.pt"
, out_dir
=
"out/"
):
if
model
is
not
None
:
if
not
os.path.exists(out_dir):
os.mkdir(out_dir)
torch.save({
'model'
: model.state_dict()}, out_dir
+
filename)
else
:
print
(
"[ERROR]:Please build a model!!!"
)
import
QuickModelBuilder as builder
if
__name__
=
=
'__main__'
:
one
=
torch.FloatTensor([
1
]).cuda()
mone
=
-
1
*
one
is_print
=
True
# 创建对象
D
=
Discriminator()
G
=
Generator()
D.weight_init()
G.weight_init()
if
torch.cuda.is_available():
D
=
D.cuda()
G
=
G.cuda()
lr
=
2e
-
4
d_optimizer
=
torch.optim.RMSprop(D.parameters(), lr
=
lr, )
g_optimizer
=
torch.optim.RMSprop(G.parameters(), lr
=
lr, )
d_scheduler
=
torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma
=
0.99
)
g_scheduler
=
torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma
=
0.99
)
fake_img
=
None
# ##########################进入训练##判别器的判断过程#####################
for
epoch
in
range
(num_epoch):
# 进行多个epoch的训练
pbar
=
builder.MyTqdm(epoch
=
epoch, maxval
=
len
(dataloader))
for
i, img
in
enumerate
(dataloader):
num_img
=
img.size(
0
)
real_img
=
img.cuda()
# 将tensor变成Variable放入计算图中
# 这里的优化器是D的优化器
for
param
in
D.parameters():
param.requires_grad
=
True
# ########判别器训练train#####################
# 分为两部分:1、真的图像判别为真;2、假的图像判别为假
# 计算真实图片的损失
d_optimizer.zero_grad()
# 在反向传播之前,先将梯度归0
real_out
=
D(real_img)
# 将真实图片放入判别器中
d_loss_real
=
real_out.mean(
0
).view(
1
)
d_loss_real.backward(one)
# 计算生成图片的损失
z
=
torch.randn(num_img, z_dimension).cuda()
# 随机生成一些噪声
z
=
z.reshape(num_img, z_dimension,
1
,
1
)
fake_img
=
G(z).detach()
# 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
fake_out
=
D(fake_img)
# 判别器判断假的图片,
d_loss_fake
=
fake_out.mean(
0
).view(
1
)
d_loss_fake.backward(mone)
d_loss
=
d_loss_fake
-
d_loss_real
d_optimizer.step()
# 更新参数
# 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01
for
parm
in
D.parameters():
parm.data.clamp_(
-
0.01
,
0.01
)
# ==================训练生成器============================
# ###############################生成网络的训练###############################
for
param
in
D.parameters():
param.requires_grad
=
False
# 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D
g_optimizer.zero_grad()
# 梯度归0
z
=
torch.randn(num_img, z_dimension).cuda()
z
=
z.reshape(num_img, z_dimension,
1
,
1
)
fake_img
=
G(z)
# 随机噪声输入到生成器中,得到一副假的图片
output
=
D(fake_img)
# 经过判别器得到的结果
# g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss
g_loss
=
torch.mean(output).view(
1
)
# bp and optimize
g_loss.backward(one)
# 进行反向传播
g_optimizer.step()
# .step()一般用在反向传播后面,用于更新生成网络的参数
# 打印中间的损失
pbar.set_right_info(d_loss
=
d_loss.data.item(),
g_loss
=
g_loss.data.item(),
real_scores
=
real_out.data.mean().item(),
fake_scores
=
fake_out.data.mean().item(),
)
pbar.update()
try
:
fake_images
=
to_img(fake_img.cpu())
save_image(fake_images, dir_path
+
'/fake_images-{}.png'
.
format
(epoch
+
1
))
except
:
pass
if
is_print:
is_print
=
False
real_images
=
to_img(real_img.cpu())
save_image(real_images, dir_path
+
'/real_images.png'
)
pbar.finish()
d_scheduler.step()
g_scheduler.step()
save(D,
"wgan_D.pt"
)
save(G,
"wgan_G.pt"
)
|
到此这篇关于Pytorch实现WGAN用于动漫头像生成的文章就介绍到这了,更多相关Pytorch实现WGAN用于动漫头像生成内容请搜索我以前的文章或继续浏览下面的相关文章希望大家以后多多支持我! 。
原文链接:https://blog.csdn.net/bu_fo/article/details/109808354 。
最后此篇关于Pytorch实现WGAN用于动漫头像生成的文章就讲到这里了,如果你想了解更多关于Pytorch实现WGAN用于动漫头像生成的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
我正在研究 WGAN,并希望实现 WGAN-GP。 在其原始论文中,由于 1-Lipschitiz 约束,WGAN-GP 是通过梯度惩罚来实现的。但是像 Keras 这样的包可以将梯度范数限制为 1(
这是WGAN-GP的损失函数 gen_sample = model.generator(input_gen) disc_real = model.discriminator(real_image, r
我正在研究在 PyTorch 中使用带有梯度惩罚的 Wasserstein GAN,但始终得到大的、正的生成器损失,并且随着时间的推移而增加。 我从 Caogang's implementation
我是一名优秀的程序员,十分优秀!