- ubuntu12.04环境下使用kvm ioctl接口实现最简单的虚拟机
- Ubuntu 通过无线网络安装Ubuntu Server启动系统后连接无线网络的方法
- 在Ubuntu上搭建网桥的方法
- ubuntu 虚拟机上网方式及相关配置详解
CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.
这篇CFSDN的博客文章pytorch:实现简单的GAN示例(MNIST数据集)由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.
我就废话不多说了,直接上代码吧! 。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
|
# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
import
torch
from
torch
import
nn
from
torch.autograd
import
Variable
import
torchvision.transforms as tfs
from
torch.utils.data
import
DataLoader, sampler
from
torchvision.datasets
import
MNIST
import
numpy as np
import
matplotlib.pyplot as plt
import
matplotlib.gridspec as gridspec
plt.rcParams[
'figure.figsize'
]
=
(
10.0
,
8.0
)
# 设置画图的尺寸
plt.rcParams[
'image.interpolation'
]
=
'nearest'
plt.rcParams[
'image.cmap'
]
=
'gray'
def
show_images(images):
# 定义画图工具
images
=
np.reshape(images, [images.shape[
0
],
-
1
])
sqrtn
=
int
(np.ceil(np.sqrt(images.shape[
0
])))
sqrtimg
=
int
(np.ceil(np.sqrt(images.shape[
1
])))
fig
=
plt.figure(figsize
=
(sqrtn, sqrtn))
gs
=
gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace
=
0.05
, hspace
=
0.05
)
for
i, img
in
enumerate
(images):
ax
=
plt.subplot(gs[i])
plt.axis(
'off'
)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect(
'equal'
)
plt.imshow(img.reshape([sqrtimg,sqrtimg]))
return
def
preprocess_img(x):
x
=
tfs.ToTensor()(x)
return
(x
-
0.5
)
/
0.5
def
deprocess_img(x):
return
(x
+
1.0
)
/
2.0
class
ChunkSampler(sampler.Sampler):
# 定义一个取样的函数
"""Samples elements sequentially from some offset.
Arguments:
num_samples: # of desired datapoints
start: offset where we should start selecting from
"""
def
__init__(
self
, num_samples, start
=
0
):
self
.num_samples
=
num_samples
self
.start
=
start
def
__iter__(
self
):
return
iter
(
range
(
self
.start,
self
.start
+
self
.num_samples))
def
__len__(
self
):
return
self
.num_samples
NUM_TRAIN
=
50000
NUM_VAL
=
5000
NOISE_DIM
=
96
batch_size
=
128
train_set
=
MNIST(
'E:/data'
, train
=
True
, transform
=
preprocess_img)
train_data
=
DataLoader(train_set, batch_size
=
batch_size, sampler
=
ChunkSampler(NUM_TRAIN,
0
))
val_set
=
MNIST(
'E:/data'
, train
=
True
, transform
=
preprocess_img)
val_data
=
DataLoader(val_set, batch_size
=
batch_size, sampler
=
ChunkSampler(NUM_VAL, NUM_TRAIN))
imgs
=
deprocess_img(train_data.__iter__().
next
()[
0
].view(batch_size,
784
)).numpy().squeeze()
# 可视化图片效果
show_images(imgs)
#判别网络
def
discriminator():
net
=
nn.Sequential(
nn.Linear(
784
,
256
),
nn.LeakyReLU(
0.2
),
nn.Linear(
256
,
256
),
nn.LeakyReLU(
0.2
),
nn.Linear(
256
,
1
)
)
return
net
#生成网络
def
generator(noise_dim
=
NOISE_DIM):
net
=
nn.Sequential(
nn.Linear(noise_dim,
1024
),
nn.ReLU(
True
),
nn.Linear(
1024
,
1024
),
nn.ReLU(
True
),
nn.Linear(
1024
,
784
),
nn.Tanh()
)
return
net
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
bce_loss
=
nn.BCEWithLogitsLoss()
#交叉熵损失函数
def
discriminator_loss(logits_real, logits_fake):
# 判别器的 loss
size
=
logits_real.shape[
0
]
true_labels
=
Variable(torch.ones(size,
1
)).
float
()
false_labels
=
Variable(torch.zeros(size,
1
)).
float
()
loss
=
bce_loss(logits_real, true_labels)
+
bce_loss(logits_fake, false_labels)
return
loss
def
generator_loss(logits_fake):
# 生成器的 loss
size
=
logits_fake.shape[
0
]
true_labels
=
Variable(torch.ones(size,
1
)).
float
()
loss
=
bce_loss(logits_fake, true_labels)
return
loss
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def
get_optimizer(net):
optimizer
=
torch.optim.Adam(net.parameters(), lr
=
3e
-
4
, betas
=
(
0.5
,
0.999
))
return
optimizer
def
train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every
=
250
,
noise_size
=
96
, num_epochs
=
10
):
iter_count
=
0
for
epoch
in
range
(num_epochs):
for
x, _
in
train_data:
bs
=
x.shape[
0
]
# 判别网络
real_data
=
Variable(x).view(bs,
-
1
)
# 真实数据
logits_real
=
D_net(real_data)
# 判别网络得分
sample_noise
=
(torch.rand(bs, noise_size)
-
0.5
)
/
0.5
# -1 ~ 1 的均匀分布
g_fake_seed
=
Variable(sample_noise)
fake_images
=
G_net(g_fake_seed)
# 生成的假的数据
logits_fake
=
D_net(fake_images)
# 判别网络得分
d_total_error
=
discriminator_loss(logits_real, logits_fake)
# 判别器的 loss
D_optimizer.zero_grad()
d_total_error.backward()
D_optimizer.step()
# 优化判别网络
# 生成网络
g_fake_seed
=
Variable(sample_noise)
fake_images
=
G_net(g_fake_seed)
# 生成的假的数据
gen_logits_fake
=
D_net(fake_images)
g_error
=
generator_loss(gen_logits_fake)
# 生成网络的 loss
G_optimizer.zero_grad()
g_error.backward()
G_optimizer.step()
# 优化生成网络
if
(iter_count
%
show_every
=
=
0
):
print
(
'Iter: {}, D: {:.4}, G:{:.4}'
.
format
(iter_count, d_total_error.item(), g_error.item()))
imgs_numpy
=
deprocess_img(fake_images.data.cpu().numpy())
show_images(imgs_numpy[
0
:
16
])
plt.show()
print
()
iter_count
+
=
1
D
=
discriminator()
G
=
generator()
D_optim
=
get_optimizer(D)
G_optim
=
get_optimizer(G)
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)
|
以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我.
原文链接:https://blog.csdn.net/xckkcxxck/article/details/83037025 。
最后此篇关于pytorch:实现简单的GAN示例(MNIST数据集)的文章就讲到这里了,如果你想了解更多关于pytorch:实现简单的GAN示例(MNIST数据集)的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
这个问题在这里已经有了答案: 关闭 11 年前。 Possible Duplicate: Sample data for IPv6? 除了 wireshark 在其网站上提供的内容之外,是否有可以下
我正在寻找可以集成到现有应用程序中并使用多拖放功能的示例或任何现成的解决方案。我在互联网上找到的大多数解决方案在将多个项目从 ListBox 等控件拖放到另一个 ListBox 时效果不佳。谁能指出我
我是 GATE Embedded 的新手,我尝试了简单的示例并得到了 NoClassDefFoundError。首先我会解释我尝试了什么 在 D:\project\gate-7.0 中下载并提取 Ga
是否有像 Eclipse 中的 SWT 示例那样的多合一 JFace 控件示例?搜索(在 stackoverflow.com 上使用谷歌搜索和搜索)对我没有帮助。 如果它是一个独立的应用程序或 ecl
我找不到任何可以清楚地解释如何通过 .net API(特别是 c#)使用谷歌计算引擎的内容。有没有人可以指点我什么? 附言我知道 API 引用 ( https://developers.google.
最近在做公司的一个项目时,客户需要我们定时获取他们矩阵系统的数据。在与客户进行对接时,提到他们的接口使用的目前不常用的BASIC 认证。天呢,它好不安全,容易被不法人监听,咋还在使用呀。但是没办法呀,
最近在做公司的一个项目时,客户需要我们定时获取他们矩阵系统的数据。在与客户进行对接时,提到他们的接口使用的目前不常用的BASIC 认证。天呢,它好不安全,容易被不法人监听,咋还在使用呀。但是没办法呀,
我正在尝试为我的应用程序设计配置文件格式并选择了 YAML。但是,这(显然)意味着我需要能够定义、解析和验证正确的 YAML 语法! 在配置文件中,必须有一个名为 widgets 的集合/序列。 .这
你能给我一个使用 pysmb 库连接到一些 samba 服务器的例子吗?我读过有类 smb.SMBConnection.SMBConnection(用户名、密码、my_name、remote_name
linux服务器默认通过22端口用ssh协议登录,这种不安全。今天想做限制,即允许部分来源ip连接服务器。 案例目标:通过iptables规则限制对linux服务器的登录。 处理方法:编
我一直在寻找任何 PostProjectAnalysisTask 工作代码示例,但没有看。 This页面指出 HipChat plugin使用这个钩子(Hook),但在我看来它仍然使用遗留的 Po
我发现了 GWT 的 CustomScrollPanel 以及如何自定义滚动条,但我找不到任何示例或如何设置它。是否有任何示例显示正在使用的自定义滚动条? 最佳答案 这是自定义 native 滚动条的
我正在尝试开发一个 Backbone Marionette 应用程序,我需要知道如何以最佳方式执行 CRUD(创建、读取、更新和销毁)操作。我找不到任何解释这一点的资源(仅适用于 Backbone)。
关闭。这个问题需要details or clarity .它目前不接受答案。 想改进这个问题?通过 editing this post 添加详细信息并澄清问题. 去年关闭。 Improve this
我需要一个提交多个单独请求的 django 表单,如果没有大量定制,我找不到如何做到这一点的示例。即,假设有一个汽车维修店使用的表格。该表格将列出商店能够进行的所有可能的维修,并且用户将选择他们想要进
我有一个 Multi-Tenancy 应用程序。然而,这个相同的应用程序有 liquibase。我需要在我的所有数据源中运行 liquibase,但是我不能使用这个 Bean。 我的应用程序.yml
我了解有关单元测试的一般思想,并已在系统中发生复杂交互的场景中使用它,但我仍然对所有这些原则结合在一起有疑问。 我们被警告不要测试框架或数据库。好的 UI 设计不适合非人工测试。 MVC 框架不包括一
我正在使用 docjure并且它的 select-columns 函数需要一个列映射。我想获取所有列而无需手动指定。 如何将以下内容生成为惰性无限向量序列 [:A :B :C :D :E ... :A
$condition使用说明和 $param在 findByAttributes在 Yii 在大多数情况下,这就是我使用 findByAttributes 的方式 Person::model()->f
我在 Ubuntu 11.10 上安装了 qtcreator sudo apt-get install qtcreator 安装的版本有:QT Creator 2.2.1、QT 4.7.3 当我启动
我是一名优秀的程序员,十分优秀!