- ubuntu12.04环境下使用kvm ioctl接口实现最简单的虚拟机
- Ubuntu 通过无线网络安装Ubuntu Server启动系统后连接无线网络的方法
- 在Ubuntu上搭建网桥的方法
- ubuntu 虚拟机上网方式及相关配置详解
CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.
这篇CFSDN的博客文章pytorch GAN伪造手写体mnist数据集方式由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.
一,mnist数据集 。
形如上图的数字手写体就是mnist数据集.
二,GAN原理(生成对抗网络) 。
GAN网络一共由两部分组成:一个是伪造器(Generator,简称G),一个是判别器(Discrimniator,简称D) 。
一开始,G由服从某几个分布(如高斯分布)的噪音组成,生成的图片不断送给D判断是否正确,直到G生成的图片连D都判断以为是真的。D每一轮除了看过G生成的假图片以外,还要见数据集中的真图片,以前者和后者得到的损失函数值为依据更新D网络中的权值。因此G和D都在不停地更新权值。以下图为例:
在v1时的G只不过是 一堆噪声,见过数据集(real images)的D肯定能判断出G所生成的是假的。当然G也能知道D判断它是假的这个结果,因此G就会更新权值,到v2的时候,G就能生成更逼真的图片来让D判断,当然在v2时D也是会先看一次真图片,再去判断G所生成的图片。以此类推,不断循环就是GAN的思想.
三,训练代码 。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
|
import
argparse
import
os
import
numpy as np
import
math
import
torchvision.transforms as transforms
from
torchvision.utils
import
save_image
from
torch.utils.data
import
DataLoader
from
torchvision
import
datasets
from
torch.autograd
import
Variable
import
torch.nn as nn
import
torch.nn.functional as F
import
torch
os.makedirs(
"images"
, exist_ok
=
True
)
parser
=
argparse.ArgumentParser()
parser.add_argument(
"--n_epochs"
,
type
=
int
, default
=
200
,
help
=
"number of epochs of training"
)
parser.add_argument(
"--batch_size"
,
type
=
int
, default
=
64
,
help
=
"size of the batches"
)
parser.add_argument(
"--lr"
,
type
=
float
, default
=
0.0002
,
help
=
"adam: learning rate"
)
parser.add_argument(
"--b1"
,
type
=
float
, default
=
0.5
,
help
=
"adam: decay of first order momentum of gradient"
)
parser.add_argument(
"--b2"
,
type
=
float
, default
=
0.999
,
help
=
"adam: decay of first order momentum of gradient"
)
parser.add_argument(
"--n_cpu"
,
type
=
int
, default
=
8
,
help
=
"number of cpu threads to use during batch generation"
)
parser.add_argument(
"--latent_dim"
,
type
=
int
, default
=
100
,
help
=
"dimensionality of the latent space"
)
parser.add_argument(
"--img_size"
,
type
=
int
, default
=
28
,
help
=
"size of each image dimension"
)
parser.add_argument(
"--channels"
,
type
=
int
, default
=
1
,
help
=
"number of image channels"
)
parser.add_argument(
"--sample_interval"
,
type
=
int
, default
=
400
,
help
=
"interval betwen image samples"
)
opt
=
parser.parse_args()
print
(opt)
img_shape
=
(opt.channels, opt.img_size, opt.img_size)
# 确定图片输入的格式为(1,28,28),由于mnist数据集是灰度图所以通道为1
cuda
=
True
if
torch.cuda.is_available()
else
False
class
Generator(nn.Module):
def
__init__(
self
):
super
(Generator,
self
).__init__()
def
block(in_feat, out_feat, normalize
=
True
):
layers
=
[nn.Linear(in_feat, out_feat)]
if
normalize:
layers.append(nn.BatchNorm1d(out_feat,
0.8
))
layers.append(nn.LeakyReLU(
0.2
, inplace
=
True
))
return
layers
self
.model
=
nn.Sequential(
*
block(opt.latent_dim,
128
, normalize
=
False
),
*
block(
128
,
256
),
*
block(
256
,
512
),
*
block(
512
,
1024
),
nn.Linear(
1024
,
int
(np.prod(img_shape))),
nn.Tanh()
)
def
forward(
self
, z):
img
=
self
.model(z)
img
=
img.view(img.size(
0
),
*
img_shape)
return
img
class
Discriminator(nn.Module):
def
__init__(
self
):
super
(Discriminator,
self
).__init__()
self
.model
=
nn.Sequential(
nn.Linear(
int
(np.prod(img_shape)),
512
),
nn.LeakyReLU(
0.2
, inplace
=
True
),
nn.Linear(
512
,
256
),
nn.LeakyReLU(
0.2
, inplace
=
True
),
nn.Linear(
256
,
1
),
nn.Sigmoid(),
)
def
forward(
self
, img):
img_flat
=
img.view(img.size(
0
),
-
1
)
validity
=
self
.model(img_flat)
return
validity
# Loss function
adversarial_loss
=
torch.nn.BCELoss()
# Initialize generator and discriminator
generator
=
Generator()
discriminator
=
Discriminator()
if
cuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
# Configure data loader
os.makedirs(
"../../data/mnist"
, exist_ok
=
True
)
dataloader
=
torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist"
,
train
=
True
,
download
=
True
,
transform
=
transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([
0.5
], [
0.5
])]
),
),
batch_size
=
opt.batch_size,
shuffle
=
True
,
)
# Optimizers
optimizer_G
=
torch.optim.Adam(generator.parameters(), lr
=
opt.lr, betas
=
(opt.b1, opt.b2))
optimizer_D
=
torch.optim.Adam(discriminator.parameters(), lr
=
opt.lr, betas
=
(opt.b1, opt.b2))
Tensor
=
torch.cuda.FloatTensor
if
cuda
else
torch.FloatTensor
# ----------
# Training
# ----------
if
__name__
=
=
'__main__'
:
for
epoch
in
range
(opt.n_epochs):
for
i, (imgs, _)
in
enumerate
(dataloader):
# print(imgs.shape)
# Adversarial ground truths
valid
=
Variable(Tensor(imgs.size(
0
),
1
).fill_(
1.0
), requires_grad
=
False
)
# 全1
fake
=
Variable(Tensor(imgs.size(
0
),
1
).fill_(
0.0
), requires_grad
=
False
)
# 全0
# Configure input
real_imgs
=
Variable(imgs.
type
(Tensor))
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# 清空G网络 上一个batch的梯度
# Sample noise as generator input
z
=
Variable(Tensor(np.random.normal(
0
,
1
, (imgs.shape[
0
], opt.latent_dim))))
# 生成的噪音,均值为0方差为1维度为(64,100)的噪音
# Generate a batch of images
gen_imgs
=
generator(z)
# Loss measures generator's ability to fool the discriminator
g_loss
=
adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
# g_loss用于更新G网络的权值,g_loss于D网络的判断结果 有关
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# 清空D网络 上一个batch的梯度
# Measure discriminator's ability to classify real from generated samples
real_loss
=
adversarial_loss(discriminator(real_imgs), valid)
fake_loss
=
adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss
=
(real_loss
+
fake_loss)
/
2
d_loss.backward()
# d_loss用于更新D网络的权值
optimizer_D.step()
print
(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
%
(epoch, opt.n_epochs, i,
len
(dataloader), d_loss.item(), g_loss.item())
)
batches_done
=
epoch
*
len
(dataloader)
+
i
if
batches_done
%
opt.sample_interval
=
=
0
:
save_image(gen_imgs.data[:
25
],
"images/%d.png"
%
batches_done, nrow
=
5
, normalize
=
True
)
# 保存一个batchsize中的25张
if
(epoch
+
1
)
%
2
=
=
0
:
print
(
'save..'
)
torch.save(generator,
'g%d.pth'
%
epoch)
torch.save(discriminator,
'd%d.pth'
%
epoch)
|
运行结果:
一开始时,G生成的全是杂音:
然后逐渐呈现数字的雏形:
最后一次生成的结果:
四,测试代码:
导入最后保存生成器的模型:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
from
gan
import
Generator,Discriminator
import
torch
import
matplotlib.pyplot as plt
from
torch.autograd
import
Variable
import
numpy as np
from
torchvision.utils
import
save_image
device
=
torch.device(
'cuda'
if
torch.cuda.is_available()
else
'cpu'
)
Tensor
=
torch.cuda.FloatTensor
g
=
torch.load(
'g199.pth'
)
#导入生成器Generator模型
#d = torch.load('d.pth')
g
=
g.to(device)
#d = d.to(device)
z
=
Variable(Tensor(np.random.normal(
0
,
1
, (
64
,
100
))))
#输入的噪音
gen_imgs
=
g(z)
#生产图片
save_image(gen_imgs.data[:
25
],
"images.png"
, nrow
=
5
, normalize
=
True
)
|
生成结果:
以上这篇pytorch GAN伪造手写体mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我.
原文链接:https://blog.csdn.net/u014453898/article/details/95044228 。
最后此篇关于pytorch GAN伪造手写体mnist数据集方式的文章就讲到这里了,如果你想了解更多关于pytorch GAN伪造手写体mnist数据集方式的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
我对 FakeItEasy(或其他模拟对象,因为我相信它们非常相似)有疑问。以下是我的 puesdocode: public class Service { public void Check
我想使用鼠标/键盘伪造操作(或触摸)事件。当我尝试使用以下方法引发事件时: RoutedEventArgs e = new RoutedEventArgs(ManipulationStartedEve
出于某种原因,我的 SVN 存储库的本地副本停止将父目录识别为工作副本。我通常会通过再次 checkout 到另一个文件夹并用我更改的文件覆盖新的工作副本来解决此问题。然后我会从新文件夹中进行提交。
我正在尝试设置我的开发实例,以便我可以伪造一些我拥有的网址。我的网站有许多网址,根据您访问的网址,我的网站的行为会因域的不同而有所不同。 我编辑了我的 C:\WINDOWS\system32\driv
当用户登录并选中“记住我”框时,我会为它生成一个 key (md5 上非常随机的数字)并保存在它的 cookie 上。如果用户未登录,我的代码会检查“记住我的 key ”cookie,如果它与用户匹配
有没有办法强制 Oracle 也“看到”一个表和相关索引比它们实际更大? 换句话说,有没有办法“伪造”数据库统计信息,因此基于成本的优化器会在几乎为空的数据库上做出决策,这更接近于在真实的大型生产数据
这是我使用 tsqlt 的第一天,所以你可能会看到一些含糊的陈述。 我正在尝试测试一个具有 Try Catch Block 的存储过程,但测试中的实际语句是插入和更新命令。 现在我想测试如果出现 Er
我从mockito开始,想知道如何假装添加观察者。我想编写一个测试来确保观察者计数在函数调用后增加。 示例测试代码: MyClassUnderTest instance = new MyClassUn
我是一名 C# 游戏开发人员,我有一个安全功能,我的服务器动态创建一个包含一些 key 的 DLL,并将这个 DLL 上传到 amazon s3,然后向人们提出挑战随机的。当客户收到此质询时,他们有
我正在尝试“伪造”一个 Canvas ,目的是将这个伪造的 Canvas 交给一个可能是任意的框架,以对所有直线、曲线和 moveTo 进行后处理。 为了解决这个问题,我尝试了这段代码,它确实有效,但
我的应用程序需要 SQL Server 2000 作为数据库存储。我真的不想使用 SQL Server 2000,但我可以改用 MySQL Server。 应用程序使用 ODBC 连接到 SQL Se
我有一个下拉菜单,需要一个带有左右边距的滚动条。我正在使用-webkit-scrollbar,但据我所知,它只支持沿滚动轴的边距,所以我一直在用容器内元素的右边距来近似水平边距,并在外部 div,如您
作为我学生小组业余项目的一部分,我正在创建微 Controller 有线网络的模拟,以测试我们编写的算法。每个 Controller 都连接到多个数据端口,每个端口都有一个输入和输出流。我通过给每个端
我已经在 Forge 中安装了自定义 SSL 证书。现在我的网站宕机了 -_-。 Site is not available connection refused 我已经重新启动了我的服务器,但没有任
我正在开发具有 ListView 和详细 View 的应用程序,并且我从 ListView 到详细 View 设置动画。在执行此操作时,我想在某个阶段隐藏状态栏(最好同时在后台显示 ListView
我想用它在 MS-Test 单元测试中伪造 System.Net.Mail.SmtpClient。为此,我添加了一个 System.dll 的 Fakes Assembmly。然后我创建一个 Shim
在我的 Playframework 2.4 项目中,我有这样的方法: public static Result resetValue(int client) { String received
这是我渲染场景的过程: 绑定(bind) MSAA x4 GBuffer(4 种颜色附件、位置、法线、颜色和无光照颜色(仅天空盒。我还有一个深度组件/纹理)。 绘制天空盒 绘制地理 将所有颜色和深度分
我不太确定 $_SESSION 在 PHP 中是如何工作的。我假设它是浏览器上的 cookie 与服务器上的唯一 key 匹配。是否可以伪造并绕过仅使用 session 来识别用户的登录。 如果 $_
大家好,我是沙漠尽头的狼。 本文首发于 Dotnet9 ,介绍使用 Lib.Harmony 库拦截第三方 .NET 库方法,达到不修改其源码并能实现修改方法逻辑、预期行为的效果,并
我是一名优秀的程序员,十分优秀!