大蟒蛇python教程共享Python实战小项目之Mnist手写数字识别

目录

    程序流程分析图:

    Python实战小项目之Mnist手写数字识别

    传播过程:

    Python实战小项目之Mnist手写数字识别

    Python实战小项目之Mnist手写数字识别

    代码展示:

    创建环境

    使用<pip install+包名>来下载torch,torchvision包

    准备数据集

    设置一次训练所选取的样本数batch_sized的值为512,训练此时epochs的值为8

      batch_size = 512  epochs = 8  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    下载数据集

    normalize()数字归一化,转换使用的值0.1307和0.3081是mnist数据集的全局平均值和标准偏差,这里我们将它们作为给定值。model

      train_loader = torch.utils.data.dataloader(      datasets.mnist('data', train=true, download=true,                     transform=transforms.compose([.                         transforms.totensor(),                         transforms.normalize((0.1307,), (0.3081,))                     ])),      batch_size=batch_size, shuffle=true)

    下载测试集

      test_loader = torch.utils.data.dataloader(      datasets.mnist('data', train=false,                     transform=transforms.compose([                         transforms.totensor(),                         transforms.normalize((0.1307,), (0.3081,))                     ])),      batch_size=batch_size, shuffle=true)

    绘制图像

    我们可以使用matplotlib来绘制其中的一些图像

      examples = enumerate(test_loader)  batch_idx, (example_data, example_targets) = next(examples)  print(example_targets)  print(example_data.shape)  print(example_data)     import matplotlib.pyplot as plt  fig = plt.figure()  for i in range(6):    plt.subplot(2,3,i+1)    plt.tight_layout()    plt.imshow(example_data[i][0], cmap='gray', interpolation='none')    plt.title("ground truth: {}".format(example_targets[i]))    plt.xticks([])    plt.yticks([])  plt.show()

    Python实战小项目之Mnist手写数字识别

    搭建神经网络

    这里我们构建全连接神经网络,我们使用三个全连接(或线性)层进行前向传播。

      class linearnet(nn.module):      def __init__(self):          super().__init__()          self.fc1 = nn.linear(784, 128)          self.fc2 = nn.linear(128, 64)          self.fc3 = nn.linear(64, 10)      def forward(self, x):          x = x.view(-1, 784)          x = self.fc1(x)          x = f.relu(x)          x = self.fc2(x)          x = f.relu(x)          x = self.fc3(x)          x = f.log_softmax(x, dim=1)          return x

    训练模型

    首先,我们需要使用optimizer.zero_grad()手动将梯度设置为零,因为pytorch在默认情况下会累积梯度。然后,我们生成网络的输出(前向传递),并计算输出与真值标签之间的负对数概率损失。现在,我们收集一组新的梯度,并使用optimizer.step()将其传播回每个网络参数。

      def train(model, device, train_loader, optimizer, epoch):      model.train()      for batch_idx, (data, target) in enumerate(train_loader):             data, target = data.to(device), target.to(device)          optimizer.zero_grad()          output = model(data)          loss = f.nll_loss(output, target)          loss.backward()          optimizer.step()          if (batch_idx) % 30 == 0:              print('train epoch: {} [{}/{} ({:.0f}%)]tloss: {:.6f}'.format(                  epoch, batch_idx * len(data), len(train_loader.dataset),                         100. * batch_idx / len(train_loader), loss.item()))

    测试模型

      def test(model, device, test_loader):      model.eval()      test_loss = 0      correct = 0      with torch.no_grad():          for data, target in test_loader:              data, target = data.to(device), target.to(device)              output = model(data)              test_loss += f.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加              pred = output.max(1, keepdim=true)[1] # 找到概率最大的下标              correct += pred.eq(target.view_as(pred)).sum().item()         test_loss /= len(test_loader.dataset)      print('ntest set: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)n'.format(          test_loss, correct, len(test_loader.dataset),          100. * correct / len(test_loader.dataset)))

    将训练次数进行循环

      if __name__ == '__main__':      model = linearnet()      optimizer = optim.adam(model.parameters())         for epoch in range(1, epochs + 1):          train(model, device, train_loader, optimizer, epoch)          test(model, device, test_loader)

    保存训练模型

      torch.save(model, 'mnist.pth')

    运行结果展示:

    Python实战小项目之Mnist手写数字识别

    Python实战小项目之Mnist手写数字识别

    Python实战小项目之Mnist手写数字识别

    分享人:苏云云

    到此这篇关于python实战小项目之mnist手写数字识别的文章就介绍到这了,更多相关python mnist手写数字识别内容请搜索<计算机技术网(www.ctvol.com)!!>以前的文章或继续浏览下面的相关文章希望大家以后多多支持<计算机技术网(www.ctvol.com)!!>!

    需要了解更多python教程分享Python实战小项目之Mnist手写数字识别,都可以关注python教程分享栏目—计算机技术网(www.ctvol.com)!

    本文来自网络收集,不代表计算机技术网立场,如涉及侵权请联系管理员删除。

    ctvol管理联系方式QQ:251552304

    本文章地址:https://www.ctvol.com/pythontutorial/886801.html

    (0)
    上一篇 2021年10月22日
    下一篇 2021年10月22日

    精彩推荐