gpt4 book ai didi

Python实战小项目之Mnist手写数字识别

转载 作者:qq735679552 更新时间:2022-09-28 22:32:09 25 4
gpt4 key购买 nike

CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.

这篇CFSDN的博客文章Python实战小项目之Mnist手写数字识别由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.

  。

程序流程分析图:

Python实战小项目之Mnist手写数字识别

  。

传播过程:

Python实战小项目之Mnist手写数字识别

Python实战小项目之Mnist手写数字识别

  。

代码展示:

  。

创建环境

使用<pip install+包名>来下载torch,torchvision包 。

  。

准备数据集

设置一次训练所选取的样本数Batch_Sized的值为512,训练此时Epochs的值为8 。

BATCH_SIZE = 512EPOCHS = 8device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  。

下载数据集

Normalize()数字归一化,转换使用的值0.1307和0.3081是MNIST数据集的全局平均值和标准偏差,这里我们将它们作为给定值。model 。

train_loader = torch.utils.data.DataLoader(    datasets.MNIST("data", train=True, download=True,                   transform=transforms.Compose([.                       transforms.ToTensor(),                       transforms.Normalize((0.1307,), (0.3081,))                   ])),    batch_size=BATCH_SIZE, shuffle=True)

  。

下载测试集

test_loader = torch.utils.data.DataLoader(    datasets.MNIST("data", train=False,                   transform=transforms.Compose([                       transforms.ToTensor(),                       transforms.Normalize((0.1307,), (0.3081,))                   ])),    batch_size=BATCH_SIZE, shuffle=True)

  。

绘制图像

我们可以使用matplotlib来绘制其中的一些图像 。

examples = enumerate(test_loader)batch_idx, (example_data, example_targets) = next(examples)print(example_targets)print(example_data.shape)print(example_data) import matplotlib.pyplot as pltfig = plt.figure()for i in range(6):  plt.subplot(2,3,i+1)  plt.tight_layout()  plt.imshow(example_data[i][0], cmap="gray", interpolation="none")  plt.title("Ground Truth: {}".format(example_targets[i]))  plt.xticks([])  plt.yticks([])plt.show()

Python实战小项目之Mnist手写数字识别

  。

搭建神经网络

这里我们构建全连接神经网络,我们使用三个全连接(或线性)层进行前向传播.

class linearNet(nn.Module):    def __init__(self):        super().__init__()        self.fc1 = nn.Linear(784, 128)        self.fc2 = nn.Linear(128, 64)        self.fc3 = nn.Linear(64, 10)    def forward(self, x):        x = x.view(-1, 784)        x = self.fc1(x)        x = F.relu(x)        x = self.fc2(x)        x = F.relu(x)        x = self.fc3(x)        x = F.log_softmax(x, dim=1)        return x

  。

训练模型

首先,我们需要使用optimizer.zero_grad()手动将梯度设置为零,因为PyTorch在默认情况下会累积梯度。然后,我们生成网络的输出(前向传递),并计算输出与真值标签之间的负对数概率损失。现在,我们收集一组新的梯度,并使用optimizer.step()将其传播回每个网络参数.

def train(model, device, train_loader, optimizer, epoch):    model.train()    for batch_idx, (data, target) in enumerate(train_loader):         data, target = data.to(device), target.to(device)        optimizer.zero_grad()        output = model(data)        loss = F.nll_loss(output, target)        loss.backward()        optimizer.step()        if (batch_idx) % 30 == 0:            print("Train Epoch: {} [{}/{} ({:.0f}%)]	Loss: {:.6f}".format(                epoch, batch_idx * len(data), len(train_loader.dataset),                       100. * batch_idx / len(train_loader), loss.item()))

  。

测试模型

def test(model, device, test_loader):    model.eval()    test_loss = 0    correct = 0    with torch.no_grad():        for data, target in test_loader:            data, target = data.to(device), target.to(device)            output = model(data)            test_loss += F.nll_loss(output, target, reduction="sum").item() # 将一批的损失相加            pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标            correct += pred.eq(target.view_as(pred)).sum().item()     test_loss /= len(test_loader.dataset)    print("Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(        test_loss, correct, len(test_loader.dataset),        100. * correct / len(test_loader.dataset)))

将训练次数进行循环 。

if __name__ == "__main__":    model = linearNet()    optimizer = optim.Adam(model.parameters())     for epoch in range(1, EPOCHS + 1):        train(model, device, train_loader, optimizer, epoch)        test(model, device, test_loader)

  。

保存训练模型

torch.save(model, "MNIST.pth")

  。

运行结果展示:

Python实战小项目之Mnist手写数字识别

Python实战小项目之Mnist手写数字识别

Python实战小项目之Mnist手写数字识别

分享人:苏云云 。

到此这篇关于Python实战小项目之Mnist手写数字识别的文章就介绍到这了,更多相关Python Mnist手写数字识别内容请搜索我以前的文章或继续浏览下面的相关文章希望大家以后多多支持我! 。

原文链接:https://blog.csdn.net/weixin_40604528/article/details/120848106 。

最后此篇关于Python实战小项目之Mnist手写数字识别的文章就讲到这里了,如果你想了解更多关于Python实战小项目之Mnist手写数字识别的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com