gpt4 book ai didi

pytorch GAN伪造手写体mnist数据集方式

转载 作者:qq735679552 更新时间:2022-09-29 22:32:09 28 4
gpt4 key购买 nike

CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.

这篇CFSDN的博客文章pytorch GAN伪造手写体mnist数据集方式由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.

一,mnist数据集 。

pytorch GAN伪造手写体mnist数据集方式

形如上图的数字手写体就是mnist数据集.

二,GAN原理(生成对抗网络) 。

GAN网络一共由两部分组成:一个是伪造器(Generator,简称G),一个是判别器(Discrimniator,简称D) 。

一开始,G由服从某几个分布(如高斯分布)的噪音组成,生成的图片不断送给D判断是否正确,直到G生成的图片连D都判断以为是真的。D每一轮除了看过G生成的假图片以外,还要见数据集中的真图片,以前者和后者得到的损失函数值为依据更新D网络中的权值。因此G和D都在不停地更新权值。以下图为例:

pytorch GAN伪造手写体mnist数据集方式

在v1时的G只不过是 一堆噪声,见过数据集(real images)的D肯定能判断出G所生成的是假的。当然G也能知道D判断它是假的这个结果,因此G就会更新权值,到v2的时候,G就能生成更逼真的图片来让D判断,当然在v2时D也是会先看一次真图片,再去判断G所生成的图片。以此类推,不断循环就是GAN的思想.

三,训练代码 。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import argparse
import os
import numpy as np
import math
 
import torchvision.transforms as transforms
from torchvision.utils import save_image
 
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
 
import torch.nn as nn
import torch.nn.functional as F
import torch
 
os.makedirs( "images" , exist_ok = True )
 
parser = argparse.ArgumentParser()
parser.add_argument( "--n_epochs" , type = int , default = 200 , help = "number of epochs of training" )
parser.add_argument( "--batch_size" , type = int , default = 64 , help = "size of the batches" )
parser.add_argument( "--lr" , type = float , default = 0.0002 , help = "adam: learning rate" )
parser.add_argument( "--b1" , type = float , default = 0.5 , help = "adam: decay of first order momentum of gradient" )
parser.add_argument( "--b2" , type = float , default = 0.999 , help = "adam: decay of first order momentum of gradient" )
parser.add_argument( "--n_cpu" , type = int , default = 8 , help = "number of cpu threads to use during batch generation" )
parser.add_argument( "--latent_dim" , type = int , default = 100 , help = "dimensionality of the latent space" )
parser.add_argument( "--img_size" , type = int , default = 28 , help = "size of each image dimension" )
parser.add_argument( "--channels" , type = int , default = 1 , help = "number of image channels" )
parser.add_argument( "--sample_interval" , type = int , default = 400 , help = "interval betwen image samples" )
opt = parser.parse_args()
print (opt)
 
img_shape = (opt.channels, opt.img_size, opt.img_size) # 确定图片输入的格式为(1,28,28),由于mnist数据集是灰度图所以通道为1
cuda = True if torch.cuda.is_available() else False
 
 
class Generator(nn.Module):
  def __init__( self ):
   super (Generator, self ).__init__()
 
   def block(in_feat, out_feat, normalize = True ):
    layers = [nn.Linear(in_feat, out_feat)]
    if normalize:
     layers.append(nn.BatchNorm1d(out_feat, 0.8 ))
    layers.append(nn.LeakyReLU( 0.2 , inplace = True ))
    return layers
 
   self .model = nn.Sequential(
    * block(opt.latent_dim, 128 , normalize = False ),
    * block( 128 , 256 ),
    * block( 256 , 512 ),
    * block( 512 , 1024 ),
    nn.Linear( 1024 , int (np.prod(img_shape))),
    nn.Tanh()
   )
 
  def forward( self , z):
   img = self .model(z)
   img = img.view(img.size( 0 ), * img_shape)
   return img
 
 
class Discriminator(nn.Module):
  def __init__( self ):
   super (Discriminator, self ).__init__()
 
   self .model = nn.Sequential(
    nn.Linear( int (np.prod(img_shape)), 512 ),
    nn.LeakyReLU( 0.2 , inplace = True ),
    nn.Linear( 512 , 256 ),
    nn.LeakyReLU( 0.2 , inplace = True ),
    nn.Linear( 256 , 1 ),
    nn.Sigmoid(),
   )
 
  def forward( self , img):
   img_flat = img.view(img.size( 0 ), - 1 )
   validity = self .model(img_flat)
   return validity
 
 
