{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# LeNet5\n", "\n", "LeNet 诞生于 1994 年,是最早的卷积神经网络之一,并且推动了深度学习领域的发展。自从 1988 年开始,在多次迭代后这个开拓性成果被命名为 LeNet5。LeNet5 的架构的提出是基于如下的观点:图像的特征分布在整张图像上,通过带有可学习参数的卷积,从而有效的减少了参数数量,能够在多个位置上提取相似特征。\n", "\n", "在LeNet5提出的时候,没有 GPU 帮助训练,甚至 CPU 的速度也很慢,因此,LeNet5的规模并不大。其包含七个处理层,每一层都包含可训练参数(权重),当时使用的输入数据是 $32 \\times 32$ 像素的图像。LeNet-5 这个网络虽然很小,但是它包含了深度学习的基本模块:卷积层,池化层,全连接层。它是其他深度学习模型的基础,这里对LeNet5进行深入分析和讲解,通过实例分析,加深对与卷积层和池化层的理解。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义网络为:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "\n", "class LeNet5(nn.Module):\n", " def __init__(self):\n", " super(LeNet5, self).__init__()\n", " # 1-input channel, 6-output channels, 5x5-conv\n", " self.conv1 = nn.Conv2d(1, 6, 5)\n", " # 6-input channel, 16-output channels, 5x5-conv\n", " self.conv2 = nn.Conv2d(6, 16, 5)\n", " # 16x5x5-input, 120-output\n", " self.fc1 = nn.Linear(16 * 5 * 5, 120) \n", " # 120-input, 84-output\n", " self.fc2 = nn.Linear(120, 84)\n", " # 84-input, 10-output\n", " self.fc3 = nn.Linear(84, 10)\n", "\n", " def forward(self, x):\n", " x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n", " x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))\n", " x = torch.flatten(x, 1) # 将结果拉升成1维向量,除了批次的维度\n", " x = F.relu(self.fc1(x))\n", " x = F.relu(self.fc2(x))\n", " x = self.fc3(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "net = LeNet5()\n", "print(net)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "input = torch.randn(1, 1, 32, 32)\n", "out = net(input)\n", "print(out)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from torchvision.datasets import mnist\n", "from torch.utils.data import DataLoader\n", "from torchvision.datasets import mnist \n", "from torchvision import transforms as tfs\n", "from utils import train\n", "\n", "# 使用数据增强\n", "def data_tf(x):\n", " im_aug = tfs.Compose([\n", " tfs.Resize(32),\n", " tfs.ToTensor() #,\n", " #tfs.Normalize([0.5], [0.5])\n", " ])\n", " x = im_aug(x)\n", " return x\n", " \n", "train_set = mnist.MNIST('../../data/mnist', train=True, transform=data_tf, download=True) \n", "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n", "test_set = mnist.MNIST('../../data/mnist', train=False, transform=data_tf, download=True) \n", "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 显示其中一个数据\n", "import matplotlib.pyplot as plt\n", "plt.imshow(train_set.data[0], cmap='gray')\n", "plt.title('%i' % train_set.targets[0])\n", "plt.colorbar()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "# 显示转化后的图像\n", "for im, label in train_data:\n", " print(im.shape)\n", " print(label.shape)\n", " \n", " img = im[0,0,:,:]\n", " lab = label[0]\n", " plt.imshow(img, cmap='gray')\n", " plt.title('%i' % lab)\n", " plt.colorbar()\n", " plt.show()\n", "\n", " print(im[0,0,:,:])\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "net = LeNet5()\n", "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n", "criterion = nn.CrossEntropyLoss()\n", "\n", "res = train(net, train_data, test_data, 20, \n", " optimizer, criterion,\n", " use_cuda=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "plt.plot(res[0], label='train')\n", "plt.plot(res[2], label='valid')\n", "plt.xlabel('epoch')\n", "plt.ylabel('Loss')\n", "plt.legend(loc='best')\n", "plt.savefig('fig-res-lenet5-train-validate-loss.pdf')\n", "plt.show()\n", "\n", "plt.plot(res[1], label='train')\n", "plt.plot(res[3], label='valid')\n", "plt.xlabel('epoch')\n", "plt.ylabel('Acc')\n", "plt.legend(loc='best')\n", "plt.savefig('fig-res-lenet5-train-validate-acc.pdf')\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 2 }