{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# LeNet5\n", "\n", "LeNet 诞生于 1994 年,是最早的卷积神经网络之一,并且推动了深度学习领域的发展。自从 1988 年开始,在多次迭代后这个开拓性成果被命名为 LeNet5。LeNet5 的架构的提出是基于如下的观点:图像的特征分布在整张图像上,通过带有可学习参数的卷积,从而有效的减少了参数数量,能够在多个位置上提取相似特征。\n", "\n", "在LeNet5提出的时候,没有 GPU 帮助训练,甚至 CPU 的速度也很慢,因此,LeNet5的规模并不大。其包含七个处理层,每一层都包含可训练参数(权重),当时使用的输入数据是 $32 \\times 32$ 像素的图像。LeNet-5 这个网络虽然很小,但是它包含了深度学习的基本模块:卷积层,池化层,全连接层。它是其他深度学习模型的基础,这里对LeNet5进行深入分析和讲解,通过实例分析,加深对与卷积层和池化层的理解。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义网络为:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "\n", "class LeNet5(nn.Module):\n", " def __init__(self):\n", " super(LeNet5, self).__init__()\n", " # 1-input channel, 6-output channels, 5x5-conv\n", " self.conv1 = nn.Conv2d(1, 6, 5)\n", " # 6-input channel, 16-output channels, 5x5-conv\n", " self.conv2 = nn.Conv2d(6, 16, 5)\n", " # 16x5x5-input, 120-output\n", " self.fc1 = nn.Linear(16 * 5 * 5, 120) \n", " # 120-input, 84-output\n", " self.fc2 = nn.Linear(120, 84)\n", " # 84-input, 10-output\n", " self.fc3 = nn.Linear(84, 10)\n", "\n", " def forward(self, x):\n", " x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n", " x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))\n", " x = torch.flatten(x, 1) # 将结果拉升成1维向量,除了批次的维度\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LeNet5(\n", " (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n", " (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n", " (fc1): Linear(in_features=400, out_features=120, bias=True)\n", " (fc2): Linear(in_features=120, out_features=84, bias=True)\n", " (fc3): Linear(in_features=84, out_features=10, bias=True)\n", ")\n" ] } ], "source": [ "net = LeNet5()\n", "print(net)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0.0124, 0.1326, 0.1647, 0.0728, 0.0722, 0.0113, 0.0829, -0.0055,\n", " 0.1749, -0.0581]], grad_fn=)\n" ] } ], "source": [ "input = torch.randn(1, 1, 32, 32)\n", "out = net(input)\n", "print(out)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from torchvision.datasets import mnist\n", "from torch.utils.data import DataLoader\n", "from torchvision import transforms as tfs\n", "from utils import train\n", "\n", "# 使用数据增强\n", "def data_tf(x):\n", " im_aug = tfs.Compose([\n", " tfs.Resize(32),\n", " tfs.ToTensor() #,\n", " #tfs.Normalize([0.5], [0.5])\n", " ])\n", " x = im_aug(x)\n", " return x\n", " \n", "train_set = mnist.MNIST('../../data/mnist', train=True, transform=data_tf, download=True) \n", "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n", "test_set = mnist.MNIST('../../data/mnist', train=False, transform=data_tf, download=True) \n", "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAATEAAAEICAYAAAA3EMMNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWHUlEQVR4nO3df6xU5Z3H8fdH1DZF02qsSBFFDavFpl6VYlONxVgrbWyQthrJxmDWiH9AVhNjVk2T0uzimi3irqs1xUr9EayS+Iu4TdW1WtdtigKlyo+yorIWuYGiUtBWDfDdP+bc7tw7M8/MvTP3znkun1cyuXPP95w5T6feD895zjnPUURgZparg7rdADOzdjjEzCxrDjEzy5pDzMyy5hAzs6w5xMwsaw4xM8uaQ8zqkvS8pA8lvV+8NnW7TWb1OMQsZX5EHFa8Tu52Y8zqcYiZWdYcYpbyz5J2SvpvSdO73RizeuR7J60eSWcBG4CPgcuAO4CeiHi9qw0zG8AhZi2R9AvgPyLi37vdFrNqPpy0VgWgbjfCbCCHmNWQ9BlJF0r6pKSDJf0tcC7wVLfbZjbQwd1ugJXSIcA/AacA+4DfAxdHhK8Vs9LxmJiZZc2Hk2aWNYeYmWXNIWZmWXOImVnWRvTspCSfRTAbZhHR1vV8M2bMiJ07d7a07urVq5+KiBnt7K9tETHkFzAD2ARsBm5oYf3wyy+/hvfVzt90RHDmmWdGq4BVTf7mJwLPARuB9cA1xfIFwNvA2uL1zaptbqSSKZuAC5u1d8g9MUljgDuBC4CtwMuSVkTEhqF+ppmVQwcvvdoLXBcRayQdDqyW9ExRuy0iFlWvLGkKlXt1TwU+B/ynpL+JiH2NdtDOmNg0YHNEvBERHwMPATPb+DwzK4n9+/e39GomInojYk3xfg+VHtmExCYzgYci4qOIeJNKj2xaah/thNgE4A9Vv2+t1zhJcyWtkrSqjX2Z2QgZ5JBSyyRNAk4HVhaL5kt6RdJSSUcUy1rKlWrthFi9wcOa/1URsSQipkbE1Db2ZWYjaBAhdlRfJ6V4za33eZIOAx4Bro2I3cBdwElAD9AL3Nq3ar3mpNraztnJrVQG7focC2xr4/PMrCQG0cva2ayDIukQKgG2LCIeLT5/e1X9buDJ4tdB50o7PbGXgcmSTpB0KJXBuBVtfJ6ZlUSnDiclCbgH2BgRi6uWj69abRawrni/ArhM0icknQBMBl5K7WPIPbGI2CtpPpXpWcYASyNi/VA/z8zKo4NnJ88GLgdelbS2WHYTMFtSD5VDxS3A1cV+10taTmVW4b3AvNSZSRjhWSx8savZ8Is2L3Y944wz4sUXX2xp3bFjx67u9ni35xMzsxoj2blpl0PMzGo4xMwsaw4xM8vWUC5k7SaHmJnVaOWWorJwiJlZDffEzCxbPpw0s+w5xMwsaw4xM8uaQ8zMshURPjtpZnlzT8zMsuYQM7OsOcTMLGsOMTPLlgf2zSx77omZWdYcYmaWNYeYmWXLN4CbWfYcYmaWNZ+dNLOsuSdmZtnymJiZZc8hZmZZc4iZWdYcYmaWLd87aWbZc0/MSmPMmDHJ+qc//elh3f/8+fMb1j71qU8ltz355JOT9Xnz5iXrixYtalibPXt2ctsPP/wwWb/llluS9R/84AfJetkdMCEmaQuwB9gH7I2IqZ1olJl11wETYoXzImJnBz7HzEriQAsxMxtFchvYP6jN7QN4WtJqSXPrrSBprqRVkla1uS8zGyF9V+03e5VBuyF2dkScAXwDmCfp3IErRMSSiJjq8TKzfHQqxCRNlPScpI2S1ku6plh+pKRnJL1W/DyiapsbJW2WtEnShc320VaIRcS24ucO4DFgWjufZ2bl0MGe2F7guoj4PPBlKp2dKcANwLMRMRl4tvidonYZcCowA/iRpOQp9iGHmKSxkg7vew98HVg31M8zs3JoNcBaCbGI6I2INcX7PcBGYAIwE7ivWO0+4OLi/UzgoYj4KCLeBDbTpHPUzsD+OOAxSX2f82BE/KKNzxu1jjvuuGT90EMPTda/8pWvJOvnnHNOw9pnPvOZ5Lbf+c53kvVu2rp1a7J+++23J+uzZs1qWNuzZ09y29/97nfJ+q9+9atkPXeDGO86asB495KIWFJvRUmTgNOBlcC4iOgt9tUr6ehitQnAb6o221osa2jIIRYRbwCnDXV7MyuvQZyd3NnKeLekw4BHgGsjYnfR+am7ap1lyURtd2DfzEahTp6dlHQIlQBbFhGPFou3Sxpf1McDO4rlW4GJVZsfC2xLfb5DzMz66eSYmCpdrnuAjRGxuKq0AphTvJ8DPFG1/DJJn5B0AjAZeCm1D1/samY1OngN2NnA5cCrktYWy24CbgGWS7oSeAu4pNjveknLgQ1UzmzOi4h9qR04xMysRqdCLCJepP44F8D5DbZZCCxsdR8OMTOrUZar8VvhEOuAnp6eZP2Xv/xlsj7c0+GUVbMzYN/73veS9ffffz9ZX7ZsWcNab29vctv33nsvWd+0aVOynrPc7p10iJlZDffEzCxrDjEzy5pDzMyy5hAzs2x5YN/MsueemJllzSF2gHnrrbeS9XfeeSdZL/N1YitXrkzWd+3alayfd955DWsff/xxctsHHnggWbfh4xAzs2yVaf78VjjEzKyGQ8zMsuazk2aWNffEzCxbHhMzs+w5xMwsaw6xA8y7776brF9//fXJ+kUXXZSs//a3v03Wmz26LGXt2rXJ+gUXXJCsf/DBB8n6qaee2rB2zTXXJLe17nGImVm2fO+kmWXPPTEzy5pDzMyy5hAzs6w5xMwsWx7YN7PsuSdm/Tz++OPJerPnUu7ZsydZP+200xrWrrzyyuS2ixYtStabXQfWzPr16xvW5s6d29Zn2/DJKcQOaraCpKWSdkhaV7XsSEnPSHqt+HnE8DbTzEZS3/2TzV5l0DTEgHuBGQOW3QA8GxGTgWeL381sFGg1wLIJsYh4ARh4X81M4L7i/X3AxZ1tlpl1U04hNtQxsXER0QsQEb2Sjm60oqS5gAc/zDLis5NVImIJsARAUjmi28waKlMvqxWtjInVs13SeIDi547ONcnMui2nw8mhhtgKYE7xfg7wRGeaY2ZlkFOINT2clPQzYDpwlKStwPeBW4Dlkq4E3gIuGc5Gjna7d+9ua/s//elPQ972qquuStYffvjhZD2nsRNrXVkCqhVNQywiZjcond/htphZCXTytiNJS4GLgB0R8YVi2QLgKuCPxWo3RcTPi9qNwJXAPuDvI+KpZvsY6uGkmY1iHTycvJfa60wBbouInuLVF2BTgMuAU4ttfiRpTLMdOMTMrEanQqzBdaaNzAQeioiPIuJNYDMwrdlGDjEzqzGIEDtK0qqqV6vXhM6X9EpxW2PfbYsTgD9UrbO1WJbkG8DNrMYgBvZ3RsTUQX78XcA/AlH8vBX4O0D1mtLswxxiZtbPcF8+ERHb+95Luht4svh1KzCxatVjgW3NPs8hNgosWLCgYe3MM89MbvvVr341Wf/a176WrD/99NPJuuVpOC+dkTS+77ZFYBbQN0POCuBBSYuBzwGTgZeafZ5DzMxqdKon1uA60+mSeqgcKm4Bri72uV7ScmADsBeYFxH7mu3DIWZmNToVYg2uM70nsf5CYOFg9uEQM7N+ynRLUSscYmZWwyFmZllziJlZ1nK6sd8hZmb9eEzMRlzqsWrNptpZs2ZNsn733Xcn688991yyvmrVqoa1O++8M7ltTn9Io01O371DzMxqOMTMLGsOMTPLVicnRRwJDjEzq+GemJllzSFmZllziJlZ1hxiVhqvv/56sn7FFVck6z/96U+T9csvv3zI9bFjxya3vf/++5P13t7eZN2Gxhe7mln2fHbSzLLmnpiZZc0hZmbZ8piYmWXPIWZmWXOImVnWfHbSsvHYY48l66+99lqyvnjx4mT9/PPPb1i7+eabk9sef/zxyfrChemH4rz99tvJutWX25jYQc1WkLRU0g5J66qWLZD0tqS1xeubw9tMMxtJfUHW7FUGTUMMuBeYUWf5bRHRU7x+3tlmmVk35RRiTQ8nI+IFSZNGoC1mVhJlCahWtNITa2S+pFeKw80jGq0kaa6kVZIaT7ZuZqXRNyliK68yGGqI3QWcBPQAvcCtjVaMiCURMTUipg5xX2Y2wkbV4WQ9EbG9772ku4EnO9YiM+u6sgRUK4bUE5M0vurXWcC6RuuaWX5GVU9M0s+A6cBRkrYC3wemS+oBAtgCXD18TbRuWrcu/e/TpZdemqx/61vfalhrNlfZ1Ven/7OaPHlysn7BBRck69ZYWQKqFa2cnZxdZ/E9w9AWMyuBMvWyWuEr9s2sRlnOPLbCIWZmNXLqibVznZiZjVKdGthvcNvikZKekfRa8fOIqtqNkjZL2iTpwlba6hAzs35aDbAWe2v3Unvb4g3AsxExGXi2+B1JU4DLgFOLbX4kaUyzHTjEzKxGp0IsIl4A3h2weCZwX/H+PuDiquUPRcRHEfEmsBmY1mwfHhOztuzatStZf+CBBxrWfvKTnyS3Pfjg9H+e5557brI+ffr0hrXnn38+ue2BbpjHxMZFRG+xn15JRxfLJwC/qVpva7EsySFmZjUGcXbyqAH3RS+JiCVD3K3qLGuapg4xM+tnkNeJ7RzCfdHbJY0vemHjgR3F8q3AxKr1jgW2Nfswj4mZWY1hvu1oBTCneD8HeKJq+WWSPiHpBGAy8FKzD3NPzMxqdGpMrMFti7cAyyVdCbwFXFLsc72k5cAGYC8wLyL2NduHQ8zManQqxBrctghQ9+ELEbEQSD88YQCHmJn10zcpYi4cYmZWI6fbjhxilvTFL34xWf/ud7+brH/pS19qWGt2HVgzGzZsSNZfeOGFtj7/QOYQM7OsOcTMLGsOMTPLlidFNLPs+eykmWXNPTEzy5pDzMyy5TExK5WTTz45WZ8/f36y/u1vfztZP+aYYwbdplbt25e+ba63tzdZz2lcp2wcYmaWtZz+AXCImVk/Ppw0s+w5xMwsaw4xM8uaQ8zMsuYQM7NsjbpJESVNBO4HjgH2U3kk079JOhJ4GJgEbAEujYj3hq+pB65m12LNnt1oBuDm14FNmjRpKE3qiFWrViXrCxemZylesWJFJ5tjVXLqibXytKO9wHUR8Xngy8C84nHjdR9Fbmb5G+anHXVU0xCLiN6IWFO83wNspPJU3kaPIjezzOUUYoMaE5M0CTgdWEnjR5GbWcbKFFCtaDnEJB0GPAJcGxG7pXpPHK+73Vxg7tCaZ2bdMOpCTNIhVAJsWUQ8Wixu9CjyfiJiCbCk+Jx8vhmzA1hOZyebjomp0uW6B9gYEYurSo0eRW5mmRttY2JnA5cDr0paWyy7iQaPIrda48aNS9anTJmSrN9xxx3J+imnnDLoNnXKypUrk/Uf/vCHDWtPPJH+dy+n3sBoUqaAakXTEIuIF4FGA2B1H0VuZnkbVSFmZgceh5iZZS2nQ3mHmJn1M+rGxMzswOMQM7OsOcTMLGsOsVHoyCOPbFj78Y9/nNy2p6cnWT/xxBOH0qSO+PWvf52s33rrrcn6U089laz/5S9/GXSbrPscYmaWrU5PiihpC7AH2AfsjYipnZyPsJX5xMzsADMMtx2dFxE9ETG1+L1j8xE6xMysxgjcO9mx+QgdYmZWYxAhdpSkVVWvetNuBfC0pNVV9X7zEQJDno/QY2Jm1s8ge1k7qw4RGzk7IrYVE6c+I+n37bWwP/fEzKxGJw8nI2Jb8XMH8BgwjWI+QoDUfIStcIiZWY39+/e39GpG0lhJh/e9B74OrKOD8xEeMIeTZ511VrJ+/fXXJ+vTpk1rWJswYcKQ2tQpf/7znxvWbr/99uS2N998c7L+wQcfDKlNlrcOXic2DnismM7+YODBiPiFpJfp0HyEB0yImVlrOnkDeES8AZxWZ/k7dGg+QoeYmdXwFftmljWHmJllzZMimlm2PCmimWXPIWZmWXOIldCsWbPaqrdjw4YNyfqTTz6ZrO/duzdZT835tWvXruS2ZvU4xMwsaw4xM8tWpydFHG4OMTOr4Z6YmWXNIWZmWXOImVm2fLGrmWUvpxBTs8ZKmgjcDxwD7AeWRMS/SVoAXAX8sVj1poj4eZPPyuebMctURKid7Q899ND47Gc/29K627ZtW93C9NTDqpWe2F7guohYU8zQuFrSM0XttohYNHzNM7NuyKkn1jTEiieR9D2VZI+kjUB3pzI1s2GT25jYoObYlzQJOB1YWSyaL+kVSUslHdFgm7l9j3Nqr6lmNlJG4LmTHdNyiEk6DHgEuDYidgN3AScBPVR6anVv4IuIJRExtdvHzWbWupxCrKWzk5IOoRJgyyLiUYCI2F5VvxtI38VsZtnI6bajpj0xVR5Tcg+wMSIWVy0fX7XaLCqPYTKzzLXaC8upJ3Y2cDnwqqS1xbKbgNmSeqg8onwLcPUwtM/MuqAsAdWKVs5OvgjUu+4keU2YmeVrVIWYmR14HGJmljWHmJlly5Mimln23BMzs6w5xMwsaw4xM8tWmS5kbYVDzMxqOMTMLGs+O2lmWXNPzMyylduY2KAmRTSzA0MnZ7GQNEPSJkmbJd3Q6bY6xMysRqdCTNIY4E7gG8AUKrPfTOlkW304aWY1OjiwPw3YHBFvAEh6CJgJbOjUDkY6xHYC/1v1+1HFsjIqa9vK2i5w24aqk207vgOf8RSVNrXikwOen7EkIpZU/T4B+EPV71uBs9psXz8jGmIR0e9hdpJWlXXu/bK2raztArdtqMrWtoiY0cGPqzcXYUfPGnhMzMyG01ZgYtXvxwLbOrkDh5iZDaeXgcmSTpB0KHAZsKKTO+j2wP6S5qt0TVnbVtZ2gds2VGVuW1siYq+k+VTG2cYASyNifSf3oZwuajMzG8iHk2aWNYeYmWWtKyE23LchtEPSFkmvSlo74PqXbrRlqaQdktZVLTtS0jOSXit+HlGiti2Q9Hbx3a2V9M0utW2ipOckbZS0XtI1xfKufneJdpXie8vViI+JFbch/A9wAZXTry8DsyOiY1fwtkPSFmBqRHT9wkhJ5wLvA/dHxBeKZf8CvBsRtxT/ABwREf9QkrYtAN6PiEUj3Z4BbRsPjI+INZIOB1YDFwNX0MXvLtGuSynB95arbvTE/nobQkR8DPTdhmADRMQLwLsDFs8E7ive30flj2DENWhbKUREb0SsKd7vATZSuXK8q99dol3Whm6EWL3bEMr0f2QAT0taLWlutxtTx7iI6IXKHwVwdJfbM9B8Sa8Uh5tdOdStJmkScDqwkhJ9dwPaBSX73nLSjRAb9tsQ2nR2RJxB5a77ecVhk7XmLuAkoAfoBW7tZmMkHQY8AlwbEbu72ZZqddpVqu8tN90IsWG/DaEdEbGt+LkDeIzK4W+ZbC/GVvrGWHZ0uT1/FRHbI2JfROwH7qaL352kQ6gExbKIeLRY3PXvrl67yvS95agbITbstyEMlaSxxYArksYCXwfWpbcacSuAOcX7OcATXWxLP30BUZhFl747SQLuATZGxOKqUle/u0btKsv3lquuXLFfnEL+V/7/NoSFI96IOiSdSKX3BZVbsh7sZtsk/QyYTmValO3A94HHgeXAccBbwCURMeID7A3aNp3KIVEAW4Cr+8agRrht5wD/BbwK9E2MdROV8aeufXeJds2mBN9brnzbkZllzVfsm1nWHGJmljWHmJllzSFmZllziJlZ1hxiZpY1h5iZZe3/APpBMI71CUMRAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 显示其中一个数据\n", "import matplotlib.pyplot as plt\n", "plt.imshow(train_set.data[0], cmap='gray')\n", "plt.title('%i' % train_set.targets[0])\n", "plt.colorbar()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([64, 1, 32, 32])\n", "torch.Size([64])\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS4AAAEICAYAAADhtRloAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZkElEQVR4nO3df5BdZZ3n8fcnTSLBiAHyqxMiiRgDWQoSDAFKFzO4cTpsWcDWZgQWyaRwM2wZC635Q3SrZty1LGBnddZxEarVILosrC5RI5WdjIWyaM2YTQItJESwZSA0iYltEgNJTNLhu3/c03rTfZ9zb3ffvvee7s+r6lb3fb7nnPvkdOfbz3nO8zxHEYGZWZFMaHYFzMyGyonLzArHicvMCseJy8wKx4nLzArHicvMCseJy8wKx4nLKpL0PyTtlXRY0ouSPtrsOpn1kwegWiWS/gXQHRHHJV0EPAn864jY3tyambnFZQkRsTMijve/zV4XNrFKZn/gxGVJkr4i6SjwC2AvsKnJVTIDfKloVUhqA64GlgP3RsTJ5tbIzC0uqyIiTkXET4Hzgf/Q7PqYgROX1e4M3MdlLcKJywaRNEPSTZKmSGqT9KfAzcCPml03M3Afl1UgaTrwv4HLKP1xewX4u4j4alMrZpZx4jKzwvGlopkVjhOXmRWOE5eZFY4Tl5kVzhmN/DBJvhNgNsoiQiPZv6OjI3p7e2vadvv27ZsjomMknzccI0pckjqALwFtwNci4p661MrMmqa3t5dt27bVtK2kaaNcnYqGfamYzWG7D1gJLAJulrSoXhUzs+aJiJpezTKSFtcySus1vQQg6VHgeuD5elTMzJrnzTffbHYVco0kcc0BXi173wNcOXAjSWuBtSP4HDNroGa3pmoxksRVqQNw0L82IjqBTnDnvFlRjOXE1QPMLXt/PrBnZNUxs1bQ6olrJOO4tgILJM2XNAm4CdhYn2qZWTON2c75iOiTtA7YTGk4xPqI2Fm3mplZ07R6i2tE47giYhNeh9xsTImIMX1X0czGqDHd4jKzscmJy8wKx4nLzAql2XcMa+HEZWaDuHPezArHLS4zKxRfKppZITlxmVnhOHGZWeE4cZlZoXjKj5kVkltcZlY4TlxmVjhOXGZWOE5cZlYo7pw3s0Jyi8tawhlnpH/U06dPT8Zmz56djC1alH7+79SpUyuWT5gwvMcc5LUADh8+nIx1dXVVLN+xY0dyn1OnTtVcr7HKicvMCseJy8wKxZOszayQnLjMrHB8V9HMCsctLjMrFPdxWa63v/3tydicOXOSsZkzZyZjZ511VsXyM888M7nPvHnzkrFLL700Gbv22muTsVmzZlUszxuWkSdviMK+ffuSsR/84AcVy7/85S8n9+nu7k7Gjh8/noyNJWM6cUl6GXgdOAX0RcTSelTKzJprTCeuzJ9ERG8djmNmLWI8JC4zG0OKMFdxePMv/iiAf5C0XdLaShtIWitpm6RtI/wsM2uQ/g76aq9mGWmL670RsUfSDOCHkn4REU+VbxARnUAngKTWbn+aGdD6l4ojanFFxJ7s637gu8CyelTKzJqrni0uSR2SXpDULemuCvG3S/qBpJ9L2ilpTbVjDrvFJemtwISIeD37/oPAfx7u8caqSZMmJWNXX311MrZq1apk7JprrknGUkMl8oYTTJ48ORnLq39fX18yduTIkSHvk/dZecM52tvbk7HVq1dXLD9x4kRyn3vuuScZ27NnTzI2ltSrxSWpDbgPWAH0AFslbYyI58s2+xjwfER8SNJ04AVJD0dE8oc0kkvFmcB3JfUf539GxN+P4Hhm1gLq3Dm/DOiOiJcAJD0KXA+UJ64A3qZSMpkCHADSf+EYQeLKKnLZcPc3s9Y1hBbXtAE33jqzfu1+c4BXy973AFcOOMZ/BzYCe4C3AR+OiNzM6eEQZjbIEBJXb5WB56p0+AHv/xToAq4FLqR0o+8nEZFcIXKkwyHMbAyqY+d8DzC37P35lFpW5dYAG6KkG/hn4KK8gzpxmdlpak1aNSaurcACSfMlTQJuonRZWG438AEASTOBhcBLeQf1paKZDVKvu4oR0SdpHbAZaAPWR8ROSXdk8QeAzwHfkPQcpUvLT1WbRujENcpSqyQAfPzjH0/GVqxYkYy1tbUlYwcPHqxYvnv37uQ+F1xwQTKWtxpC3qoMzzzzTMXy3/72t8Oqx3ve855kbMaMGclYaojFbbfdltyns7MzGRsvwyHqOeUnIjYBmwaUPVD2/R5Kw6lq5sRlZoO0+sh5Jy4zO02z5yHWwonLzAZx4jKzwnHiMrPCceIa5/IeOf+Wt7wlGctbmz115xDSa6x//vOfT+6Tt/Z9Nhe1oqNHjyZjqTuOx44dS+6zcuXKZGz69OnJWN5dxdTdsbw7onkTwceDIiwk6MRlZoO4xWVmhePEZWaF48RlZoXjxGVmheLOeTMrJLe4xrne3vQk97vvvjsZe/DBB5Oxw4eT66vxq1/9qmL5K6+8ktwnb8hGnry/yqnJzbNnz07uM2/evGQsb138vP9kv//97yuW553fvIng44UTl5kVjhOXmRWKJ1mbWSE5cZlZ4fiuopkVjltcZlYo7uOy3BUUtmzZkoxNnDgxGctbvSB1+//kyZPJfYarvb09Gevo6KhY/sEPppcWv+yy9POF84ZR5A05+d73vlex/Dvf+U5yn0OHDiVj40WrJ66qA3gkrZe0X9KOsrJzJf1Q0i+zr+eMbjXNrJHq+HiyUVHLyMNvAAP/fN4FPBERC4AnsvdmNkYUPnFFxFPAgQHF1wMPZd8/BNxQ32qZWbP0z1Ws5dUsw+3jmhkRewEiYq+k5BKUktYCa4f5OWbWBK3exzXqnfMR0Ql0Akhq7bNhZkDrJ67hza6FfZLaAbKv++tXJTNrtlbv4xpui2sjsBq4J/v6/brVaIzJ6wd444036v55bW1tFcunTZuW3Ofd7353MrZgwYJk7PLLL0/GrrrqqorlCxcuTO6T9/CQnp6eZOzHP/5xMnb//fdXLH/55ZeT+5w6dSoZGy9avcVVNXFJegRYDkyT1AP8NaWE9W1JtwO7gVWjWUkza5wxsZBgRNycCH2gznUxsxZR+BaXmY0/TlxmVjhOXGZWKM2+Y1gLJy4zG8SJy5IkJWN5Ky+84x3vSMZmzZpVsXz+/PnJfa6++upkLG84RN4xp0yZUrH8xIkTyX22bduWjG3cuDEZe+qpp5Kxrq6uZMzSCn9X0czGlyJcKg535LyZjWH1HDkvqUPSC5K6JVVcSUbSckldknZK+r/VjukWl5kNUq8Wl6Q24D5gBdADbJW0MSKeL9tmKvAVoCMiduct2tDPLS4zG6SOLa5lQHdEvBQRJ4BHKS2LVe4WYENE7M4+u+rcZycuMzvNENfjmiZpW9lr4BJWc4BXy973ZGXl3g2cI+lJSdsl3Vatjr5UNLNBhnCp2BsRS3PilW6dDzz4GcB7KE0jnAz8k6SfRcSLqYM6cY2ySZMmJWPz5s1LxlasWJGMXXPNNcnYBRdcULF8zpyBf+Rqi9XbsWPHkrG84RAPP/xwMrZ3794R1ckGq+NdxR5gbtn784E9FbbpjYgjwBFJTwGXAcnE5UtFMxukjn1cW4EFkuZLmgTcRGlZrHLfB/6lpDMknQVcCezKO6hbXGY2SL1aXBHRJ2kdsBloA9ZHxE5Jd2TxByJil6S/B54F3gS+FhE70kd14jKzAeo9ADUiNgGbBpQ9MOD93wB/U+sxnbjMbBBP+TGzwmn1KT9OXKNs6tSpyditt96ajN1yyy3J2Dvf+c4h1yNvQncjf0nz7rKm7ogCLFmyJBmbMCF9jylvrXpLc+Iys0IpwiRrJy4zG8SJy8wKx4nLzArHdxXNrFDcx2VmheTENc6dddZZydgVV1yRjM2dOzcZy1u3PfX4+Lym/8mTJ5OxPBMnThxybPLkycl9brjhhmQs71w98sgjydinP/3piuWp82QlrZ64qk6ylrRe0n5JO8rKPivptWyp1S5J141uNc2skeq5dPNoqGV1iG8AHRXK/zYiFmevTRXiZlZAQ1xIsCmqXipGxFOS5jWgLmbWIgp/qZhjnaRns0vJc1IbSVrbv6zrCD7LzBpoLFwqVnI/cCGwGNgLfCG1YUR0RsTSKsu7mlkLafXENay7ihGxr/97SV8FHq9bjcys6Vr9UnFYiUtSe0T0L/R9I5C7WuF49vrrrydjmzdvTsaOHj2ajPX19SVjzz33XMXyrVu3JvfZsmVLMpbnwx/+cDK2bNmyiuV56+W/613vSsZmzEg/au/9739/MnbxxRdXLN+1K70y8HgfKtHs1lQtqiYuSY8Ayyk9hqgH+GtguaTFlJ7W8TLwF6NXRTNrtMJP+YmImysUf30U6mJmLaLwLS4zG3+cuMysUMZEH5eZjT9OXGZWOK2euNTICkpq7bMxCtra2pKxc889Nxk7++yzk7G82/WpYRRHjhxJ7pMXy3POOckJE8n6X3TRRcl91qxZk4ytWrUqGdu/f38y9uCDD1Ys/9znPpfc59ixY8lYEURE+skoNWhvb4/Vq1fXtO299967vRmDy93iMrPTuI/LzArJicvMCseJy8wKx4nLzAqlfyHBVubEZWaDuMU1zuUNXfjNb36TjPX29g7r8xr5C3fw4MFk7NChQ0PeZ968ecnY0qXpO+5z5sxJxlasWFGx/O67707uY05cZlZATlxmVjhOXGZWKEUYgDqSh2WY2RhVz8eTSeqQ9IKkbkl35Wx3haRTkv5ttWM6cZnZIPV6WIakNuA+YCWwCLhZ0qLEdvcC6fXMy/hSsUW1elO9mlT989bg//Wvf52MHThwIBnLuxs5c+bMiuUTJvhvdp46/v4tA7oj4iUASY8C1wPPD9ju48BjwBW1HNQ/PTM7Ta2trSy5Tet/bmr2WjvgcHOAV8ve92RlfyBpDqWH7jxQax3d4jKzQYbQ4uqtsqxNpSV2Bh78vwGfiohTUm0r8jhxmdkgdbxU7AHmlr0/H9gzYJulwKNZ0poGXCepLyK+lzqoE5eZDVLHuYpbgQWS5gOvATcBt5RvEBHz+7+X9A3g8bykBU5cZjZAPcdxRUSfpHWU7ha2AesjYqekO7J4zf1a5Zy4zGyQet7VjohNwKYBZRUTVkT8eS3HrOVJ1nOBbwKzgDeBzoj4kqRzgf8FzKP0NOs/i4j0DFobVyZOnFixfMaMGcl95s+fn4zlrW9v9dfqw3FqGQ7RB/xlRFwMXAV8LBtAdhfwREQsAJ7I3pvZGFCvAaijpWqLKyL2Anuz71+XtIvSOIzrgeXZZg8BTwKfGpVamlnDjLmFBCXNA5YAW4CZWVIjIvZKSl8DmFmhtPqlYs2JS9IUSkPyPxERh2sdKJaNpB04mtbMWtiYSFySJlJKWg9HxIaseJ+k9qy11Q5UfCpnRHQCndlxWvtsmBnQ+omraue8Sk2rrwO7IuKLZaGNQP/jblcD369/9cysGQrfOQ+8F/gI8JykrqzsM8A9wLcl3Q7sBtLPSB8lkydPTsZSt+MBjh8/PqyYnW7KlCnJ2MKFCyuWv+9970vu86EPfSgZy1tX/uTJk8nY7373u4rlrd6iaKZmJ6Va1HJX8adUnigJ8IH6VsfMWsGYuqtoZuND4VtcZjb+OHGZWaGMiT4uMxt/nLjMrHDcOV8HZ599dsXySy65JLlP3ioE+/dXHCsLwKuvvpqMpR4f39fXl9xnNEyaNCkZa2trq1ieN2Rg2rRpw6rHxRdfnIytWlV5dMyKFSuS+8yePTsZy6v/nj0DF9T8ox/96EcVyxv9MysSXyqaWSE5cZlZ4ThxmVnhOHGZWeE4cZlZoYy5hQTNbHxwi6sOli9fXrH8k5/8ZHKfJUuWJGN5Qx6efvrpZOyxxx6rWJ43vGI0/nLNnTs3GUs9VGLfvn3JfdasWZOM5a2ysXRp+gHG5513XsXy1HANgBMnTiRjeUMeHn/88WTszjvvTMYszYnLzArHicvMCsUDUM2skJy4zKxwfFfRzArHLa46uPXWWyuW561fPmFC+jkgF110UTK2YMGCZOzGG2+sWN7oH3Levy0Vy/sLmjdpO0/eHcJTp05VLD906FByn23btiVj3/rWt5KxDRs2JGM2dO7jMrNCcuIys8Jx4jKzwnHnvJkVivu4zKyQnLjMrHAKn7gkzQW+CcwC3gQ6I+JLkj4L/HvgN9mmn4mITaNRyV27dlUsX7x4cXKfvEe2500czhsaMNxhA42UGoaQN4E5L5Zn06b0j7urq6ti+TPPPJPc58UXX0zGDhw4kIwdPXo0GbPhafXElR4Q9Ed9wF9GxMXAVcDHJC3KYn8bEYuz16gkLTNrvP5+rmqvWkjqkPSCpG5Jd1WI/ztJz2avf5R0WbVjVm1xRcReYG/2/euSdgHp5oyZFVo9FxKU1AbcB6wAeoCtkjZGxPNlm/0z8P6IOChpJdAJXJl33FpaXOWVmAcsAbZkReuyLLleUuWFoMyscOrY4loGdEfESxFxAngUuH7AZ/1jRPQ/++9nwPnVDlpz4pI0BXgM+EREHAbuBy4EFlNqkX0hsd9aSdskpedzmFlLGULimtb//zt7rR1wqDlA+cqdPeRfsd0O/J9q9avprqKkiZSS1sMRsSH7h+0ri38VqLgMZUR0Umr6Iam1e/zMDBhS53xvRKSXwgVVOnzFDaU/oZS40pOQM7XcVRTwdWBXRHyxrLw96/8CuBHYUe1YZtb66jwAtQcoX2v8fGDQOtySLgW+BqyMiN9WO2gtLa73Ah8BnpPUlZV9BrhZ0mJK2fNl4C9qONawpFYGePLJJ5P7zJo1Kxm78MILk7GFCxfWXK9W9Nprr1Us/8lPflL3z8pbu//gwYMVy/OGNRw5cmTEdbL6qGPi2goskDQfeA24CbilfANJ7wA2AB+JiPSYmDK13FX8KZWbex7+YDZG1euuYkT0SVoHbAbagPURsVPSHVn8AeCvgPOAr5Qu8OircvnpkfNmNlg9B6BmYzw3DSh7oOz7jwIfHcoxnbjM7DSeZG1mheTEZWaF48RlZoXjhQTroLu7e0jlAJMnT07Gpk+fnozNmDGj9oq1oNTDKPLOlVk593GZWSE5cZlZ4ThxmVnhOHGZWeE4cZlZodRzIcHR4sRlZoO4xdUkx44dS8Z27949rJjZeOHEZWaF48RlZoXiAahmVkhOXGZWOL6raGaF4xaXmRWK+7jMrJCcuMyscJy4zKxw3DlvZoXiPi4zKyQnLjMrnFZPXBOqbSDpTEn/T9LPJe2U9J+y8nMl/VDSL7Ov54x+dc2sEfovF6u9mqVq4gKOA9dGxGXAYqBD0lXAXcATEbEAeCJ7b2ZjQOETV5S8kb2dmL0CuB54KCt/CLhhNCpoZo3Vv5BgLa9mqaXFhaQ2SV3AfuCHEbEFmBkRewGyr8V+rpeZ/UGrt7hq6pyPiFPAYklTge9KuqTWD5C0Flg7vOqZWTMUvnO+XEQcAp4EOoB9ktoBsq/7E/t0RsTSiFg6sqqaWaO0eourlruK07OWFpImA/8K+AWwEVidbbYa+P4o1dHMGqjWpNXql4rtwEOS2iglum9HxOOS/gn4tqTbgd3AqlGsp5k1UKtfKqqRFZTU2mfDbAyICI1k/wkTJsTEiRNr2vbEiRPbm9EN5JHzZjZIq7e4nLjM7DTN7r+qxZDuKprZ+FDPznlJHZJekNQtadAMG5X8XRZ/VtLl1Y7pxGVmg9QrcWU39e4DVgKLgJslLRqw2UpgQfZaC9xf7bhOXGY2SB2n/CwDuiPipYg4ATxKabpgueuBb2bTC38GTO0fI5rS6D6uXuCV7Ptp2ftmcz1O53qcrmj1uKAOn7U5+7xanClpW9n7zojoLHs/B3i17H0PcOWAY1TaZg6wN/WhDU1cETG9/3tJ21phNL3r4Xq4HqeLiI46Hq7S0IyB15i1bHMaXyqa2WjqAeaWvT8f2DOMbU7jxGVmo2krsEDSfEmTgJsoTRcstxG4Lbu7eBXwu/6VZ1KaOY6rs/omDeF6nM71OJ3rMQIR0SdpHaV+szZgfUTslHRHFn8A2ARcB3QDR4E11Y7b0Ck/Zmb14EtFMyscJy4zK5ymJK5qUwAaWI+XJT0nqWvAWJTR/tz1kvZL2lFW1vCnJiXq8VlJr2XnpEvSdQ2ox1xJP5a0K3uS1J1ZeUPPSU49GnpO/GSt6hrex5VNAXgRWEHpNuhW4OaIeL6hFSnV5WVgaUQ0dIChpGuANyiNFr4kK/svwIGIuCdL5udExKeaUI/PAm9ExH8dzc8eUI92oD0inpb0NmA7pYev/DkNPCc59fgzGnhOJAl4a0S8IWki8FPgTuDf0ODfkVbVjBZXLVMAxrSIeAo4MKC44U9NStSj4SJib0Q8nX3/OrCL0sjphp6TnHo0lJ+sVV0zEldqeH8zBPAPkrZnD/VoplZ6atK6bJb++kZfjkiaBywBmvokqQH1gAafE/nJWrmakbiGPLx/FL03Ii6nNDv9Y9ml03h3P3AhpYf/7gW+0KgPljQFeAz4REQcbtTn1lCPhp+TiDgVEYspjSJfpiE8WWs8aEbiGvLw/tESEXuyr/uB71K6jG2Wmp6aNNoiYl/2n+ZN4Ks06JxkfTmPAQ9HxIasuOHnpFI9mnVOss8+xBCfrDUeNCNx1TIFYNRJemvWAYuktwIfBHbk7zWqWuKpSQOWE7mRBpyTrDP668CuiPhiWaih5yRVj0afE/nJWtUN5VFE9XpRGt7/IvAr4D82qQ7vBH6evXY2sh7AI5QuOU5SaoHeDpwHPAH8Mvt6bpPq8S3gOeBZSv9R2htQj/dR6i54FujKXtc1+pzk1KOh5wS4FHgm+7wdwF9l5Q3/HWnVl6f8mFnheOS8mRWOE5eZFY4Tl5kVjhOXmRWOE5eZFY4Tl5kVjhOXmRXO/weOzhS3T9Z1xAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]])\n" ] } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "# 显示转化后的图像\n", "for im, label in train_data:\n", " print(im.shape)\n", " print(label.shape)\n", " \n", " img = im[0,0,:,:]\n", " lab = label[0]\n", " plt.imshow(img, cmap='gray')\n", " plt.title('%i' % lab)\n", " plt.colorbar()\n", " plt.show()\n", "\n", " print(im[0,0,:,:])\n", " break" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0. Train Loss: 0.292838, Train Acc: 0.908382, Valid Loss: 0.075638, Valid Acc: 0.974684, Time 00:00:13\n", "Epoch 1. Train Loss: 0.077091, Train Acc: 0.976479, Valid Loss: 0.066128, Valid Acc: 0.978936, Time 00:00:14\n", "Epoch 2. Train Loss: 0.055866, Train Acc: 0.982759, Valid Loss: 0.042326, Valid Acc: 0.986748, Time 00:00:14\n", "Epoch 3. Train Loss: 0.043993, Train Acc: 0.986257, Valid Loss: 0.042040, Valid Acc: 0.986847, Time 00:00:14\n", "Epoch 4. Train Loss: 0.035289, Train Acc: 0.988823, Valid Loss: 0.035118, Valid Acc: 0.988430, Time 00:00:14\n", "Epoch 5. Train Loss: 0.030174, Train Acc: 0.990572, Valid Loss: 0.036890, Valid Acc: 0.988430, Time 00:00:14\n", "Epoch 6. Train Loss: 0.025604, Train Acc: 0.991571, Valid Loss: 0.028075, Valid Acc: 0.990803, Time 00:00:14\n", "Epoch 7. Train Loss: 0.021483, Train Acc: 0.993220, Valid Loss: 0.039955, Valid Acc: 0.988133, Time 00:00:14\n", "Epoch 8. Train Loss: 0.018553, Train Acc: 0.994020, Valid Loss: 0.031569, Valid Acc: 0.990506, Time 00:00:14\n", "Epoch 9. Train Loss: 0.016860, Train Acc: 0.994420, Valid Loss: 0.028923, Valid Acc: 0.990803, Time 00:00:14\n", "Epoch 10. Train Loss: 0.014547, Train Acc: 0.995186, Valid Loss: 0.041005, Valid Acc: 0.987737, Time 00:00:14\n", "Epoch 11. Train Loss: 0.011832, Train Acc: 0.996085, Valid Loss: 0.039684, Valid Acc: 0.989221, Time 00:00:14\n", "Epoch 12. Train Loss: 0.012104, Train Acc: 0.996019, Valid Loss: 0.033983, Valid Acc: 0.990012, Time 00:00:14\n", "Epoch 13. Train Loss: 0.009578, Train Acc: 0.996802, Valid Loss: 0.044510, Valid Acc: 0.989419, Time 00:00:14\n", "Epoch 14. Train Loss: 0.008961, Train Acc: 0.997018, Valid Loss: 0.033376, Valid Acc: 0.991693, Time 00:00:14\n", "Epoch 15. Train Loss: 0.008937, Train Acc: 0.997002, Valid Loss: 0.054347, Valid Acc: 0.986847, Time 00:00:15\n", "Epoch 16. Train Loss: 0.009171, Train Acc: 0.996902, Valid Loss: 0.034495, Valid Acc: 0.991594, Time 00:00:16\n", "Epoch 17. Train Loss: 0.006915, Train Acc: 0.997818, Valid Loss: 0.046391, Valid Acc: 0.989517, Time 00:00:16\n", "Epoch 18. Train Loss: 0.007419, Train Acc: 0.997651, Valid Loss: 0.044388, Valid Acc: 0.989419, Time 00:00:16\n", "Epoch 19. Train Loss: 0.006600, Train Acc: 0.998001, Valid Loss: 0.049959, Valid Acc: 0.987935, Time 00:00:16\n" ] } ], "source": [ "net = LeNet5()\n", "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n", "criterion = nn.CrossEntropyLoss()\n", "\n", "res = train(net, train_data, test_data, 20, optimizer, criterion)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "plt.plot(res[0], label='train')\n", "plt.plot(res[2], label='valid')\n", "plt.xlabel('epoch')\n", "plt.ylabel('Loss')\n", "plt.legend(loc='best')\n", "plt.savefig('fig-res-lenet5-train-validate-loss.pdf')\n", "plt.show()\n", "\n", "plt.plot(res[1], label='train')\n", "plt.plot(res[3], label='valid')\n", "plt.xlabel('epoch')\n", "plt.ylabel('Acc')\n", "plt.legend(loc='best')\n", "plt.savefig('fig-res-lenet5-train-validate-acc.pdf')\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 2 }