# Loss function
adversarial_loss = torch.nn.BCELoss()
 
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
 
if cuda:
  generator.cuda()
  discriminator.cuda()
  adversarial_loss.cuda()
 
# Configure data loader
os.makedirs( "../../data/mnist" , exist_ok = True )
dataloader = torch.utils.data.DataLoader(
  datasets.MNIST(
   "../../data/mnist" ,
   train = True ,
   download = True ,
   transform = transforms.Compose(
    [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([ 0.5 ], [ 0.5 ])]
   ),
  ),
  batch_size = opt.batch_size,
  shuffle = True ,
)
 
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2))
 
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
 
# ----------
# Training
# ----------
if __name__ = = '__main__' :
  for epoch in range (opt.n_epochs):
   for i, (imgs, _) in enumerate (dataloader):
    # print(imgs.shape)
    # Adversarial ground truths
    valid = Variable(Tensor(imgs.size( 0 ), 1 ).fill_( 1.0 ), requires_grad = False ) # 全1
    fake = Variable(Tensor(imgs.size( 0 ), 1 ).fill_( 0.0 ), requires_grad = False ) # 全0
    # Configure input
    real_imgs = Variable(imgs. type (Tensor))
 
    # -----------------
    # Train Generator
    # -----------------
 
    optimizer_G.zero_grad() # 清空G网络 上一个batch的梯度
 
    # Sample noise as generator input
    z = Variable(Tensor(np.random.normal( 0 , 1 , (imgs.shape[ 0 ], opt.latent_dim)))) # 生成的噪音,均值为0方差为1维度为(64,100)的噪音
    # Generate a batch of images
    gen_imgs = generator(z)
    # Loss measures generator's ability to fool the discriminator
    g_loss = adversarial_loss(discriminator(gen_imgs), valid)
 
    g_loss.backward() # g_loss用于更新G网络的权值,g_loss于D网络的判断结果 有关
    optimizer_G.step()
 
    # ---------------------
    # Train Discriminator
    # ---------------------
 
    optimizer_D.zero_grad() # 清空D网络 上一个batch的梯度
    # Measure discriminator's ability to classify real from generated samples
    real_loss = adversarial_loss(discriminator(real_imgs), valid)
    fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
    d_loss = (real_loss + fake_loss) / 2
 
    d_loss.backward() # d_loss用于更新D网络的权值
    optimizer_D.step()
 
    print (
     "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
     % (epoch, opt.n_epochs, i, len (dataloader), d_loss.item(), g_loss.item())
    )
 
    batches_done = epoch * len (dataloader) + i
    if batches_done % opt.sample_interval = = 0 :
     save_image(gen_imgs.data[: 25 ], "images/%d.png" % batches_done, nrow = 5 , normalize = True ) # 保存一个batchsize中的25张
    if (epoch + 1 ) % 2 = = 0 :
     print ( 'save..' )
     torch.save(generator, 'g%d.pth' % epoch)
     torch.save(discriminator, 'd%d.pth' % epoch)

运行结果:

一开始时,G生成的全是杂音:

pytorch GAN伪造手写体mnist数据集方式

然后逐渐呈现数字的雏形:

pytorch GAN伪造手写体mnist数据集方式

最后一次生成的结果:

pytorch GAN伪造手写体mnist数据集方式

四,测试代码:

导入最后保存生成器的模型:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from gan import Generator,Discriminator
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np
from torchvision.utils import save_image
 
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
Tensor = torch.cuda.FloatTensor
g = torch.load( 'g199.pth' ) #导入生成器Generator模型
#d = torch.load('d.pth')
g = g.to(device)
#d = d.to(device)
 
z = Variable(Tensor(np.random.normal( 0 , 1 , ( 64 , 100 )))) #输入的噪音
gen_imgs = g(z) #生产图片
save_image(gen_imgs.data[: 25 ], "images.png" , nrow = 5 , normalize = True )

生成结果:

pytorch GAN伪造手写体mnist数据集方式

以上这篇pytorch GAN伪造手写体mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我.

原文链接:https://blog.csdn.net/u014453898/article/details/95044228 。

最后此篇关于pytorch GAN伪造手写体mnist数据集方式的文章就讲到这里了,如果你想了解更多关于pytorch GAN伪造手写体mnist数据集方式的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com