You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

02-LeNet5.ipynb 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# LeNet5\n",
  8. "\n",
  9. "LeNet 诞生于 1994 年,是最早的卷积神经网络之一,并且推动了深度学习领域的发展。自从 1988 年开始,在多次迭代后这个开拓性成果被命名为 LeNet5。LeNet5 的架构的提出是基于如下的观点:图像的特征分布在整张图像上,通过带有可学习参数的卷积,从而有效的减少了参数数量,能够在多个位置上提取相似特征。\n",
  10. "\n",
  11. "在LeNet5提出的时候,没有 GPU 帮助训练,甚至 CPU 的速度也很慢,因此,LeNet5的规模并不大。其包含七个处理层,每一层都包含可训练参数(权重),当时使用的输入数据是 $32 \\times 32$ 像素的图像。LeNet-5 这个网络虽然很小,但是它包含了深度学习的基本模块:卷积层,池化层,全连接层。它是其他深度学习模型的基础,这里对LeNet5进行深入分析和讲解,通过实例分析,加深对与卷积层和池化层的理解。"
  12. ]
  13. },
  14. {
  15. "cell_type": "markdown",
  16. "metadata": {},
  17. "source": [
  18. "定义网络为:"
  19. ]
  20. },
  21. {
  22. "cell_type": "code",
  23. "execution_count": 1,
  24. "metadata": {},
  25. "outputs": [],
  26. "source": [
  27. "import torch\n",
  28. "from torch import nn\n",
  29. "import torch.nn.functional as F\n",
  30. "\n",
  31. "class LeNet5(nn.Module):\n",
  32. " def __init__(self):\n",
  33. " super(LeNet5, self).__init__()\n",
  34. " # 1-input channel, 6-output channels, 5x5-conv\n",
  35. " self.conv1 = nn.Conv2d(1, 6, 5)\n",
  36. " # 6-input channel, 16-output channels, 5x5-conv\n",
  37. " self.conv2 = nn.Conv2d(6, 16, 5)\n",
  38. " # 16x5x5-input, 120-output\n",
  39. " self.fc1 = nn.Linear(16 * 5 * 5, 120) \n",
  40. " # 120-input, 84-output\n",
  41. " self.fc2 = nn.Linear(120, 84)\n",
  42. " # 84-input, 10-output\n",
  43. " self.fc3 = nn.Linear(84, 10)\n",
  44. "\n",
  45. " def forward(self, x):\n",
  46. " x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
  47. " x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))\n",
  48. " x = torch.flatten(x, 1) # 将结果拉升成1维向量,除了批次的维度\n",
  49. " x = F.relu(self.fc1(x))\n",
  50. " x = F.relu(self.fc2(x))\n",
  51. " x = self.fc3(x)\n",
  52. " return x"
  53. ]
  54. },
  55. {
  56. "cell_type": "code",
  57. "execution_count": 2,
  58. "metadata": {},
  59. "outputs": [
  60. {
  61. "name": "stdout",
  62. "output_type": "stream",
  63. "text": [
  64. "LeNet5(\n",
  65. " (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n",
  66. " (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
  67. " (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
  68. " (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
  69. " (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
  70. ")\n"
  71. ]
  72. }
  73. ],
  74. "source": [
  75. "net = LeNet5()\n",
  76. "print(net)"
  77. ]
  78. },
  79. {
  80. "cell_type": "code",
  81. "execution_count": 3,
  82. "metadata": {},
  83. "outputs": [
  84. {
  85. "name": "stdout",
  86. "output_type": "stream",
  87. "text": [
  88. "tensor([[ 0.0124, 0.1326, 0.1647, 0.0728, 0.0722, 0.0113, 0.0829, -0.0055,\n",
  89. " 0.1749, -0.0581]], grad_fn=<AddmmBackward>)\n"
  90. ]
  91. }
  92. ],
  93. "source": [
  94. "input = torch.randn(1, 1, 32, 32)\n",
  95. "out = net(input)\n",
  96. "print(out)"
  97. ]
  98. },
  99. {
  100. "cell_type": "code",
  101. "execution_count": 4,
  102. "metadata": {},
  103. "outputs": [],
  104. "source": [
  105. "import numpy as np\n",
  106. "from torchvision.datasets import mnist\n",
  107. "from torch.utils.data import DataLoader\n",
  108. "from torchvision.datasets import mnist \n",
  109. "from torchvision import transforms as tfs\n",
  110. "from utils import train\n",
  111. "\n",
  112. "# 使用数据增强\n",
  113. "def data_tf(x):\n",
  114. " im_aug = tfs.Compose([\n",
  115. " tfs.Resize(32),\n",
  116. " tfs.ToTensor() #,\n",
  117. " #tfs.Normalize([0.5], [0.5])\n",
  118. " ])\n",
  119. " x = im_aug(x)\n",
  120. " return x\n",
  121. " \n",
  122. "train_set = mnist.MNIST('../../data/mnist', train=True, transform=data_tf, download=True) \n",
  123. "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
  124. "test_set = mnist.MNIST('../../data/mnist', train=False, transform=data_tf, download=True) \n",
  125. "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n",
  126. "\n"
  127. ]
  128. },
  129. {
  130. "cell_type": "code",
  131. "execution_count": 5,
  132. "metadata": {},
  133. "outputs": [
  134. {
  135. "data": {
  136. "image/png": "iVBORw0KGgoAAAANSUhEUgAAATEAAAEICAYAAAA3EMMNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWHUlEQVR4nO3df6xU5Z3H8fdH1DZF02qsSBFFDavFpl6VYlONxVgrbWyQthrJxmDWiH9AVhNjVk2T0uzimi3irqs1xUr9EayS+Iu4TdW1WtdtigKlyo+yorIWuYGiUtBWDfDdP+bc7tw7M8/MvTP3znkun1cyuXPP95w5T6feD895zjnPUURgZparg7rdADOzdjjEzCxrDjEzy5pDzMyy5hAzs6w5xMwsaw4xM8uaQ8zqkvS8pA8lvV+8NnW7TWb1OMQsZX5EHFa8Tu52Y8zqcYiZWdYcYpbyz5J2SvpvSdO73RizeuR7J60eSWcBG4CPgcuAO4CeiHi9qw0zG8AhZi2R9AvgPyLi37vdFrNqPpy0VgWgbjfCbCCHmNWQ9BlJF0r6pKSDJf0tcC7wVLfbZjbQwd1ugJXSIcA/AacA+4DfAxdHhK8Vs9LxmJiZZc2Hk2aWNYeYmWXNIWZmWXOImVnWRvTspCSfRTAbZhHR1vV8M2bMiJ07d7a07urVq5+KiBnt7K9tETHkFzAD2ARsBm5oYf3wyy+/hvfVzt90RHDmmWdGq4BVTf7mJwLPARuB9cA1xfIFwNvA2uL1zaptbqSSKZuAC5u1d8g9MUljgDuBC4CtwMuSVkTEhqF+ppmVQwcvvdoLXBcRayQdDqyW9ExRuy0iFlWvLGkKlXt1TwU+B/ynpL+JiH2NdtDOmNg0YHNEvBERHwMPATPb+DwzK4n9+/e39GomInojYk3xfg+VHtmExCYzgYci4qOIeJNKj2xaah/thNgE4A9Vv2+t1zhJcyWtkrSqjX2Z2QgZ5JBSyyRNAk4HVhaL5kt6RdJSSUcUy1rKlWrthFi9wcOa/1URsSQipkbE1Db2ZWYjaBAhdlRfJ6V4za33eZIOAx4Bro2I3cBdwElAD9AL3Nq3ar3mpNraztnJrVQG7focC2xr4/PMrCQG0cva2ayDIukQKgG2LCIeLT5/e1X9buDJ4tdB50o7PbGXgcmSTpB0KJXBuBVtfJ6ZlUSnDiclCbgH2BgRi6uWj69abRawrni/ArhM0icknQBMBl5K7WPIPbGI2CtpPpXpWcYASyNi/VA/z8zKo4NnJ88GLgdelbS2WHYTMFtSD5VDxS3A1cV+10taTmVW4b3AvNSZSRjhWSx8savZ8Is2L3Y944wz4sUXX2xp3bFjx67u9ni35xMzsxoj2blpl0PMzGo4xMwsaw4xM8vWUC5k7SaHmJnVaOWWorJwiJlZDffEzCxbPpw0s+w5xMwsaw4xM8uaQ8zMshURPjtpZnlzT8zMsuYQM7OsOcTMLGsOMTPLlgf2zSx77omZWdYcYmaWNYeYmWXLN4CbWfYcYmaWNZ+dNLOsuSdmZtnymJiZZc8hZmZZc4iZWdYcYmaWLd87aWbZc0/MSmPMmDHJ+qc//elh3f/8+fMb1j71qU8ltz355JOT9Xnz5iXrixYtalibPXt2ctsPP/wwWb/llluS9R/84AfJetkdMCEmaQuwB9gH7I2IqZ1olJl11wETYoXzImJnBz7HzEriQAsxMxtFchvYP6jN7QN4WtJqSXPrrSBprqRVkla1uS8zGyF9V+03e5VBuyF2dkScAXwDmCfp3IErRMSSiJjq8TKzfHQqxCRNlPScpI2S1ku6plh+pKRnJL1W/DyiapsbJW2WtEnShc320VaIRcS24ucO4DFgWjufZ2bl0MGe2F7guoj4PPBlKp2dKcANwLMRMRl4tvidonYZcCowA/iRpOQp9iGHmKSxkg7vew98HVg31M8zs3JoNcBaCbGI6I2INcX7PcBGYAIwE7ivWO0+4OLi/UzgoYj4KCLeBDbTpHPUzsD+OOAxSX2f82BE/KKNzxu1jjvuuGT90EMPTda/8pWvJOvnnHNOw9pnPvOZ5Lbf+c53kvVu2rp1a7J+++23J+uzZs1qWNuzZ09y29/97nfJ+q9+9atkPXeDGO86asB495KIWFJvRUmTgNOBlcC4iOgt9tUr6ehitQnAb6o221osa2jIIRYRbwCnDXV7MyuvQZyd3NnKeLekw4BHgGsjYnfR+am7ap1lyURtd2DfzEahTp6dlHQIlQBbFhGPFou3Sxpf1McDO4rlW4GJVZsfC2xLfb5DzMz66eSYmCpdrnuAjRGxuKq0AphTvJ8DPFG1/DJJn5B0AjAZeCm1D1/samY1OngN2NnA5cCrktYWy24CbgGWS7oSeAu4pNjveknLgQ1UzmzOi4h9qR04xMysRqdCLCJepP44F8D5DbZZCCxsdR8OMTOrUZar8VvhEOuAnp6eZP2Xv/xlsj7c0+GUVbMzYN/73veS9ffffz9ZX7ZsWcNab29vctv33nsvWd+0aVOynrPc7p10iJlZDffEzCxrDjEzy5pDzMyy5hAzs2x5YN/MsueemJllzSF2gHnrrbeS9XfeeSdZL/N1YitXrkzWd+3alayfd955DWsff/xxctsHHnggWbfh4xAzs2yVaf78VjjEzKyGQ8zMsuazk2aWNffEzCxbHhMzs+w5xMwsaw6xA8y7776brF9//fXJ+kUXXZSs//a3v03Wmz26LGXt2rXJ+gUXXJCsf/DBB8n6qaee2rB2zTXXJLe17nGImVm2fO+kmWXPPTEzy5pDzMyy5hAzs6w5xMwsWx7YN7PsuSdm/Tz++OPJerPnUu7ZsydZP+200xrWrrzyyuS2ixYtStabXQfWzPr16xvW5s6d29Zn2/DJKcQOaraCpKWSdkhaV7XsSEnPSHqt+HnE8DbTzEZS3/2TzV5l0DTEgHuBGQOW3QA8GxGTgWeL381sFGg1wLIJsYh4ARh4X81M4L7i/X3AxZ1tlpl1U04hNtQxsXER0QsQEb2Sjm60oqS5gAc/zDLis5NVImIJsARAUjmi28waKlMvqxWtjInVs13SeIDi547ONcnMui2nw8mhhtgKYE7xfg7wRGeaY2ZlkFOINT2clPQzYDpwlKStwPeBW4Dlkq4E3gIuGc5Gjna7d+9ua/s//elPQ972qquuStYffvjhZD2nsRNrXVkCqhVNQywiZjcond/htphZCXTytiNJS4GLgB0R8YVi2QLgKuCPxWo3RcTPi9qNwJXAPuDvI+KpZvsY6uGkmY1iHTycvJfa60wBbouInuLVF2BTgMuAU4ttfiRpTLMdOMTMrEanQqzBdaaNzAQeioiPIuJNYDMwrdlGDjEzqzGIEDtK0qqqV6vXhM6X9EpxW2PfbYsTgD9UrbO1WJbkG8DNrMYgBvZ3RsTUQX78XcA/AlH8vBX4O0D1mtLswxxiZtbPcF8+ERHb+95Luht4svh1KzCxatVjgW3NPs8hNgosWLCgYe3MM89MbvvVr341Wf/a176WrD/99NPJuuVpOC+dkTS+77ZFYBbQN0POCuBBSYuBzwGTgZeafZ5DzMxqdKon1uA60+mSeqgcKm4Bri72uV7ScmADsBeYFxH7mu3DIWZmNToVYg2uM70nsf5CYOFg9uEQM7N+ynRLUSscYmZWwyFmZllziJlZ1nK6sd8hZmb9eEzMRlzqsWrNptpZs2ZNsn733Xcn688991yyvmrVqoa1O++8M7ltTn9Io01O371DzMxqOMTMLGsOMTPLVicnRRwJDjEzq+GemJllzSFmZllziJlZ1hxiVhqvv/56sn7FFVck6z/96U+T9csvv3zI9bFjxya3vf/++5P13t7eZN2Gxhe7mln2fHbSzLLmnpiZZc0hZmbZ8piYmWXPIWZmWXOImVnWfHbSsvHYY48l66+99lqyvnjx4mT9/PPPb1i7+eabk9sef/zxyfrChemH4rz99tvJutWX25jYQc1WkLRU0g5J66qWLZD0tqS1xeubw9tMMxtJfUHW7FUGTUMMuBeYUWf5bRHRU7x+3tlmmVk35RRiTQ8nI+IFSZNGoC1mVhJlCahWtNITa2S+pFeKw80jGq0kaa6kVZIaT7ZuZqXRNyliK68yGGqI3QWcBPQAvcCtjVaMiCURMTUipg5xX2Y2wkbV4WQ9EbG9772ku4EnO9YiM+u6sgRUK4bUE5M0vurXWcC6RuuaWX5GVU9M0s+A6cBRkrYC3wemS+oBAtgCXD18TbRuWrcu/e/TpZdemqx/61vfalhrNlfZ1Ven/7OaPHlysn7BBRck69ZYWQKqFa2cnZxdZ/E9w9AWMyuBMvWyWuEr9s2sRlnOPLbCIWZmNXLqibVznZiZjVKdGthvcNvikZKekfRa8fOIqtqNkjZL2iTpwlba6hAzs35aDbAWe2v3Unvb4g3AsxExGXi2+B1JU4DLgFOLbX4kaUyzHTjEzKxGp0IsIl4A3h2weCZwX/H+PuDiquUPRcRHEfEmsBmY1mwfHhOztuzatStZf+CBBxrWfvKTnyS3Pfjg9H+e5557brI+ffr0hrXnn38+ue2BbpjHxMZFRG+xn15JRxfLJwC/qVpva7EsySFmZjUGcXbyqAH3RS+JiCVD3K3qLGuapg4xM+tnkNeJ7RzCfdHbJY0vemHjgR3F8q3AxKr1jgW2Nfswj4mZWY1hvu1oBTCneD8HeKJq+WWSPiHpBGAy8FKzD3NPzMxqdGpMrMFti7cAyyVdCbwFXFLsc72k5cAGYC8wLyL2NduHQ8zManQqxBrctghQ9+ELEbEQSD88YQCHmJn10zcpYi4cYmZWI6fbjhxilvTFL34xWf/ud7+brH/pS19qWGt2HVgzGzZsSNZfeOGFtj7/QOYQM7OsOcTMLGsOMTPLlidFNLPs+eykmWXNPTEzy5pDzMyy5TExK5WTTz45WZ8/f36y/u1vfztZP+aYYwbdplbt25e+ba63tzdZz2lcp2wcYmaWtZz+AXCImVk/Ppw0s+w5xMwsaw4xM8uaQ8zMsuYQM7NsjbpJESVNBO4HjgH2U3kk079JOhJ4GJgEbAEujYj3hq+pB65m12LNnt1oBuDm14FNmjRpKE3qiFWrViXrCxemZylesWJFJ5tjVXLqibXytKO9wHUR8Xngy8C84nHjdR9Fbmb5G+anHXVU0xCLiN6IWFO83wNspPJU3kaPIjezzOUUYoMaE5M0CTgdWEnjR5GbWcbKFFCtaDnEJB0GPAJcGxG7pXpPHK+73Vxg7tCaZ2bdMOpCTNIhVAJsWUQ8Wixu9CjyfiJiCbCk+Jx8vhmzA1hOZyebjomp0uW6B9gYEYurSo0eRW5mmRttY2JnA5cDr0paWyy7iQaPIrda48aNS9anTJmSrN9xxx3J+imnnDLoNnXKypUrk/Uf/vCHDWtPPJH+dy+n3sBoUqaAakXTEIuIF4FGA2B1H0VuZnkbVSFmZgceh5iZZS2nQ3mHmJn1M+rGxMzswOMQM7OsOcTMLGsOsVHoyCOPbFj78Y9/nNy2p6cnWT/xxBOH0qSO+PWvf52s33rrrcn6U089laz/5S9/GXSbrPscYmaWrU5PiihpC7AH2AfsjYipnZyPsJX5xMzsADMMtx2dFxE9ETG1+L1j8xE6xMysxgjcO9mx+QgdYmZWYxAhdpSkVVWvetNuBfC0pNVV9X7zEQJDno/QY2Jm1s8ge1k7qw4RGzk7IrYVE6c+I+n37bWwP/fEzKxGJw8nI2Jb8XMH8BgwjWI+QoDUfIStcIiZWY39+/e39GpG0lhJh/e9B74OrKOD8xEeMIeTZ511VrJ+/fXXJ+vTpk1rWJswYcKQ2tQpf/7znxvWbr/99uS2N998c7L+wQcfDKlNlrcOXic2DnismM7+YODBiPiFpJfp0HyEB0yImVlrOnkDeES8AZxWZ/k7dGg+QoeYmdXwFftmljWHmJllzZMimlm2PCmimWXPIWZmWXOIldCsWbPaqrdjw4YNyfqTTz6ZrO/duzdZT835tWvXruS2ZvU4xMwsaw4xM8tWpydFHG4OMTOr4Z6YmWXNIWZmWXOImVm2fLGrmWUvpxBTs8ZKmgjcDxwD7AeWRMS/SVoAXAX8sVj1poj4eZPPyuebMctURKid7Q899ND47Gc/29K627ZtW93C9NTDqpWe2F7guohYU8zQuFrSM0XttohYNHzNM7NuyKkn1jTEiieR9D2VZI+kjUB3pzI1s2GT25jYoObYlzQJOB1YWSyaL+kVSUslHdFgm7l9j3Nqr6lmNlJG4LmTHdNyiEk6DHgEuDYidgN3AScBPVR6anVv4IuIJRExtdvHzWbWupxCrKWzk5IOoRJgyyLiUYCI2F5VvxtI38VsZtnI6bajpj0xVR5Tcg+wMSIWVy0fX7XaLCqPYTKzzLXaC8upJ3Y2cDnwqqS1xbKbgNmSeqg8onwLcPUwtM/MuqAsAdWKVs5OvgjUu+4keU2YmeVrVIWYmR14HGJmljWHmJlly5Mimln23BMzs6w5xMwsaw4xM8tWmS5kbYVDzMxqOMTMLGs+O2lmWXNPzMyylduY2KAmRTSzA0MnZ7GQNEPSJkmbJd3Q6bY6xMysRqdCTNIY4E7gG8AUKrPfTOlkW304aWY1OjiwPw3YHBFvAEh6CJgJbOjUDkY6xHYC/1v1+1HFsjIqa9vK2i5w24aqk207vgOf8RSVNrXikwOen7EkIpZU/T4B+EPV71uBs9psXz8jGmIR0e9hdpJWlXXu/bK2raztArdtqMrWtoiY0cGPqzcXYUfPGnhMzMyG01ZgYtXvxwLbOrkDh5iZDaeXgcmSTpB0KHAZsKKTO+j2wP6S5qt0TVnbVtZ2gds2VGVuW1siYq+k+VTG2cYASyNifSf3oZwuajMzG8iHk2aWNYeYmWWtKyE23LchtEPSFkmvSlo74PqXbrRlqaQdktZVLTtS0jOSXit+HlGiti2Q9Hbx3a2V9M0utW2ipOckbZS0XtI1xfKufneJdpXie8vViI+JFbch/A9wAZXTry8DsyOiY1fwtkPSFmBqRHT9wkhJ5wLvA/dHxBeKZf8CvBsRtxT/ABwREf9QkrYtAN6PiEUj3Z4BbRsPjI+INZIOB1YDFwNX0MXvLtGuSynB95arbvTE/nobQkR8DPTdhmADRMQLwLsDFs8E7ive30flj2DENWhbKUREb0SsKd7vATZSuXK8q99dol3Whm6EWL3bEMr0f2QAT0taLWlutxtTx7iI6IXKHwVwdJfbM9B8Sa8Uh5tdOdStJmkScDqwkhJ9dwPaBSX73nLSjRAb9tsQ2nR2RJxB5a77ecVhk7XmLuAkoAfoBW7tZmMkHQY8AlwbEbu72ZZqddpVqu8tN90IsWG/DaEdEbGt+LkDeIzK4W+ZbC/GVvrGWHZ0uT1/FRHbI2JfROwH7qaL352kQ6gExbKIeLRY3PXvrl67yvS95agbITbstyEMlaSxxYArksYCXwfWpbcacSuAOcX7OcATXWxLP30BUZhFl747SQLuATZGxOKqUle/u0btKsv3lquuXLFfnEL+V/7/NoSFI96IOiSdSKX3BZVbsh7sZtsk/QyYTmValO3A94HHgeXAccBbwCURMeID7A3aNp3KIVEAW4Cr+8agRrht5wD/BbwK9E2MdROV8aeufXeJds2mBN9brnzbkZllzVfsm1nWHGJmljWHmJllzSFmZllziJlZ1hxiZpY1h5iZZe3/APpBMI71CUMRAAAAAElFTkSuQmCC\n",
  137. "text/plain": [
  138. "<Figure size 432x288 with 2 Axes>"
  139. ]
  140. },
  141. "metadata": {
  142. "needs_background": "light"
  143. },
  144. "output_type": "display_data"
  145. }
  146. ],
  147. "source": [
  148. "# 显示其中一个数据\n",
  149. "import matplotlib.pyplot as plt\n",
  150. "plt.imshow(train_set.data[0], cmap='gray')\n",
  151. "plt.title('%i' % train_set.targets[0])\n",
  152. "plt.colorbar()\n",
  153. "plt.show()"
  154. ]
  155. },
  156. {
  157. "cell_type": "code",
  158. "execution_count": 6,
  159. "metadata": {},
  160. "outputs": [
  161. {
  162. "name": "stdout",
  163. "output_type": "stream",
  164. "text": [
  165. "torch.Size([64, 1, 32, 32])\n",
  166. "torch.Size([64])\n"
  167. ]
  168. },
  169. {
  170. "data": {
  171. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS4AAAEICAYAAADhtRloAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZkElEQVR4nO3df5BdZZ3n8fcnTSLBiAHyqxMiiRgDWQoSDAFKFzO4cTpsWcDWZgQWyaRwM2wZC635Q3SrZty1LGBnddZxEarVILosrC5RI5WdjIWyaM2YTQItJESwZSA0iYltEgNJTNLhu3/c03rTfZ9zb3ffvvee7s+r6lb3fb7nnPvkdOfbz3nO8zxHEYGZWZFMaHYFzMyGyonLzArHicvMCseJy8wKx4nLzArHicvMCseJy8wKx4nLKpL0PyTtlXRY0ouSPtrsOpn1kwegWiWS/gXQHRHHJV0EPAn864jY3tyambnFZQkRsTMijve/zV4XNrFKZn/gxGVJkr4i6SjwC2AvsKnJVTIDfKloVUhqA64GlgP3RsTJ5tbIzC0uqyIiTkXET4Hzgf/Q7PqYgROX1e4M3MdlLcKJywaRNEPSTZKmSGqT9KfAzcCPml03M3Afl1UgaTrwv4HLKP1xewX4u4j4alMrZpZx4jKzwvGlopkVjhOXmRWOE5eZFY4Tl5kVzhmN/DBJvhNgNsoiQiPZv6OjI3p7e2vadvv27ZsjomMknzccI0pckjqALwFtwNci4p661MrMmqa3t5dt27bVtK2kaaNcnYqGfamYzWG7D1gJLAJulrSoXhUzs+aJiJpezTKSFtcySus1vQQg6VHgeuD5elTMzJrnzTffbHYVco0kcc0BXi173wNcOXAjSWuBtSP4HDNroGa3pmoxksRVqQNw0L82IjqBTnDnvFlRjOXE1QPMLXt/PrBnZNUxs1bQ6olrJOO4tgILJM2XNAm4CdhYn2qZWTON2c75iOiTtA7YTGk4xPqI2Fm3mplZ07R6i2tE47giYhNeh9xsTImIMX1X0czGqDHd4jKzscmJy8wKx4nLzAql2XcMa+HEZWaDuHPezArHLS4zKxRfKppZITlxmVnhOHGZWeE4cZlZoXjKj5kVkltcZlY4TlxmVjhOXGZWOE5cZlYo7pw3s0Jyi8tawhlnpH/U06dPT8Zmz56djC1alH7+79SpUyuWT5gwvMcc5LUADh8+nIx1dXVVLN+xY0dyn1OnTtVcr7HKicvMCseJy8wKxZOszayQnLjMrHB8V9HMCsctLjMrFPdxWa63v/3tydicOXOSsZkzZyZjZ511VsXyM888M7nPvHnzkrFLL700Gbv22muTsVmzZlUszxuWkSdviMK+ffuSsR/84AcVy7/85S8n9+nu7k7Gjh8/noyNJWM6cUl6GXgdOAX0RcTSelTKzJprTCeuzJ9ERG8djmNmLWI8JC4zG0OKMFdxePMv/iiAf5C0XdLaShtIWitpm6RtI/wsM2uQ/g76aq9mGWmL670RsUfSDOCHkn4REU+VbxARnUAngKTWbn+aGdD6l4ojanFFxJ7s637gu8CyelTKzJqrni0uSR2SXpDULemuCvG3S/qBpJ9L2ilpTbVjDrvFJemtwISIeD37/oPAfx7u8caqSZMmJWNXX311MrZq1apk7JprrknGUkMl8oYTTJ48ORnLq39fX18yduTIkSHvk/dZecM52tvbk7HVq1dXLD9x4kRyn3vuuScZ27NnTzI2ltSrxSWpDbgPWAH0AFslbYyI58s2+xjwfER8SNJ04AVJD0dE8oc0kkvFmcB3JfUf539GxN+P4Hhm1gLq3Dm/DOiOiJcAJD0KXA+UJ64A3qZSMpkCHADSf+EYQeLKKnLZcPc3s9Y1hBbXtAE33jqzfu1+c4BXy973AFcOOMZ/BzYCe4C3AR+OiNzM6eEQZjbIEBJXb5WB56p0+AHv/xToAq4FLqR0o+8nEZFcIXKkwyHMbAyqY+d8DzC37P35lFpW5dYAG6KkG/hn4KK8gzpxmdlpak1aNSaurcACSfMlTQJuonRZWG438AEASTOBhcBLeQf1paKZDVKvu4oR0SdpHbAZaAPWR8ROSXdk8QeAzwHfkPQcpUvLT1WbRujENcpSqyQAfPzjH0/GVqxYkYy1tbUlYwcPHqxYvnv37uQ+F1xwQTKWtxpC3qoMzzzzTMXy3/72t8Oqx3ve855kbMaMGclYaojFbbfdltyns7MzGRsvwyHqOeUnIjYBmwaUPVD2/R5Kw6lq5sRlZoO0+sh5Jy4zO02z5yHWwonLzAZx4jKzwnHiMrPCceIa5/IeOf+Wt7wlGctbmz115xDSa6x//vOfT+6Tt/Z9Nhe1oqNHjyZjqTuOx44dS+6zcuXKZGz69OnJWN5dxdTdsbw7onkTwceDIiwk6MRlZoO4xWVmhePEZWaF48RlZoXjxGVmheLOeTMrJLe4xrne3vQk97vvvjsZe/DBB5Oxw4eT66vxq1/9qmL5K6+8ktwnb8hGnry/yqnJzbNnz07uM2/evGQsb138vP9kv//97yuW553fvIng44UTl5kVjhOXmRWKJ1mbWSE5cZlZ4fiuopkVjltcZlYo7uOy3BUUtmzZkoxNnDgxGctbvSB1+//kyZPJfYarvb09Gevo6KhY/sEPppcWv+yy9POF84ZR5A05+d73vlex/Dvf+U5yn0OHDiVj40WrJ66qA3gkrZe0X9KOsrJzJf1Q0i+zr+eMbjXNrJHq+HiyUVHLyMNvAAP/fN4FPBERC4AnsvdmNkYUPnFFxFPAgQHF1wMPZd8/BNxQ32qZWbP0z1Ws5dUsw+3jmhkRewEiYq+k5BKUktYCa4f5OWbWBK3exzXqnfMR0Ql0Akhq7bNhZkDrJ67hza6FfZLaAbKv++tXJTNrtlbv4xpui2sjsBq4J/v6/brVaIzJ6wd444036v55bW1tFcunTZuW3Ofd7353MrZgwYJk7PLLL0/GrrrqqorlCxcuTO6T9/CQnp6eZOzHP/5xMnb//fdXLH/55ZeT+5w6dSoZGy9avcVVNXFJegRYDkyT1AP8NaWE9W1JtwO7gVWjWUkza5wxsZBgRNycCH2gznUxsxZR+BaXmY0/TlxmVjhOXGZWKM2+Y1gLJy4zG8SJy5IkJWN5Ky+84x3vSMZmzZpVsXz+/PnJfa6++upkLG84RN4xp0yZUrH8xIkTyX22bduWjG3cuDEZe+qpp5Kxrq6uZMzSCn9X0czGlyJcKg535LyZjWH1HDkvqUPSC5K6JVVcSUbSckldknZK+r/VjukWl5kNUq8Wl6Q24D5gBdADbJW0MSKeL9tmKvAVoCMiduct2tDPLS4zG6SOLa5lQHdEvBQRJ4BHKS2LVe4WYENE7M4+u+rcZycuMzvNENfjmiZpW9lr4BJWc4BXy973ZGXl3g2cI+lJSdsl3Vatjr5UNLNBhnCp2BsRS3PilW6dDzz4GcB7KE0jnAz8k6SfRcSLqYM6cY2ySZMmJWPz5s1LxlasWJGMXXPNNcnYBRdcULF8zpyBf+Rqi9XbsWPHkrG84RAPP/xwMrZ3794R1ckGq+NdxR5gbtn784E9FbbpjYgjwBFJTwGXAcnE5UtFMxukjn1cW4EFkuZLmgTcRGlZrHLfB/6lpDMknQVcCezKO6hbXGY2SL1aXBHRJ2kdsBloA9ZHxE5Jd2TxByJil6S/B54F3gS+FhE70kd14jKzAeo9ADUiNgGbBpQ9MOD93wB/U+sxnbjMbBBP+TGzwmn1KT9OXKNs6tSpyditt96ajN1yyy3J2Dvf+c4h1yNvQncjf0nz7rKm7ogCLFmyJBmbMCF9jylvrXpLc+Iys0IpwiRrJy4zG8SJy8wKx4nLzArHdxXNrFDcx2VmheTENc6dddZZydgVV1yRjM2dOzcZy1u3PfX4+Lym/8mTJ5OxPBMnThxybPLkycl9brjhhmQs71w98sgjydinP/3piuWp82QlrZ64qk6ylrRe0n5JO8rKPivptWyp1S5J141uNc2skeq5dPNoqGV1iG8AHRXK/zYiFmevTRXiZlZAQ1xIsCmqXipGxFOS5jWgLmbWIgp/qZhjnaRns0vJc1IbSVrbv6zrCD7LzBpoLFwqVnI/cCGwGNgLfCG1YUR0RsTSKsu7mlkLafXENay7ihGxr/97SV8FHq9bjcys6Vr9UnFYiUtSe0T0L/R9I5C7WuF49vrrrydjmzdvTsaOHj2ajPX19SVjzz33XMXyrVu3JvfZsmVLMpbnwx/+cDK2bNmyiuV56+W/613vSsZmzEg/au/9739/MnbxxRdXLN+1K70y8HgfKtHs1lQtqiYuSY8Ayyk9hqgH+GtguaTFlJ7W8TLwF6NXRTNrtMJP+YmImysUf30U6mJmLaLwLS4zG3+cuMysUMZEH5eZjT9OXGZWOK2euNTICkpq7bMxCtra2pKxc889Nxk7++yzk7G82/WpYRRHjhxJ7pMXy3POOckJE8n6X3TRRcl91qxZk4ytWrUqGdu/f38y9uCDD1Ys/9znPpfc59ixY8lYEURE+skoNWhvb4/Vq1fXtO299967vRmDy93iMrPTuI/LzArJicvMCseJy8wKx4nLzAqlfyHBVubEZWaDuMU1zuUNXfjNb36TjPX29g7r8xr5C3fw4MFk7NChQ0PeZ968ecnY0qXpO+5z5sxJxlasWFGx/O67707uY05cZlZATlxmVjhOXGZWKEUYgDqSh2WY2RhVz8eTSeqQ9IKkbkl35Wx3haRTkv5ttWM6cZnZIPV6WIakNuA+YCWwCLhZ0qLEdvcC6fXMy/hSsUW1elO9mlT989bg//Wvf52MHThwIBnLuxs5c+bMiuUTJvhvdp46/v4tA7oj4iUASY8C1wPPD9ju48BjwBW1HNQ/PTM7Ta2trSy5Tet/bmr2WjvgcHOAV8ve92RlfyBpDqWH7jxQax3d4jKzQYbQ4uqtsqxNpSV2Bh78vwGfiohTUm0r8jhxmdkgdbxU7AHmlr0/H9gzYJulwKNZ0poGXCepLyK+lzqoE5eZDVLHuYpbgQWS5gOvATcBt5RvEBHz+7+X9A3g8bykBU5cZjZAPcdxRUSfpHWU7ha2AesjYqekO7J4zf1a5Zy4zGyQet7VjohNwKYBZRUTVkT8eS3HrOVJ1nOBbwKzgDeBzoj4kqRzgf8FzKP0NOs/i4j0DFobVyZOnFixfMaMGcl95s+fn4zlrW9v9dfqw3FqGQ7RB/xlRFwMXAV8LBtAdhfwREQsAJ7I3pvZGFCvAaijpWqLKyL2Anuz71+XtIvSOIzrgeXZZg8BTwKfGpVamlnDjLmFBCXNA5YAW4CZWVIjIvZKSl8DmFmhtPqlYs2JS9IUSkPyPxERh2sdKJaNpB04mtbMWtiYSFySJlJKWg9HxIaseJ+k9qy11Q5UfCpnRHQCndlxWvtsmBnQ+omraue8Sk2rrwO7IuKLZaGNQP/jblcD369/9cysGQrfOQ+8F/gI8JykrqzsM8A9wLcl3Q7sBtLPSB8lkydPTsZSt+MBjh8/PqyYnW7KlCnJ2MKFCyuWv+9970vu86EPfSgZy1tX/uTJk8nY7373u4rlrd6iaKZmJ6Va1HJX8adUnigJ8IH6VsfMWsGYuqtoZuND4VtcZjb+OHGZWaGMiT4uMxt/nLjMrHDcOV8HZ599dsXySy65JLlP3ioE+/dXHCsLwKuvvpqMpR4f39fXl9xnNEyaNCkZa2trq1ieN2Rg2rRpw6rHxRdfnIytWlV5dMyKFSuS+8yePTsZy6v/nj0DF9T8ox/96EcVyxv9MysSXyqaWSE5cZlZ4ThxmVnhOHGZWeE4cZlZoYy5hQTNbHxwi6sOli9fXrH8k5/8ZHKfJUuWJGN5Qx6efvrpZOyxxx6rWJ43vGI0/nLNnTs3GUs9VGLfvn3JfdasWZOM5a2ysXRp+gHG5513XsXy1HANgBMnTiRjeUMeHn/88WTszjvvTMYszYnLzArHicvMCsUDUM2skJy4zKxwfFfRzArHLa46uPXWWyuW561fPmFC+jkgF110UTK2YMGCZOzGG2+sWN7oH3Levy0Vy/sLmjdpO0/eHcJTp05VLD906FByn23btiVj3/rWt5KxDRs2JGM2dO7jMrNCcuIys8Jx4jKzwnHnvJkVivu4zKyQnLjMrHAKn7gkzQW+CcwC3gQ6I+JLkj4L/HvgN9mmn4mITaNRyV27dlUsX7x4cXKfvEe2500czhsaMNxhA42UGoaQN4E5L5Zn06b0j7urq6ti+TPPPJPc58UXX0zGDhw4kIwdPXo0GbPhafXElR4Q9Ed9wF9GxMXAVcDHJC3KYn8bEYuz16gkLTNrvP5+rmqvWkjqkPSCpG5Jd1WI/ztJz2avf5R0WbVjVm1xRcReYG/2/euSdgHp5oyZFVo9FxKU1AbcB6wAeoCtkjZGxPNlm/0z8P6IOChpJdAJXJl33FpaXOWVmAcsAbZkReuyLLleUuWFoMyscOrY4loGdEfESxFxAngUuH7AZ/1jRPQ/++9nwPnVDlpz4pI0BXgM+EREHAbuBy4EFlNqkX0hsd9aSdskpedzmFlLGULimtb//zt7rR1wqDlA+cqdPeRfsd0O/J9q9avprqKkiZSS1sMRsSH7h+0ri38VqLgMZUR0Umr6Iam1e/zMDBhS53xvRKSXwgVVOnzFDaU/oZS40pOQM7XcVRTwdWBXRHyxrLw96/8CuBHYUe1YZtb66jwAtQcoX2v8fGDQOtySLgW+BqyMiN9WO2gtLa73Ah8BnpPUlZV9BrhZ0mJK2fNl4C9qONawpFYGePLJJ5P7zJo1Kxm78MILk7GFCxfWXK9W9Nprr1Us/8lPflL3z8pbu//gwYMVy/OGNRw5cmTEdbL6qGPi2goskDQfeA24CbilfANJ7wA2AB+JiPSYmDK13FX8KZWbex7+YDZG1euuYkT0SVoHbAbagPURsVPSHVn8AeCvgPOAr5Qu8OircvnpkfNmNlg9B6BmYzw3DSh7oOz7jwIfHcoxnbjM7DSeZG1mheTEZWaF48RlZoXjhQTroLu7e0jlAJMnT07Gpk+fnozNmDGj9oq1oNTDKPLOlVk593GZWSE5cZlZ4ThxmVnhOHGZWeE4cZlZodRzIcHR4sRlZoO4xdUkx44dS8Z27949rJjZeOHEZWaF48RlZoXiAahmVkhOXGZWOL6raGaF4xaXmRWK+7jMrJCcuMyscJy4zKxw3DlvZoXiPi4zKyQnLjMrnFZPXBOqbSDpTEn/T9LPJe2U9J+y8nMl/VDSL7Ov54x+dc2sEfovF6u9mqVq4gKOA9dGxGXAYqBD0lXAXcATEbEAeCJ7b2ZjQOETV5S8kb2dmL0CuB54KCt/CLhhNCpoZo3Vv5BgLa9mqaXFhaQ2SV3AfuCHEbEFmBkRewGyr8V+rpeZ/UGrt7hq6pyPiFPAYklTge9KuqTWD5C0Flg7vOqZWTMUvnO+XEQcAp4EOoB9ktoBsq/7E/t0RsTSiFg6sqqaWaO0eourlruK07OWFpImA/8K+AWwEVidbbYa+P4o1dHMGqjWpNXql4rtwEOS2iglum9HxOOS/gn4tqTbgd3AqlGsp5k1UKtfKqqRFZTU2mfDbAyICI1k/wkTJsTEiRNr2vbEiRPbm9EN5JHzZjZIq7e4nLjM7DTN7r+qxZDuKprZ+FDPznlJHZJekNQtadAMG5X8XRZ/VtLl1Y7pxGVmg9QrcWU39e4DVgKLgJslLRqw2UpgQfZaC9xf7bhOXGY2SB2n/CwDuiPipYg4ATxKabpgueuBb2bTC38GTO0fI5rS6D6uXuCV7Ptp2ftmcz1O53qcrmj1uKAOn7U5+7xanClpW9n7zojoLHs/B3i17H0PcOWAY1TaZg6wN/WhDU1cETG9/3tJ21phNL3r4Xq4HqeLiI46Hq7S0IyB15i1bHMaXyqa2WjqAeaWvT8f2DOMbU7jxGVmo2krsEDSfEmTgJsoTRcstxG4Lbu7eBXwu/6VZ1KaOY6rs/omDeF6nM71OJ3rMQIR0SdpHaV+szZgfUTslHRHFn8A2ARcB3QDR4E11Y7b0Ck/Zmb14EtFMyscJy4zK5ymJK5qUwAaWI+XJT0nqWvAWJTR/tz1kvZL2lFW1vCnJiXq8VlJr2XnpEvSdQ2ox1xJP5a0K3uS1J1ZeUPPSU49GnpO/GSt6hrex5VNAXgRWEHpNuhW4OaIeL6hFSnV5WVgaUQ0dIChpGuANyiNFr4kK/svwIGIuCdL5udExKeaUI/PAm9ExH8dzc8eUI92oD0inpb0NmA7pYev/DkNPCc59fgzGnhOJAl4a0S8IWki8FPgTuDf0ODfkVbVjBZXLVMAxrSIeAo4MKC44U9NStSj4SJib0Q8nX3/OrCL0sjphp6TnHo0lJ+sVV0zEldqeH8zBPAPkrZnD/VoplZ6atK6bJb++kZfjkiaBywBmvokqQH1gAafE/nJWrmakbiGPLx/FL03Ii6nNDv9Y9ml03h3P3AhpYf/7gW+0KgPljQFeAz4REQcbtTn1lCPhp+TiDgVEYspjSJfpiE8WWs8aEbiGvLw/tESEXuyr/uB71K6jG2Wmp6aNNoiYl/2n+ZN4Ks06JxkfTmPAQ9HxIasuOHnpFI9mnVOss8+xBCfrDUeNCNx1TIFYNRJemvWAYuktwIfBHbk7zWqWuKpSQOWE7mRBpyTrDP668CuiPhiWaih5yRVj0afE/nJWtUN5VFE9XpRGt7/IvAr4D82qQ7vBH6evXY2sh7AI5QuOU5SaoHeDpwHPAH8Mvt6bpPq8S3gOeBZSv9R2htQj/dR6i54FujKXtc1+pzk1KOh5wS4FHgm+7wdwF9l5Q3/HWnVl6f8mFnheOS8mRWOE5eZFY4Tl5kVjhOXmRWOE5eZFY4Tl5kVjhOXmRXO/weOzhS3T9Z1xAAAAABJRU5ErkJggg==\n",
  172. "text/plain": [
  173. "<Figure size 432x288 with 2 Axes>"
  174. ]
  175. },
  176. "metadata": {
  177. "needs_background": "light"
  178. },
  179. "output_type": "display_data"
  180. },
  181. {
  182. "name": "stdout",
  183. "output_type": "stream",
  184. "text": [
  185. "tensor([[0., 0., 0., ..., 0., 0., 0.],\n",
  186. " [0., 0., 0., ..., 0., 0., 0.],\n",
  187. " [0., 0., 0., ..., 0., 0., 0.],\n",
  188. " ...,\n",
  189. " [0., 0., 0., ..., 0., 0., 0.],\n",
  190. " [0., 0., 0., ..., 0., 0., 0.],\n",
  191. " [0., 0., 0., ..., 0., 0., 0.]])\n"
  192. ]
  193. }
  194. ],
  195. "source": [
  196. "import matplotlib.pyplot as plt\n",
  197. "\n",
  198. "# 显示转化后的图像\n",
  199. "for im, label in train_data:\n",
  200. " print(im.shape)\n",
  201. " print(label.shape)\n",
  202. " \n",
  203. " img = im[0,0,:,:]\n",
  204. " lab = label[0]\n",
  205. " plt.imshow(img, cmap='gray')\n",
  206. " plt.title('%i' % lab)\n",
  207. " plt.colorbar()\n",
  208. " plt.show()\n",
  209. "\n",
  210. " print(im[0,0,:,:])\n",
  211. " break"
  212. ]
  213. },
  214. {
  215. "cell_type": "code",
  216. "execution_count": 7,
  217. "metadata": {
  218. "scrolled": false
  219. },
  220. "outputs": [
  221. {
  222. "name": "stdout",
  223. "output_type": "stream",
  224. "text": [
  225. "Epoch 0. Train Loss: 0.292838, Train Acc: 0.908382, Valid Loss: 0.075638, Valid Acc: 0.974684, Time 00:00:13\n",
  226. "Epoch 1. Train Loss: 0.077091, Train Acc: 0.976479, Valid Loss: 0.066128, Valid Acc: 0.978936, Time 00:00:14\n",
  227. "Epoch 2. Train Loss: 0.055866, Train Acc: 0.982759, Valid Loss: 0.042326, Valid Acc: 0.986748, Time 00:00:14\n",
  228. "Epoch 3. Train Loss: 0.043993, Train Acc: 0.986257, Valid Loss: 0.042040, Valid Acc: 0.986847, Time 00:00:14\n",
  229. "Epoch 4. Train Loss: 0.035289, Train Acc: 0.988823, Valid Loss: 0.035118, Valid Acc: 0.988430, Time 00:00:14\n",
  230. "Epoch 5. Train Loss: 0.030174, Train Acc: 0.990572, Valid Loss: 0.036890, Valid Acc: 0.988430, Time 00:00:14\n",
  231. "Epoch 6. Train Loss: 0.025604, Train Acc: 0.991571, Valid Loss: 0.028075, Valid Acc: 0.990803, Time 00:00:14\n",
  232. "Epoch 7. Train Loss: 0.021483, Train Acc: 0.993220, Valid Loss: 0.039955, Valid Acc: 0.988133, Time 00:00:14\n",
  233. "Epoch 8. Train Loss: 0.018553, Train Acc: 0.994020, Valid Loss: 0.031569, Valid Acc: 0.990506, Time 00:00:14\n",
  234. "Epoch 9. Train Loss: 0.016860, Train Acc: 0.994420, Valid Loss: 0.028923, Valid Acc: 0.990803, Time 00:00:14\n",
  235. "Epoch 10. Train Loss: 0.014547, Train Acc: 0.995186, Valid Loss: 0.041005, Valid Acc: 0.987737, Time 00:00:14\n",
  236. "Epoch 11. Train Loss: 0.011832, Train Acc: 0.996085, Valid Loss: 0.039684, Valid Acc: 0.989221, Time 00:00:14\n",
  237. "Epoch 12. Train Loss: 0.012104, Train Acc: 0.996019, Valid Loss: 0.033983, Valid Acc: 0.990012, Time 00:00:14\n",
  238. "Epoch 13. Train Loss: 0.009578, Train Acc: 0.996802, Valid Loss: 0.044510, Valid Acc: 0.989419, Time 00:00:14\n",
  239. "Epoch 14. Train Loss: 0.008961, Train Acc: 0.997018, Valid Loss: 0.033376, Valid Acc: 0.991693, Time 00:00:14\n",
  240. "Epoch 15. Train Loss: 0.008937, Train Acc: 0.997002, Valid Loss: 0.054347, Valid Acc: 0.986847, Time 00:00:15\n",
  241. "Epoch 16. Train Loss: 0.009171, Train Acc: 0.996902, Valid Loss: 0.034495, Valid Acc: 0.991594, Time 00:00:16\n",
  242. "Epoch 17. Train Loss: 0.006915, Train Acc: 0.997818, Valid Loss: 0.046391, Valid Acc: 0.989517, Time 00:00:16\n",
  243. "Epoch 18. Train Loss: 0.007419, Train Acc: 0.997651, Valid Loss: 0.044388, Valid Acc: 0.989419, Time 00:00:16\n",
  244. "Epoch 19. Train Loss: 0.006600, Train Acc: 0.998001, Valid Loss: 0.049959, Valid Acc: 0.987935, Time 00:00:16\n"
  245. ]
  246. }
  247. ],
  248. "source": [
  249. "net = LeNet5()\n",
  250. "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n",
  251. "criterion = nn.CrossEntropyLoss()\n",
  252. "\n",
  253. "res = train(net, train_data, test_data, 20, \n",
  254. " optimizer, criterion,\n",
  255. " use_cuda=True)"
  256. ]
  257. },
  258. {
  259. "cell_type": "code",
  260. "execution_count": null,
  261. "metadata": {},
  262. "outputs": [],
  263. "source": [
  264. "import matplotlib.pyplot as plt\n",
  265. "%matplotlib inline\n",
  266. "\n",
  267. "plt.plot(res[0], label='train')\n",
  268. "plt.plot(res[2], label='valid')\n",
  269. "plt.xlabel('epoch')\n",
  270. "plt.ylabel('Loss')\n",
  271. "plt.legend(loc='best')\n",
  272. "plt.savefig('fig-res-lenet5-train-validate-loss.pdf')\n",
  273. "plt.show()\n",
  274. "\n",
  275. "plt.plot(res[1], label='train')\n",
  276. "plt.plot(res[3], label='valid')\n",
  277. "plt.xlabel('epoch')\n",
  278. "plt.ylabel('Acc')\n",
  279. "plt.legend(loc='best')\n",
  280. "plt.savefig('fig-res-lenet5-train-validate-acc.pdf')\n",
  281. "plt.show()"
  282. ]
  283. }
  284. ],
  285. "metadata": {
  286. "kernelspec": {
  287. "display_name": "Python 3",
  288. "language": "python",
  289. "name": "python3"
  290. },
  291. "language_info": {
  292. "codemirror_mode": {
  293. "name": "ipython",
  294. "version": 3
  295. },
  296. "file_extension": ".py",
  297. "mimetype": "text/x-python",
  298. "name": "python",
  299. "nbconvert_exporter": "python",
  300. "pygments_lexer": "ipython3",
  301. "version": "3.8.12"
  302. }
  303. },
  304. "nbformat": 4,
  305. "nbformat_minor": 2
  306. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。