You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

02-LeNet5.ipynb 27 kB


  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# LeNet5\n",
  8. "\n",
  9. "LeNet 诞生于 1994 年,是最早的卷积神经网络之一,并且推动了深度学习领域的发展。自从 1988 年开始,在多次迭代后这个开拓性成果被命名为 LeNet5。LeNet5 的架构的提出是基于如下的观点:图像的特征分布在整张图像上,通过带有可学习参数的卷积,从而有效的减少了参数数量,能够在多个位置上提取相似特征。\n",
  10. "\n",
  11. "在LeNet5提出的时候,没有 GPU 帮助训练,甚至 CPU 的速度也很慢,因此,LeNet5的规模并不大。其包含七个处理层,每一层都包含可训练参数(权重),当时使用的输入数据是 $32 \\times 32$ 像素的图像。LeNet-5 这个网络虽然很小,但是它包含了深度学习的基本模块:卷积层,池化层,全连接层。它是其他深度学习模型的基础,这里对LeNet5进行深入分析和讲解,通过实例分析,加深对与卷积层和池化层的理解。"
  12. ]
  13. },
  14. {
  15. "cell_type": "markdown",
  16. "metadata": {},
  17. "source": [
  18. "定义网络为:"
  19. ]
  20. },
  21. {
  22. "cell_type": "code",
  23. "execution_count": 1,
  24. "metadata": {},
  25. "outputs": [],
  26. "source": [
  27. "import torch\n",
  28. "from torch import nn\n",
  29. "import torch.nn.functional as F\n",
  30. "\n",
  31. "class LeNet5(nn.Module):\n",
  32. " def __init__(self):\n",
  33. " super(LeNet5, self).__init__()\n",
  34. " # 1-input channel, 6-output channels, 5x5-conv\n",
  35. " self.conv1 = nn.Conv2d(1, 6, 5)\n",
  36. " # 6-input channel, 16-output channels, 5x5-conv\n",
  37. " self.conv2 = nn.Conv2d(6, 16, 5)\n",
  38. " # 16x5x5-input, 120-output\n",
  39. " self.fc1 = nn.Linear(16 * 5 * 5, 120) \n",
  40. " # 120-input, 84-output\n",
  41. " self.fc2 = nn.Linear(120, 84)\n",
  42. " # 84-input, 10-output\n",
  43. " self.fc3 = nn.Linear(84, 10)\n",
  44. "\n",
  45. " def forward(self, x):\n",
  46. " x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
  47. " x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))\n",
  48. " x = torch.flatten(x, 1) # 将结果拉升成1维向量,除了批次的维度\n",
  49. " x = F.relu(self.fc1(x))\n",
  50. " x = F.relu(self.fc2(x))\n",
  51. " x = self.fc3(x)\n",
  52. " return x"
  53. ]
  54. },
  55. {
  56. "cell_type": "code",
  57. "execution_count": 2,
  58. "metadata": {},
  59. "outputs": [
  60. {
  61. "name": "stdout",
  62. "output_type": "stream",
  63. "text": [
  64. "LeNet5(\n",
  65. " (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n",
  66. " (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
  67. " (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
  68. " (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
  69. " (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
  70. ")\n"
  71. ]
  72. }
  73. ],
  74. "source": [
  75. "net = LeNet5()\n",
  76. "print(net)"
  77. ]
  78. },
  79. {
  80. "cell_type": "code",
  81. "execution_count": 3,
  82. "metadata": {},
  83. "outputs": [
  84. {
  85. "name": "stdout",
  86. "output_type": "stream",
  87. "text": [
  88. "tensor([[ 0.0124, 0.1326, 0.1647, 0.0728, 0.0722, 0.0113, 0.0829, -0.0055,\n",
  89. " 0.1749, -0.0581]], grad_fn=<AddmmBackward>)\n"
  90. ]
  91. }
  92. ],
  93. "source": [
  94. "input = torch.randn(1, 1, 32, 32)\n",
  95. "out = net(input)\n",
  96. "print(out)"
  97. ]
  98. },
  99. {
  100. "cell_type": "code",
  101. "execution_count": 4,
  102. "metadata": {},
  103. "outputs": [],
  104. "source": [
  105. "import numpy as np\n",
  106. "from torchvision.datasets import mnist\n",
  107. "from torch.utils.data import DataLoader\n",
  108. "from torchvision.datasets import mnist \n",
  109. "from torchvision import transforms as tfs\n",
  110. "from utils import train\n",
  111. "\n",
  112. "# 使用数据增强\n",
  113. "def data_tf(x):\n",
  114. " im_aug = tfs.Compose([\n",
  115. " tfs.Resize(32),\n",
  116. " tfs.ToTensor() #,\n",
  117. " #tfs.Normalize([0.5], [0.5])\n",
  118. " ])\n",
  119. " x = im_aug(x)\n",
  120. " return x\n",
  121. " \n",
  122. "train_set = mnist.MNIST('../../data/mnist', train=True, transform=data_tf, download=True) \n",
  123. "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
  124. "test_set = mnist.MNIST('../../data/mnist', train=False, transform=data_tf, download=True) \n",
  125. "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n",
  126. "\n"
  127. ]
  128. },
  129. {
  130. "cell_type": "code",
  131. "execution_count": 5,
  132. "metadata": {},
  133. "outputs": [
  134. {
  135. "data": {
  136. "image/png": "iVBORw0KGgoAAAANSUhEUgAAATEAAAEICAYAAAA3EMMNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWHUlEQVR4nO3df6xU5Z3H8fdH1DZF02qsSBFFDavFpl6VYlONxVgrbWyQthrJxmDWiH9AVhNjVk2T0uzimi3irqs1xUr9EayS+Iu4TdW1WtdtigKlyo+yorIWuYGiUtBWDfDdP+bc7tw7M8/MvTP3znkun1cyuXPP95w5T6feD895zjnPUURgZparg7rdADOzdjjEzCxrDjEzy5pDzMyy5hAzs6w5xMwsaw4xM8uaQ8zqkvS8pA8lvV+8NnW7TWb1OMQsZX5EHFa8Tu52Y8zqcYiZWdYcYpbyz5J2SvpvSdO73RizeuR7J60eSWcBG4CPgcuAO4CeiHi9qw0zG8AhZi2R9AvgPyLi37vdFrNqPpy0VgWgbjfCbCCHmNWQ9BlJF0r6pKSDJf0tcC7wVLfbZjbQwd1ugJXSIcA/AacA+4DfAxdHhK8Vs9LxmJiZZc2Hk2aWNYeYmWXNIWZmWXOImVnWRvTspCSfRTAbZhHR1vV8M2bMiJ07d7a07urVq5+KiBnt7K9tETHkFzAD2ARsBm5oYf3wyy+/hvfVzt90RHDmmWdGq4BVTf7mJwLPARuB9cA1xfIFwNvA2uL1zaptbqSSKZuAC5u1d8g9MUljgDuBC4CtwMuSVkTEhqF+ppmVQwcvvdoLXBcRayQdDqyW9ExRuy0iFlWvLGkKlXt1TwU+B/ynpL+JiH2NdtDOmNg0YHNEvBERHwMPATPb+DwzK4n9+/e39GomInojYk3xfg+VHtmExCYzgYci4qOIeJNKj2xaah/thNgE4A9Vv2+t1zhJcyWtkrSqjX2Z2QgZ5JBSyyRNAk4HVhaL5kt6RdJSSUcUy1rKlWrthFi9wcOa/1URsSQipkbE1Db2ZWYjaBAhdlRfJ6V4za33eZIOAx4Bro2I3cBdwElAD9AL3Nq3ar3mpNraztnJrVQG7focC2xr4/PMrCQG0cva2ayDIukQKgG2LCIeLT5/e1X9buDJ4tdB50o7PbGXgcmSTpB0KJXBuBVtfJ6ZlUSnDiclCbgH2BgRi6uWj69abRawrni/ArhM0icknQBMBl5K7WPIPbGI2CtpPpXpWcYASyNi/VA/z8zKo4NnJ88GLgdelbS2WHYTMFtSD5VDxS3A1cV+10taTmVW4b3AvNSZSRjhWSx8savZ8Is2L3Y944wz4sUXX2xp3bFjx67u9ni35xMzsxoj2blpl0PMzGo4xMwsaw4xM8vWUC5k7SaHmJnVaOWWorJwiJlZDffEzCxbPpw0s+w5xMwsaw4xM8uaQ8zMshURPjtpZnlzT8zMsuYQM7OsOcTMLGsOMTPLlgf2zSx77omZWdYcYmaWNYeYmWXLN4CbWfYcYmaWNZ+dNLOsuSdmZtnymJiZZc8hZmZZc4iZWdYcYmaWLd87aWbZc0/MSmPMmDHJ+qc//elh3f/8+fMb1j71qU8ltz355JOT9Xnz5iXrixYtalibPXt2ctsPP/wwWb/llluS9R/84AfJetkdMCEmaQuwB9gH7I2IqZ1olJl11wETYoXzImJnBz7HzEriQAsxMxtFchvYP6jN7QN4WtJqSXPrrSBprqRVkla1uS8zGyF9V+03e5VBuyF2dkScAXwDmCfp3IErRMSSiJjq8TKzfHQqxCRNlPScpI2S1ku6plh+pKRnJL1W/DyiapsbJW2WtEnShc320VaIRcS24ucO4DFgWjufZ2bl0MGe2F7guoj4PPBlKp2dKcANwLMRMRl4tvidonYZcCowA/iRpOQp9iGHmKSxkg7vew98HVg31M8zs3JoNcBaCbGI6I2INcX7PcBGYAIwE7ivWO0+4OLi/UzgoYj4KCLeBDbTpHPUzsD+OOAxSX2f82BE/KKNzxu1jjvuuGT90EMPTda/8pWvJOvnnHNOw9pnPvOZ5Lbf+c53kvVu2rp1a7J+++23J+uzZs1qWNuzZ09y29/97nfJ+q9+9atkPXeDGO86asB495KIWFJvRUmTgNOBlcC4iOgt9tUr6ehitQnAb6o221osa2jIIRYRbwCnDXV7MyuvQZyd3NnKeLekw4BHgGsjYnfR+am7ap1lyURtd2DfzEahTp6dlHQIlQBbFhGPFou3Sxpf1McDO4rlW4GJVZsfC2xLfb5DzMz66eSYmCpdrnuAjRGxuKq0AphTvJ8DPFG1/DJJn5B0AjAZeCm1D1/samY1OngN2NnA5cCrktYWy24CbgGWS7oSeAu4pNjveknLgQ1UzmzOi4h9qR04xMysRqdCLCJepP44F8D5DbZZCCxsdR8OMTOrUZar8VvhEOuAnp6eZP2Xv/xlsj7c0+GUVbMzYN/73veS9ffffz9ZX7ZsWcNab29vctv33nsvWd+0aVOynrPc7p10iJlZDffEzCxrDjEzy5pDzMyy5hAzs2x5YN/MsueemJllzSF2gHnrrbeS9XfeeSdZL/N1YitXrkzWd+3alayfd955DWsff/xxctsHHnggWbfh4xAzs2yVaf78VjjEzKyGQ8zMsuazk2aWNffEzCxbHhMzs+w5xMwsaw6xA8y7776brF9//fXJ+kUXXZSs//a3v03Wmz26LGXt2rXJ+gUXXJCsf/DBB8n6qaee2rB2zTXXJLe17nGImVm2fO+kmWXPPTEzy5pDzMyy5hAzs6w5xMwsWx7YN7PsuSdm/Tz++OPJerPnUu7ZsydZP+200xrWrrzyyuS2ixYtStabXQfWzPr16xvW5s6d29Zn2/DJKcQOaraCpKWSdkhaV7XsSEnPSHqt+HnE8DbTzEZS3/2TzV5l0DTEgHuBGQOW3QA8GxGTgWeL381sFGg1wLIJsYh4ARh4X81M4L7i/X3AxZ1tlpl1U04hNtQxsXER0QsQEb2Sjm60oqS5gAc/zDLis5NVImIJsARAUjmi28waKlMvqxWtjInVs13SeIDi547ONcnMui2nw8mhhtgKYE7xfg7wRGeaY2ZlkFOINT2clPQzYDpwlKStwPeBW4Dlkq4E3gIuGc5Gjna7d+9ua/s//elPQ972qquuStYffvjhZD2nsRNrXVkCqhVNQywiZjcond/htphZCXTytiNJS4GLgB0R8YVi2QLgKuCPxWo3RcTPi9qNwJXAPuDvI+KpZvsY6uGkmY1iHTycvJfa60wBbouInuLVF2BTgMuAU4ttfiRpTLMdOMTMrEanQqzBdaaNzAQeioiPIuJNYDMwrdlGDjEzqzGIEDtK0qqqV6vXhM6X9EpxW2PfbYsTgD9UrbO1WJbkG8DNrMYgBvZ3RsTUQX78XcA/AlH8vBX4O0D1mtLswxxiZtbPcF8+ERHb+95Luht4svh1KzCxatVjgW3NPs8hNgosWLCgYe3MM89MbvvVr341Wf/a176WrD/99NPJuuVpOC+dkTS+77ZFYBbQN0POCuBBSYuBzwGTgZeafZ5DzMxqdKon1uA60+mSeqgcKm4Bri72uV7ScmADsBeYFxH7mu3DIWZmNToVYg2uM70nsf5CYOFg9uEQM7N+ynRLUSscYmZWwyFmZllziJlZ1nK6sd8hZmb9eEzMRlzqsWrNptpZs2ZNsn733Xcn688991yyvmrVqoa1O++8M7ltTn9Io01O371DzMxqOMTMLGsOMTPLVicnRRwJDjEzq+GemJllzSFmZllziJlZ1hxiVhqvv/56sn7FFVck6z/96U+T9csvv3zI9bFjxya3vf/++5P13t7eZN2Gxhe7mln2fHbSzLLmnpiZZc0hZmbZ8piYmWXPIWZmWXOImVnWfHbSsvHYY48l66+99lqyvnjx4mT9/PPPb1i7+eabk9sef/zxyfrChemH4rz99tvJutWX25jYQc1WkLRU0g5J66qWLZD0tqS1xeubw9tMMxtJfUHW7FUGTUMMuBeYUWf5bRHRU7x+3tlmmVk35RRiTQ8nI+IFSZNGoC1mVhJlCahWtNITa2S+pFeKw80jGq0kaa6kVZIaT7ZuZqXRNyliK68yGGqI3QWcBPQAvcCtjVaMiCURMTUipg5xX2Y2wkbV4WQ9EbG9772ku4EnO9YiM+u6sgRUK4bUE5M0vurXWcC6RuuaWX5GVU9M0s+A6cBRkrYC3wemS+oBAtgCXD18TbRuWrcu/e/TpZdemqx/61vfalhrNlfZ1Ven/7OaPHlysn7BBRck69ZYWQKqFa2cnZxdZ/E9w9AWMyuBMvWyWuEr9s2sRlnOPLbCIWZmNXLqibVznZiZjVKdGthvcNvikZKekfRa8fOIqtqNkjZL2iTpwlba6hAzs35aDbAWe2v3Unvb4g3AsxExGXi2+B1JU4DLgFOLbX4kaUyzHTjEzKxGp0IsIl4A3h2weCZwX/H+PuDiquUPRcRHEfEmsBmY1mwfHhOztuzatStZf+CBBxrWfvKTnyS3Pfjg9H+e5557brI+ffr0hrXnn38+ue2BbpjHxMZFRG+xn15JRxfLJwC/qVpva7EsySFmZjUGcXbyqAH3RS+JiCVD3K3qLGuapg4xM+tnkNeJ7RzCfdHbJY0vemHjgR3F8q3AxKr1jgW2Nfswj4mZWY1hvu1oBTCneD8HeKJq+WWSPiHpBGAy8FKzD3NPzMxqdGpMrMFti7cAyyVdCbwFXFLsc72k5cAGYC8wLyL2NduHQ8zManQqxBrctghQ9+ELEbEQSD88YQCHmJn10zcpYi4cYmZWI6fbjhxilvTFL34xWf/ud7+brH/pS19qWGt2HVgzGzZsSNZfeOGFtj7/QOYQM7OsOcTMLGsOMTPLlidFNLPs+eykmWXNPTEzy5pDzMyy5TExK5WTTz45WZ8/f36y/u1vfztZP+aYYwbdplbt25e+ba63tzdZz2lcp2wcYmaWtZz+AXCImVk/Ppw0s+w5xMwsaw4xM8uaQ8zMsuYQM7NsjbpJESVNBO4HjgH2U3kk079JOhJ4GJgEbAEujYj3hq+pB65m12LNnt1oBuDm14FNmjRpKE3qiFWrViXrCxemZylesWJFJ5tjVXLqibXytKO9wHUR8Xngy8C84nHjdR9Fbmb5G+anHXVU0xCLiN6IWFO83wNspPJU3kaPIjezzOUUYoMaE5M0CTgdWEnjR5GbWcbKFFCtaDnEJB0GPAJcGxG7pXpPHK+73Vxg7tCaZ2bdMOpCTNIhVAJsWUQ8Wixu9CjyfiJiCbCk+Jx8vhmzA1hOZyebjomp0uW6B9gYEYurSo0eRW5mmRttY2JnA5cDr0paWyy7iQaPIrda48aNS9anTJmSrN9xxx3J+imnnDLoNnXKypUrk/Uf/vCHDWtPPJH+dy+n3sBoUqaAakXTEIuIF4FGA2B1H0VuZnkbVSFmZgceh5iZZS2nQ3mHmJn1M+rGxMzswOMQM7OsOcTMLGsOsVHoyCOPbFj78Y9/nNy2p6cnWT/xxBOH0qSO+PWvf52s33rrrcn6U089laz/5S9/GXSbrPscYmaWrU5PiihpC7AH2AfsjYipnZyPsJX5xMzsADMMtx2dFxE9ETG1+L1j8xE6xMysxgjcO9mx+QgdYmZWYxAhdpSkVVWvetNuBfC0pNVV9X7zEQJDno/QY2Jm1s8ge1k7qw4RGzk7IrYVE6c+I+n37bWwP/fEzKxGJw8nI2Jb8XMH8BgwjWI+QoDUfIStcIiZWY39+/e39GpG0lhJh/e9B74OrKOD8xEeMIeTZ511VrJ+/fXXJ+vTpk1rWJswYcKQ2tQpf/7znxvWbr/99uS2N998c7L+wQcfDKlNlrcOXic2DnismM7+YODBiPiFpJfp0HyEB0yImVlrOnkDeES8AZxWZ/k7dGg+QoeYmdXwFftmljWHmJllzZMimlm2PCmimWXPIWZmWXOIldCsWbPaqrdjw4YNyfqTTz6ZrO/duzdZT835tWvXruS2ZvU4xMwsaw4xM8tWpydFHG4OMTOr4Z6YmWXNIWZmWXOImVm2fLGrmWUvpxBTs8ZKmgjcDxwD7AeWRMS/SVoAXAX8sVj1poj4eZPPyuebMctURKid7Q899ND47Gc/29K627ZtW93C9NTDqpWe2F7guohYU8zQuFrSM0XttohYNHzNM7NuyKkn1jTEiieR9D2VZI+kjUB3pzI1s2GT25jYoObYlzQJOB1YWSyaL+kVSUslHdFgm7l9j3Nqr6lmNlJG4LmTHdNyiEk6DHgEuDYidgN3AScBPVR6anVv4IuIJRExtdvHzWbWupxCrKWzk5IOoRJgyyLiUYCI2F5VvxtI38VsZtnI6bajpj0xVR5Tcg+wMSIWVy0fX7XaLCqPYTKzzLXaC8upJ3Y2cDnwqqS1xbKbgNmSeqg8onwLcPUwtM/MuqAsAdWKVs5OvgjUu+4keU2YmeVrVIWYmR14HGJmljWHmJlly5Mimln23BMzs6w5xMwsaw4xM8tWmS5kbYVDzMxqOMTMLGs+O2lmWXNPzMyylduY2KAmRTSzA0MnZ7GQNEPSJkmbJd3Q6bY6xMysRqdCTNIY4E7gG8AUKrPfTOlkW304aWY1OjiwPw3YHBFvAEh6CJgJbOjUDkY6xHYC/1v1+1HFsjIqa9vK2i5w24aqk207vgOf8RSVNrXikwOen7EkIpZU/T4B+EPV71uBs9psXz8jGmIR0e9hdpJWlXXu/bK2raztArdtqMrWtoiY0cGPqzcXYUfPGnhMzMyG01ZgYtXvxwLbOrkDh5iZDaeXgcmSTpB0KHAZsKKTO+j2wP6S5qt0TVnbVtZ2gds2VGVuW1siYq+k+VTG2cYASyNifSf3oZwuajMzG8iHk2aWNYeYmWWtKyE23LchtEPSFkmvSlo74PqXbrRlqaQdktZVLTtS0jOSXit+HlGiti2Q9Hbx3a2V9M0utW2ipOckbZS0XtI1xfKufneJdpXie8vViI+JFbch/A9wAZXTry8DsyOiY1fwtkPSFmBqRHT9wkhJ5wLvA/dHxBeKZf8CvBsRtxT/ABwREf9QkrYtAN6PiEUj3Z4BbRsPjI+INZIOB1YDFwNX0MXvLtGuSynB95arbvTE/nobQkR8DPTdhmADRMQLwLsDFs8E7ive30flj2DENWhbKUREb0SsKd7vATZSuXK8q99dol3Whm6EWL3bEMr0f2QAT0taLWlutxtTx7iI6IXKHwVwdJfbM9B8Sa8Uh5tdOdStJmkScDqwkhJ9dwPaBSX73nLSjRAb9tsQ2nR2RJxB5a77ecVhk7XmLuAkoAfoBW7tZmMkHQY8AlwbEbu72ZZqddpVqu8tN90IsWG/DaEdEbGt+LkDeIzK4W+ZbC/GVvrGWHZ0uT1/FRHbI2JfROwH7qaL352kQ6gExbKIeLRY3PXvrl67yvS95agbITbstyEMlaSxxYArksYCXwfWpbcacSuAOcX7OcATXWxLP30BUZhFl747SQLuATZGxOKqUle/u0btKsv3lquuXLFfnEL+V/7/NoSFI96IOiSdSKX3BZVbsh7sZtsk/QyYTmValO3A94HHgeXAccBbwCURMeID7A3aNp3KIVEAW4Cr+8agRrht5wD/BbwK9E2MdROV8aeufXeJds2mBN9brnzbkZllzVfsm1nWHGJmljWHmJllzSFmZllziJlZ1hxiZpY1h5iZZe3/APpBMI71CUMRAAAAAElFTkSuQmCC\n",
  137. "text/plain": [
  138. "<Figure size 432x288 with 2 Axes>"
  139. ]
  140. },
  141. "metadata": {
  142. "needs_background": "light"
  143. },
  144. "output_type": "display_data"
  145. }
  146. ],
  147. "source": [
  148. "# 显示其中一个数据\n",
  149. "import matplotlib.pyplot as plt\n",
  150. "plt.imshow(train_set.data[0], cmap='gray')\n",
  151. "plt.title('%i' % train_set.targets[0])\n",
  152. "plt.colorbar()\n",
  153. "plt.show()"
  154. ]
  155. },
  156. {
  157. "cell_type": "code",
  158. "execution_count": 6,
  159. "metadata": {},
  160. "outputs": [
  161. {
  162. "name": "stdout",
  163. "output_type": "stream",
  164. "text": [
  165. "torch.Size([64, 1, 32, 32])\n",
  166. "torch.Size([64])\n"
  167. ]
  168. },
  169. {
  170. "data": {
  171. "image/png": "\n",
  172. "text/plain": [
  173. "<Figure size 432x288 with 2 Axes>"
  174. ]
  175. },
  176. "metadata": {
  177. "needs_background": "light"
  178. },
  179. "output_type": "display_data"
  180. },
  181. {
  182. "name": "stdout",
  183. "output_type": "stream",
  184. "text": [
  185. "tensor([[0., 0., 0., ..., 0., 0., 0.],\n",
  186. " [0., 0., 0., ..., 0., 0., 0.],\n",
  187. " [0., 0., 0., ..., 0., 0., 0.],\n",
  188. " ...,\n",
  189. " [0., 0., 0., ..., 0., 0., 0.],\n",
  190. " [0., 0., 0., ..., 0., 0., 0.],\n",
  191. " [0., 0., 0., ..., 0., 0., 0.]])\n"
  192. ]
  193. }
  194. ],
  195. "source": [
  196. "import matplotlib.pyplot as plt\n",
  197. "\n",
  198. "# 显示转化后的图像\n",
  199. "for im, label in train_data:\n",
  200. " print(im.shape)\n",
  201. " print(label.shape)\n",
  202. " \n",
  203. " img = im[0,0,:,:]\n",
  204. " lab = label[0]\n",
  205. " plt.imshow(img, cmap='gray')\n",
  206. " plt.title('%i' % lab)\n",
  207. " plt.colorbar()\n",
  208. " plt.show()\n",
  209. "\n",
  210. " print(im[0,0,:,:])\n",
  211. " break"
  212. ]
  213. },
  214. {
  215. "cell_type": "code",
  216. "execution_count": 7,
  217. "metadata": {
  218. "scrolled": false
  219. },
  220. "outputs": [
  221. {
  222. "name": "stdout",
  223. "output_type": "stream",
  224. "text": [
  225. "Epoch 0. Train Loss: 0.292838, Train Acc: 0.908382, Valid Loss: 0.075638, Valid Acc: 0.974684, Time 00:00:13\n",
  226. "Epoch 1. Train Loss: 0.077091, Train Acc: 0.976479, Valid Loss: 0.066128, Valid Acc: 0.978936, Time 00:00:14\n",
  227. "Epoch 2. Train Loss: 0.055866, Train Acc: 0.982759, Valid Loss: 0.042326, Valid Acc: 0.986748, Time 00:00:14\n",
  228. "Epoch 3. Train Loss: 0.043993, Train Acc: 0.986257, Valid Loss: 0.042040, Valid Acc: 0.986847, Time 00:00:14\n",
  229. "Epoch 4. Train Loss: 0.035289, Train Acc: 0.988823, Valid Loss: 0.035118, Valid Acc: 0.988430, Time 00:00:14\n",
  230. "Epoch 5. Train Loss: 0.030174, Train Acc: 0.990572, Valid Loss: 0.036890, Valid Acc: 0.988430, Time 00:00:14\n",
  231. "Epoch 6. Train Loss: 0.025604, Train Acc: 0.991571, Valid Loss: 0.028075, Valid Acc: 0.990803, Time 00:00:14\n",
  232. "Epoch 7. Train Loss: 0.021483, Train Acc: 0.993220, Valid Loss: 0.039955, Valid Acc: 0.988133, Time 00:00:14\n",
  233. "Epoch 8. Train Loss: 0.018553, Train Acc: 0.994020, Valid Loss: 0.031569, Valid Acc: 0.990506, Time 00:00:14\n",
  234. "Epoch 9. Train Loss: 0.016860, Train Acc: 0.994420, Valid Loss: 0.028923, Valid Acc: 0.990803, Time 00:00:14\n",
  235. "Epoch 10. Train Loss: 0.014547, Train Acc: 0.995186, Valid Loss: 0.041005, Valid Acc: 0.987737, Time 00:00:14\n",
  236. "Epoch 11. Train Loss: 0.011832, Train Acc: 0.996085, Valid Loss: 0.039684, Valid Acc: 0.989221, Time 00:00:14\n",
  237. "Epoch 12. Train Loss: 0.012104, Train Acc: 0.996019, Valid Loss: 0.033983, Valid Acc: 0.990012, Time 00:00:14\n",
  238. "Epoch 13. Train Loss: 0.009578, Train Acc: 0.996802, Valid Loss: 0.044510, Valid Acc: 0.989419, Time 00:00:14\n",
  239. "Epoch 14. Train Loss: 0.008961, Train Acc: 0.997018, Valid Loss: 0.033376, Valid Acc: 0.991693, Time 00:00:14\n",
  240. "Epoch 15. Train Loss: 0.008937, Train Acc: 0.997002, Valid Loss: 0.054347, Valid Acc: 0.986847, Time 00:00:15\n",
  241. "Epoch 16. Train Loss: 0.009171, Train Acc: 0.996902, Valid Loss: 0.034495, Valid Acc: 0.991594, Time 00:00:16\n",
  242. "Epoch 17. Train Loss: 0.006915, Train Acc: 0.997818, Valid Loss: 0.046391, Valid Acc: 0.989517, Time 00:00:16\n",
  243. "Epoch 18. Train Loss: 0.007419, Train Acc: 0.997651, Valid Loss: 0.044388, Valid Acc: 0.989419, Time 00:00:16\n",
  244. "Epoch 19. Train Loss: 0.006600, Train Acc: 0.998001, Valid Loss: 0.049959, Valid Acc: 0.987935, Time 00:00:16\n"
  245. ]
  246. }
  247. ],
  248. "source": [
  249. "net = LeNet5()\n",
  250. "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n",
  251. "criterion = nn.CrossEntropyLoss()\n",
  252. "\n",
  253. "res = train(net, train_data, test_data, 20, \n",
  254. " optimizer, criterion,\n",
  255. " use_cuda=True)"
  256. ]
  257. },
  258. {
  259. "cell_type": "code",
  260. "execution_count": null,
  261. "metadata": {},
  262. "outputs": [],
  263. "source": [
  264. "import matplotlib.pyplot as plt\n",
  265. "%matplotlib inline\n",
  266. "\n",
  267. "plt.plot(res[0], label='train')\n",
  268. "plt.plot(res[2], label='valid')\n",
  269. "plt.xlabel('epoch')\n",
  270. "plt.ylabel('Loss')\n",
  271. "plt.legend(loc='best')\n",
  272. "plt.savefig('fig-res-lenet5-train-validate-loss.pdf')\n",
  273. "plt.show()\n",
  274. "\n",
  275. "plt.plot(res[1], label='train')\n",
  276. "plt.plot(res[3], label='valid')\n",
  277. "plt.xlabel('epoch')\n",
  278. "plt.ylabel('Acc')\n",
  279. "plt.legend(loc='best')\n",
  280. "plt.savefig('fig-res-lenet5-train-validate-acc.pdf')\n",
  281. "plt.show()"
  282. ]
  283. }
  284. ],
  285. "metadata": {
  286. "kernelspec": {
  287. "display_name": "Python 3",
  288. "language": "python",
  289. "name": "python3"
  290. },
  291. "language_info": {
  292. "codemirror_mode": {
  293. "name": "ipython",
  294. "version": 3
  295. },
  296. "file_extension": ".py",
  297. "mimetype": "text/x-python",
  298. "name": "python",
  299. "nbconvert_exporter": "python",
  300. "pygments_lexer": "ipython3",
  301. "version": "3.8.12"
  302. }
  303. },
  304. "nbformat": 4,
  305. "nbformat_minor": 2
  306. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。