{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# LeNet5\n", "\n", "LeNet 诞生于 1994 年,是最早的卷积神经网络之一,并且推动了深度学习领域的发展。自从 1988 年开始,在多次迭代后这个开拓性成果被命名为 LeNet5。LeNet5 的架构的提出是基于如下的观点:图像的特征分布在整张图像上,通过带有可学习参数的卷积,从而有效的减少了参数数量,能够在多个位置上提取相似特征。\n", "\n", "在LeNet5提出的时候,没有 GPU 帮助训练,甚至 CPU 的速度也很慢,因此,LeNet5的规模并不大。其包含七个处理层,每一层都包含可训练参数(权重),当时使用的输入数据是 $32 \\times 32$ 像素的图像。LeNet-5 这个网络虽然很小,但是它包含了深度学习的基本模块:卷积层,池化层,全连接层。它是其他深度学习模型的基础,这里对LeNet5进行深入分析和讲解,通过实例分析,加深对与卷积层和池化层的理解。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义网络为:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "\n", "class LeNet5(nn.Module):\n", " def __init__(self):\n", " super(LeNet5, self).__init__()\n", " # 1-input channel, 6-output channels, 5x5-conv\n", " self.conv1 = nn.Conv2d(1, 6, 5)\n", " # 6-input channel, 16-output channels, 5x5-conv\n", " self.conv2 = nn.Conv2d(6, 16, 5)\n", " # 16x5x5-input, 120-output\n", " self.fc1 = nn.Linear(16 * 5 * 5, 120) \n", " # 120-input, 84-output\n", " self.fc2 = nn.Linear(120, 84)\n", " # 84-input, 10-output\n", " self.fc3 = nn.Linear(84, 10)\n", "\n", " def forward(self, x):\n", " x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n", " x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))\n", " x = torch.flatten(x, 1) # 将结果拉升成1维向量,除了批次的维度\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LeNet5(\n", " (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n", " (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n", " (fc1): Linear(in_features=400, out_features=120, bias=True)\n", " (fc2): Linear(in_features=120, out_features=84, bias=True)\n", " (fc3): Linear(in_features=84, out_features=10, bias=True)\n", ")\n" ] } ], "source": [ "net = LeNet5()\n", "print(net)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0.0124, 0.1326, 0.1647, 0.0728, 0.0722, 0.0113, 0.0829, -0.0055,\n", " 0.1749, -0.0581]], grad_fn=)\n" ] } ], "source": [ "input = torch.randn(1, 1, 32, 32)\n", "out = net(input)\n", "print(out)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from torchvision.datasets import mnist\n", "from torch.utils.data import DataLoader\n", "from torchvision import transforms as tfs\n", "from utils import train\n", "\n", "# 使用数据增强\n", "def data_tf(x):\n", " im_aug = tfs.Compose([\n", " tfs.Resize(32),\n", " tfs.ToTensor() #,\n", " #tfs.Normalize([0.5], [0.5])\n", " ])\n", " x = im_aug(x)\n", " return x\n", " \n", "train_set = mnist.MNIST('../../data/mnist', train=True, transform=data_tf, download=True) \n", "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n", "test_set = mnist.MNIST('../../data/mnist', train=False, transform=data_tf, download=True) \n", "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAATEAAAEICAYAAAA3EMMNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWHUlEQVR4nO3df6xU5Z3H8fdH1DZF02qsSBFFDavFpl6VYlONxVgrbWyQthrJxmDWiH9AVhNjVk2T0uzimi3irqs1xUr9EayS+Iu4TdW1WtdtigKlyo+yorIWuYGiUtBWDfDdP+bc7tw7M8/MvTP3znkun1cyuXPP95w5T6feD895zjnPUURgZparg7rdADOzdjjEzCxrDjEzy5pDzMyy5hAzs6w5xMwsaw4xM8uaQ8zqkvS8pA8lvV+8NnW7TWb1OMQsZX5EHFa8Tu52Y8zqcYiZWdYcYpbyz5J2SvpvSdO73RizeuR7J60eSWcBG4CPgcuAO4CeiHi9qw0zG8AhZi2R9AvgPyLi37vdFrNqPpy0VgWgbjfCbCCHmNWQ9BlJF0r6pKSDJf0tcC7wVLfbZjbQwd1ugJXSIcA/AacA+4DfAxdHhK8Vs9LxmJiZZc2Hk2aWNYeYmWXNIWZmWXOImVnWRvTspCSfRTAbZhHR1vV8M2bMiJ07d7a07urVq5+KiBnt7K9tETHkFzAD2ARsBm5oYf3wyy+/hvfVzt90RHDmmWdGq4BVTf7mJwLPARuB9cA1xfIFwNvA2uL1zaptbqSSKZuAC5u1d8g9MUljgDuBC4CtwMuSVkTEhqF+ppmVQwcvvdoLXBcRayQdDqyW9ExRuy0iFlWvLGkKlXt1TwU+B/ynpL+JiH2NdtDOmNg0YHNEvBERHwMPATPb+DwzK4n9+/e39GomInojYk3xfg+VHtmExCYzgYci4qOIeJNKj2xaah/thNgE4A9Vv2+t1zhJcyWtkrSqjX2Z2QgZ5JBSyyRNAk4HVhaL5kt6RdJSSUcUy1rKlWrthFi9wcOa/1URsSQipkbE1Db2ZWYjaBAhdlRfJ6V4za33eZIOAx4Bro2I3cBdwElAD9AL3Nq3ar3mpNraztnJrVQG7focC2xr4/PMrCQG0cva2ayDIukQKgG2LCIeLT5/e1X9buDJ4tdB50o7PbGXgcmSTpB0KJXBuBVtfJ6ZlUSnDiclCbgH2BgRi6uWj69abRawrni/ArhM0icknQBMBl5K7WPIPbGI2CtpPpXpWcYASyNi/VA/z8zKo4NnJ88GLgdelbS2WHYTMFtSD5VDxS3A1cV+10taTmVW4b3AvNSZSRjhWSx8savZ8Is2L3Y944wz4sUXX2xp3bFjx67u9ni35xMzsxoj2blpl0PMzGo4xMwsaw4xM8vWUC5k7SaHmJnVaOWWorJwiJlZDffEzCxbPpw0s+w5xMwsaw4xM8uaQ8zMshURPjtpZnlzT8zMsuYQM7OsOcTMLGsOMTPLlgf2zSx77omZWdYcYmaWNYeYmWXLN4CbWfYcYmaWNZ+dNLOsuSdmZtnymJiZZc8hZmZZc4iZWdYcYmaWLd87aWbZc0/MSmPMmDHJ+qc//elh3f/8+fMb1j71qU8ltz355JOT9Xnz5iXrixYtalibPXt2ctsPP/wwWb/llluS9R/84AfJetkdMCEmaQuwB9gH7I2IqZ1olJl11wETYoXzImJnBz7HzEriQAsxMxtFchvYP6jN7QN4WtJqSXPrrSBprqRVkla1uS8zGyF9V+03e5VBuyF2dkScAXwDmCfp3IErRMSSiJjq8TKzfHQqxCRNlPScpI2S1ku6plh+pKRnJL1W/DyiapsbJW2WtEnShc320VaIRcS24ucO4DFgWjufZ2bl0MGe2F7guoj4PPBlKp2dKcANwLMRMRl4tvidonYZcCowA/iRpOQp9iGHmKSxkg7vew98HVg31M8zs3JoNcBaCbGI6I2INcX7PcBGYAIwE7ivWO0+4OLi/UzgoYj4KCLeBDbTpHPUzsD+OOAxSX2f82BE/KKNzxu1jjvuuGT90EMPTda/8pWvJOvnnHNOw9pnPvOZ5Lbf+c53kvVu2rp1a7J+++23J+uzZs1qWNuzZ09y29/97nfJ+q9+9atkPXeDGO86asB495KIWFJvRUmTgNOBlcC4iOgt9tUr6ehitQnAb6o221osa2jIIRYRbwCnDXV7MyuvQZyd3NnKeLekw4BHgGsjYnfR+am7ap1lyURtd2DfzEahTp6dlHQIlQBbFhGPFou3Sxpf1McDO4rlW4GJVZsfC2xLfb5DzMz66eSYmCpdrnuAjRGxuKq0AphTvJ8DPFG1/DJJn5B0AjAZeCm1D1/samY1OngN2NnA5cCrktYWy24CbgGWS7oSeAu4pNjveknLgQ1UzmzOi4h9qR04xMysRqdCLCJepP44F8D5DbZZCCxsdR8OMTOrUZar8VvhEOuAnp6eZP2Xv/xlsj7c0+GUVbMzYN/73veS9ffffz9ZX7ZsWcNab29vctv33nsvWd+0aVOynrPc7p10iJlZDffEzCxrDjEzy5pDzMyy5hAzs2x5YN/MsueemJllzSF2gHnrrbeS9XfeeSdZL/N1YitXrkzWd+3alayfd955DWsff/xxctsHHnggWbfh4xAzs2yVaf78VjjEzKyGQ8zMsuazk2aWNffEzCxbHhMzs+w5xMwsaw6xA8y7776brF9//fXJ+kUXXZSs//a3v03Wmz26LGXt2rXJ+gUXXJCsf/DBB8n6qaee2rB2zTXXJLe17nGImVm2fO+kmWXPPTEzy5pDzMyy5hAzs6w5xMwsWx7YN7PsuSdm/Tz++OPJerPnUu7ZsydZP+200xrWrrzyyuS2ixYtStabXQfWzPr16xvW5s6d29Zn2/DJKcQOaraCpKWSdkhaV7XsSEnPSHqt+HnE8DbTzEZS3/2TzV5l0DTEgHuBGQOW3QA8GxGTgWeL381sFGg1wLIJsYh4ARh4X81M4L7i/X3AxZ1tlpl1U04hNtQxsXER0QsQEb2Sjm60oqS5gAc/zDLis5NVImIJsARAUjmi28waKlMvqxWtjInVs13SeIDi547ONcnMui2nw8mhhtgKYE7xfg7wRGeaY2ZlkFOINT2clPQzYDpwlKStwPeBW4Dlkq4E3gIuGc5Gjna7d+9ua/s//elPQ972qquuStYffvjhZD2nsRNrXVkCqhVNQywiZjcond/htphZCXTytiNJS4GLgB0R8YVi2QLgKuCPxWo3RcTPi9qNwJXAPuDvI+KpZvsY6uGkmY1iHTycvJfa60wBbouInuLVF2BTgMuAU4ttfiRpTLMdOMTMrEanQqzBdaaNzAQeioiPIuJNYDMwrdlGDjEzqzGIEDtK0qqqV6vXhM6X9EpxW2PfbYsTgD9UrbO1WJbkG8DNrMYgBvZ3RsTUQX78XcA/AlH8vBX4O0D1mtLswxxiZtbPcF8+ERHb+95Luht4svh1KzCxatVjgW3NPs8hNgosWLCgYe3MM89MbvvVr341Wf/a176WrD/99NPJuuVpOC+dkTS+77ZFYBbQN0POCuBBSYuBzwGTgZeafZ5DzMxqdKon1uA60+mSeqgcKm4Bri72uV7ScmADsBeYFxH7mu3DIWZmNToVYg2uM70nsf5CYOFg9uEQM7N+ynRLUSscYmZWwyFmZllziJlZ1nK6sd8hZmb9eEzMRlzqsWrNptpZs2ZNsn733Xcn688991yyvmrVqoa1O++8M7ltTn9Io01O371DzMxqOMTMLGsOMTPLVicnRRwJDjEzq+GemJllzSFmZllziJlZ1hxiVhqvv/56sn7FFVck6z/96U+T9csvv3zI9bFjxya3vf/++5P13t7eZN2Gxhe7mln2fHbSzLLmnpiZZc0hZmbZ8piYmWXPIWZmWXOImVnWfHbSsvHYY48l66+99lqyvnjx4mT9/PPPb1i7+eabk9sef/zxyfrChemH4rz99tvJutWX25jYQc1WkLRU0g5J66qWLZD0tqS1xeubw9tMMxtJfUHW7FUGTUMMuBeYUWf5bRHRU7x+3tlmmVk35RRiTQ8nI+IFSZNGoC1mVhJlCahWtNITa2S+pFeKw80jGq0kaa6kVZIaT7ZuZqXRNyliK68yGGqI3QWcBPQAvcCtjVaMiCURMTUipg5xX2Y2wkbV4WQ9EbG9772ku4EnO9YiM+u6sgRUK4bUE5M0vurXWcC6RuuaWX5GVU9M0s+A6cBRkrYC3wemS+oBAtgCXD18TbRuWrcu/e/TpZdemqx/61vfalhrNlfZ1Ven/7OaPHlysn7BBRck69ZYWQKqFa2cnZxdZ/E9w9AWMyuBMvWyWuEr9s2sRlnOPLbCIWZmNXLqibVznZiZjVKdGthvcNvikZKekfRa8fOIqtqNkjZL2iTpwlba6hAzs35aDbAWe2v3Unvb4g3AsxExGXi2+B1JU4DLgFOLbX4kaUyzHTjEzKxGp0IsIl4A3h2weCZwX/H+PuDiquUPRcRHEfEmsBmY1mwfHhOztuzatStZf+CBBxrWfvKTnyS3Pfjg9H+e5557brI+ffr0hrXnn38+ue2BbpjHxMZFRG+xn15JRxfLJwC/qVpva7EsySFmZjUGcXbyqAH3RS+JiCVD3K3qLGuapg4xM+tnkNeJ7RzCfdHbJY0vemHjgR3F8q3AxKr1jgW2Nfswj4mZWY1hvu1oBTCneD8HeKJq+WWSPiHpBGAy8FKzD3NPzMxqdGpMrMFti7cAyyVdCbwFXFLsc72k5cAGYC8wLyL2NduHQ8zManQqxBrctghQ9+ELEbEQSD88YQCHmJn10zcpYi4cYmZWI6fbjhxilvTFL34xWf/ud7+brH/pS19qWGt2HVgzGzZsSNZfeOGFtj7/QOYQM7OsOcTMLGsOMTPLlidFNLPs+eykmWXNPTEzy5pDzMyy5TExK5WTTz45WZ8/f36y/u1vfztZP+aYYwbdplbt25e+ba63tzdZz2lcp2wcYmaWtZz+AXCImVk/Ppw0s+w5xMwsaw4xM8uaQ8zMsuYQM7NsjbpJESVNBO4HjgH2U3kk079JOhJ4GJgEbAEujYj3hq+pB65m12LNnt1oBuDm14FNmjRpKE3qiFWrViXrCxemZylesWJFJ5tjVXLqibXytKO9wHUR8Xngy8C84nHjdR9Fbmb5G+anHXVU0xCLiN6IWFO83wNspPJU3kaPIjezzOUUYoMaE5M0CTgdWEnjR5GbWcbKFFCtaDnEJB0GPAJcGxG7pXpPHK+73Vxg7tCaZ2bdMOpCTNIhVAJsWUQ8Wixu9CjyfiJiCbCk+Jx8vhmzA1hOZyebjomp0uW6B9gYEYurSo0eRW5mmRttY2JnA5cDr0paWyy7iQaPIrda48aNS9anTJmSrN9xxx3J+imnnDLoNnXKypUrk/Uf/vCHDWtPPJH+dy+n3sBoUqaAakXTEIuIF4FGA2B1H0VuZnkbVSFmZgceh5iZZS2nQ3mHmJn1M+rGxMzswOMQM7OsOcTMLGsOsVHoyCOPbFj78Y9/nNy2p6cnWT/xxBOH0qSO+PWvf52s33rrrcn6U089laz/5S9/GXSbrPscYmaWrU5PiihpC7AH2AfsjYipnZyPsJX5xMzsADMMtx2dFxE9ETG1+L1j8xE6xMysxgjcO9mx+QgdYmZWYxAhdpSkVVWvetNuBfC0pNVV9X7zEQJDno/QY2Jm1s8ge1k7qw4RGzk7IrYVE6c+I+n37bWwP/fEzKxGJw8nI2Jb8XMH8BgwjWI+QoDUfIStcIiZWY39+/e39GpG0lhJh/e9B74OrKOD8xEeMIeTZ511VrJ+/fXXJ+vTpk1rWJswYcKQ2tQpf/7znxvWbr/99uS2N998c7L+wQcfDKlNlrcOXic2DnismM7+YODBiPiFpJfp0HyEB0yImVlrOnkDeES8AZxWZ/k7dGg+QoeYmdXwFftmljWHmJllzZMimlm2PCmimWXPIWZmWXOIldCsWbPaqrdjw4YNyfqTTz6ZrO/duzdZT835tWvXruS2ZvU4xMwsaw4xM8tWpydFHG4OMTOr4Z6YmWXNIWZmWXOImVm2fLGrmWUvpxBTs8ZKmgjcDxwD7AeWRMS/SVoAXAX8sVj1poj4eZPPyuebMctURKid7Q899ND47Gc/29K627ZtW93C9NTDqpWe2F7guohYU8zQuFrSM0XttohYNHzNM7NuyKkn1jTEiieR9D2VZI+kjUB3pzI1s2GT25jYoObYlzQJOB1YWSyaL+kVSUslHdFgm7l9j3Nqr6lmNlJG4LmTHdNyiEk6DHgEuDYidgN3AScBPVR6anVv4IuIJRExtdvHzWbWupxCrKWzk5IOoRJgyyLiUYCI2F5VvxtI38VsZtnI6bajpj0xVR5Tcg+wMSIWVy0fX7XaLCqPYTKzzLXaC8upJ3Y2cDnwqqS1xbKbgNmSeqg8onwLcPUwtM/MuqAsAdWKVs5OvgjUu+4keU2YmeVrVIWYmR14HGJmljWHmJlly5Mimln23BMzs6w5xMwsaw4xM8tWmS5kbYVDzMxqOMTMLGs+O2lmWXNPzMyylduY2KAmRTSzA0MnZ7GQNEPSJkmbJd3Q6bY6xMysRqdCTNIY4E7gG8AUKrPfTOlkW304aWY1OjiwPw3YHBFvAEh6CJgJbOjUDkY6xHYC/1v1+1HFsjIqa9vK2i5w24aqk207vgOf8RSVNrXikwOen7EkIpZU/T4B+EPV71uBs9psXz8jGmIR0e9hdpJWlXXu/bK2raztArdtqMrWtoiY0cGPqzcXYUfPGnhMzMyG01ZgYtXvxwLbOrkDh5iZDaeXgcmSTpB0KHAZsKKTO+j2wP6S5qt0TVnbVtZ2gds2VGVuW1siYq+k+VTG2cYASyNifSf3oZwuajMzG8iHk2aWNYeYmWWtKyE23LchtEPSFkmvSlo74PqXbrRlqaQdktZVLTtS0jOSXit+HlGiti2Q9Hbx3a2V9M0utW2ipOckbZS0XtI1xfKufneJdpXie8vViI+JFbch/A9wAZXTry8DsyOiY1fwtkPSFmBqRHT9wkhJ5wLvA/dHxBeKZf8CvBsRtxT/ABwREf9QkrYtAN6PiEUj3Z4BbRsPjI+INZIOB1YDFwNX0MXvLtGuSynB95arbvTE/nobQkR8DPTdhmADRMQLwLsDFs8E7ive30flj2DENWhbKUREb0SsKd7vATZSuXK8q99dol3Whm6EWL3bEMr0f2QAT0taLWlutxtTx7iI6IXKHwVwdJfbM9B8Sa8Uh5tdOdStJmkScDqwkhJ9dwPaBSX73nLSjRAb9tsQ2nR2RJxB5a77ecVhk7XmLuAkoAfoBW7tZmMkHQY8AlwbEbu72ZZqddpVqu8tN90IsWG/DaEdEbGt+LkDeIzK4W+ZbC/GVvrGWHZ0uT1/FRHbI2JfROwH7qaL352kQ6gExbKIeLRY3PXvrl67yvS95agbITbstyEMlaSxxYArksYCXwfWpbcacSuAOcX7OcATXWxLP30BUZhFl747SQLuATZGxOKqUle/u0btKsv3lquuXLFfnEL+V/7/NoSFI96IOiSdSKX3BZVbsh7sZtsk/QyYTmValO3A94HHgeXAccBbwCURMeID7A3aNp3KIVEAW4Cr+8agRrht5wD/BbwK9E2MdROV8aeufXeJds2mBN9brnzbkZllzVfsm1nWHGJmljWHmJllzSFmZllziJlZ1hxiZpY1h5iZZe3/APpBMI71CUMRAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 显示其中一个数据\n", "import matplotlib.pyplot as plt\n", "plt.imshow(train_set.data[0], cmap='gray')\n", "plt.title('%i' % train_set.targets[0])\n", "plt.colorbar()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([64, 1, 32, 32])\n", "torch.Size([64])\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]])\n" ] } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "# 显示转化后的图像\n", "for im, label in train_data:\n", " print(im.shape)\n", " print(label.shape)\n", " \n", " img = im[0,0,:,:]\n", " lab = label[0]\n", " plt.imshow(img, cmap='gray')\n", " plt.title('%i' % lab)\n", " plt.colorbar()\n", " plt.show()\n", "\n", " print(im[0,0,:,:])\n", " break" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0. Train Loss: 0.292838, Train Acc: 0.908382, Valid Loss: 0.075638, Valid Acc: 0.974684, Time 00:00:13\n", "Epoch 1. Train Loss: 0.077091, Train Acc: 0.976479, Valid Loss: 0.066128, Valid Acc: 0.978936, Time 00:00:14\n", "Epoch 2. Train Loss: 0.055866, Train Acc: 0.982759, Valid Loss: 0.042326, Valid Acc: 0.986748, Time 00:00:14\n", "Epoch 3. Train Loss: 0.043993, Train Acc: 0.986257, Valid Loss: 0.042040, Valid Acc: 0.986847, Time 00:00:14\n", "Epoch 4. Train Loss: 0.035289, Train Acc: 0.988823, Valid Loss: 0.035118, Valid Acc: 0.988430, Time 00:00:14\n", "Epoch 5. Train Loss: 0.030174, Train Acc: 0.990572, Valid Loss: 0.036890, Valid Acc: 0.988430, Time 00:00:14\n", "Epoch 6. Train Loss: 0.025604, Train Acc: 0.991571, Valid Loss: 0.028075, Valid Acc: 0.990803, Time 00:00:14\n", "Epoch 7. Train Loss: 0.021483, Train Acc: 0.993220, Valid Loss: 0.039955, Valid Acc: 0.988133, Time 00:00:14\n", "Epoch 8. Train Loss: 0.018553, Train Acc: 0.994020, Valid Loss: 0.031569, Valid Acc: 0.990506, Time 00:00:14\n", "Epoch 9. Train Loss: 0.016860, Train Acc: 0.994420, Valid Loss: 0.028923, Valid Acc: 0.990803, Time 00:00:14\n", "Epoch 10. Train Loss: 0.014547, Train Acc: 0.995186, Valid Loss: 0.041005, Valid Acc: 0.987737, Time 00:00:14\n", "Epoch 11. Train Loss: 0.011832, Train Acc: 0.996085, Valid Loss: 0.039684, Valid Acc: 0.989221, Time 00:00:14\n", "Epoch 12. Train Loss: 0.012104, Train Acc: 0.996019, Valid Loss: 0.033983, Valid Acc: 0.990012, Time 00:00:14\n", "Epoch 13. Train Loss: 0.009578, Train Acc: 0.996802, Valid Loss: 0.044510, Valid Acc: 0.989419, Time 00:00:14\n", "Epoch 14. Train Loss: 0.008961, Train Acc: 0.997018, Valid Loss: 0.033376, Valid Acc: 0.991693, Time 00:00:14\n", "Epoch 15. Train Loss: 0.008937, Train Acc: 0.997002, Valid Loss: 0.054347, Valid Acc: 0.986847, Time 00:00:15\n", "Epoch 16. Train Loss: 0.009171, Train Acc: 0.996902, Valid Loss: 0.034495, Valid Acc: 0.991594, Time 00:00:16\n", "Epoch 17. Train Loss: 0.006915, Train Acc: 0.997818, Valid Loss: 0.046391, Valid Acc: 0.989517, Time 00:00:16\n", "Epoch 18. Train Loss: 0.007419, Train Acc: 0.997651, Valid Loss: 0.044388, Valid Acc: 0.989419, Time 00:00:16\n", "Epoch 19. Train Loss: 0.006600, Train Acc: 0.998001, Valid Loss: 0.049959, Valid Acc: 0.987935, Time 00:00:16\n" ] } ], "source": [ "net = LeNet5()\n", "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n", "criterion = nn.CrossEntropyLoss()\n", "\n", "res = train(net, train_data, test_data, 20, optimizer, criterion)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "plt.plot(res[0], label='train')\n", "plt.plot(res[2], label='valid')\n", "plt.xlabel('epoch')\n", "plt.ylabel('Loss')\n", "plt.legend(loc='best')\n", "plt.savefig('fig-res-lenet5-train-validate-loss.pdf')\n", "plt.show()\n", "\n", "plt.plot(res[1], label='train')\n", "plt.plot(res[3], label='valid')\n", "plt.xlabel('epoch')\n", "plt.ylabel('Acc')\n", "plt.legend(loc='best')\n", "plt.savefig('fig-res-lenet5-train-validate-acc.pdf')\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 2 }