{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# VGG\n", "\n", "计算机视觉是一直深度学习的主战场,从这里将学习近几年非常流行的卷积网络结构,网络结构由浅变深,参数越来越多,网络有着更多的跨层链接。\n", "\n", "VGG卷积神经网络是牛津大学在2014年提出来的模型。当这个模型被提出时,由于它的简洁性和实用性,马上成为了当时最流行的卷积神经网络模型。它在图像分类和目标检测任务中都表现出非常好的结果。在2014年的ILSVRC比赛中,VGG 在Top-5中取得了92.3%的正确率。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## CIFAR 10\n", "\n", "首先介绍一个数据集 CIFAR10,后续以此数据集为例介绍各种卷积网络的结构。\n", "\n", "CIFAR10 这个数据集一共有 50000 张训练集,10000 张测试集,两个数据集里面的图片都是 png 彩色图片,图片大小是 32 x 32 x 3,一共是 10 分类问题,分别为飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。这个数据集是对网络性能测试一个非常重要的指标,可以说如果一个网络在这个数据集上超过另外一个网络,那么这个网络性能上一定要比另外一个网络好,目前这个数据集最好的结果是 95% 左右的测试集准确率。\n", "\n", "![](images/CIFAR10.png)\n", "\n", "\n", "CIFAR10 已经被 PyTorch 内置了,使用非常方便,只需要调用 `torchvision.datasets.CIFAR10` 就可以了" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## VGGNet\n", "VGGNet 是第一个真正意义上的深层网络结构,其是 ImageNet2014年的冠军,得益于 Python 的函数和循环,我们能够非常方便地构建重复结构的深层网络。\n", "\n", "VGG 的网络结构非常简单,就是不断地堆叠卷积层和池化层,下面是网络结构图:\n", "\n", "![](images/VGG_network.png)\n", "\n", "VGG整个结构只有3×3的卷积层,连续的卷积层后使用池化层隔开。虽然层数很多,但是很简洁。几乎全部使用 3 x 3 的卷积核以及 2 x 2 的池化层,使用小的卷积核进行多层的堆叠和一个大的卷积核的感受野是相同的,同时小的卷积核还能减少参数,同时可以有更深的结构。\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "VGG网络的特点:\n", "* 小卷积核和连续的卷积层: VGG中使用的都是3×3卷积核,并且使用了连续多个卷积层。这样做的好处主要有,\n", " - 使用连续的的多个小卷积核(3×3),来代替一个大的卷积核(例如(5×5)。使用小的卷积核的问题是,其感受野必然变小。所以,VGG中就使用连续的3×3卷积核,来增大感受野。VGG认为2个连续的3×3卷积核能够替代一个5×5卷积核,三个连续的3×3能够代替一个7×7。\n", " - 小卷积核的参数较少。3个3×3的卷积核参数为3×3×=27,而一个7×7的卷积核参数为7×7=49\n", " - 由于每个卷积层都有一个非线性的激活函数,多个卷积层增加了非线性映射。\n", "* 小池化核,使用的是2×2\n", "* 通道数更多,特征度更宽: 每个通道代表着一个FeatureMap,更多的通道数表示更丰富的图像特征。VGG网络第一层的通道数为64,后面每层都进行了翻倍,最多到512个通道,通道数的增加,使得更多的信息可以被提取出来。\n", "* 层数更深: 使用连续的小卷积核代替大的卷积核,网络的深度更深,并且对边缘进行填充,卷积的过程并不会降低图像尺寸。仅使用小的池化单元,降低图像的尺寸。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "VGG 的一个关键就是使用很多层 3 x 3 的卷积然后再使用一个最大池化层,这个模块被使用了很多次,下面照着这个结构把网络用PyTorch实现出来:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2017-12-22T09:01:51.296457Z", "start_time": "2017-12-22T09:01:50.883050Z" }, "collapsed": true }, "outputs": [], "source": [ "import sys\n", "sys.path.append('..')\n", "\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", "from torch.autograd import Variable\n", "from torchvision.datasets import CIFAR10" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "为了代码的简洁和复用,可以定义一个 VGG 的 block,传入三个参数:\n", "* 第一个是模型层数\n", "* 第二个是输入的通道数\n", "* 第三个是输出的通道数\n", "\n", "第一层卷积接受的输入通道就是图片输入的通道数,然后输出最后的输出通道数,后面的卷积接受的通道数就是最后的输出通道数" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2017-12-22T09:01:51.312500Z", "start_time": "2017-12-22T09:01:51.298777Z" }, "collapsed": true }, "outputs": [], "source": [ "def vgg_block(num_convs, in_channels, out_channels):\n", " net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(True)] # 定义第一层\n", "\n", " for i in range(num_convs-1): # 定义后面的很多层\n", " net.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))\n", " net.append(nn.ReLU(True))\n", " \n", " net.append(nn.MaxPool2d(2, 2)) # 定义池化层\n", " return nn.Sequential(*net)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "将模型打印出来,可以看到网络的具体结构" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2017-12-22T08:20:40.819497Z", "start_time": "2017-12-22T08:20:40.808853Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential(\n", " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (3): ReLU(inplace=True)\n", " (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (5): ReLU(inplace=True)\n", " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", ")\n" ] } ], "source": [ "block_demo = vgg_block(3, 64, 128)\n", "print(block_demo)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2017-12-22T07:52:04.632406Z", "start_time": "2017-12-22T07:52:02.381987Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 128, 150, 150])\n" ] } ], "source": [ "# 首先定义输入为 (1, 64, 300, 300) (batch, channels, imgH, imgW)\n", "input_demo = Variable(torch.zeros(1, 64, 300, 300))\n", "output_demo = block_demo(input_demo)\n", "print(output_demo.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到输出就变为了 `(1, 128, 150, 150)` (batch, channels, imgH, imgW) ,可以看到经过了这一个 VGG block,输入大小被减半,通道数变成了 128\n", "\n", "下面我们定义一个函数对这个 VGG block 进行堆叠" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2017-12-22T09:01:54.497712Z", "start_time": "2017-12-22T09:01:54.489255Z" }, "collapsed": true }, "outputs": [], "source": [ "def vgg_stack(num_convs, channels):\n", " net = []\n", " for n, c in zip(num_convs, channels):\n", " in_c = c[0]\n", " out_c = c[1]\n", " net.append(vgg_block(n, in_c, out_c))\n", " return nn.Sequential(*net)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "作为实例,我们定义一个稍微简单一点的 VGG 结构,其中有 8 个卷积层" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2017-12-22T09:01:55.149378Z", "start_time": "2017-12-22T09:01:55.041923Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sequential(\n", " (0): Sequential(\n", " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (3): ReLU(inplace=True)\n", " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (1): Sequential(\n", " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (3): ReLU(inplace=True)\n", " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (2): Sequential(\n", " (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (3): ReLU(inplace=True)\n", " (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (5): ReLU(inplace=True)\n", " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (3): Sequential(\n", " (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (3): ReLU(inplace=True)\n", " (4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (5): ReLU(inplace=True)\n", " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", " (4): Sequential(\n", " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (3): ReLU(inplace=True)\n", " (4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (5): ReLU(inplace=True)\n", " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " )\n", ")\n" ] } ], "source": [ "vgg_net = vgg_stack((2, 2, 3, 3, 3), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))\n", "print(vgg_net)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到网络结构中有个 5 个 最大池化,说明图片的大小会减少 5 倍。可以验证一下,输入一张 256 x 256 的图片看看结果是什么" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2017-12-22T08:52:44.049650Z", "start_time": "2017-12-22T08:52:43.431478Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 512, 8, 8])\n" ] } ], "source": [ "test_x = Variable(torch.zeros(1, 3, 256, 256))\n", "test_y = vgg_net(test_x)\n", "print(test_y.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到图片减小了 $2^5$ 倍,最后再加上几层全连接,就能够得到我们想要的分类输出" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2017-12-22T09:01:57.323034Z", "start_time": "2017-12-22T09:01:57.306864Z" }, "collapsed": true }, "outputs": [], "source": [ "class vgg(nn.Module):\n", " def __init__(self):\n", " super(vgg, self).__init__()\n", " self.feature = vgg_net\n", " self.fc = nn.Sequential(\n", " nn.Linear(512, 100),\n", " nn.ReLU(True),\n", " nn.Linear(100, 10)\n", " )\n", " def forward(self, x):\n", " x = self.feature(x)\n", " x = x.view(x.shape[0], -1)\n", " x = self.fc(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "然后我们可以训练我们的模型看看在 CIFAR10 上的效果" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2017-12-22T09:01:59.921373Z", "start_time": "2017-12-22T09:01:58.709531Z" }, "collapsed": true }, "outputs": [], "source": [ "from utils import train\n", "\n", "def data_tf(x):\n", " x = np.array(x, dtype='float32') / 255\n", " x = (x - 0.5) / 0.5 # 标准化,这个技巧之后会讲到\n", " x = x.transpose((2, 0, 1)) # 将 channel 放到第一维,只是 pytorch 要求的输入方式\n", " x = torch.from_numpy(x)\n", " return x\n", " \n", "train_set = CIFAR10('../../data', train=True, transform=data_tf)\n", "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n", "test_set = CIFAR10('../../data', train=False, transform=data_tf)\n", "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n", "\n", "net = vgg()\n", "optimizer = torch.optim.SGD(net.parameters(), lr=1e-1)\n", "criterion = nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2017-12-22T09:12:46.868967Z", "start_time": "2017-12-22T09:01:59.924086Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0. Train Loss: 2.303118, Train Acc: 0.098186, Valid Loss: 2.302944, Valid Acc: 0.099585, Time 00:00:32\n", "Epoch 1. Train Loss: 2.303085, Train Acc: 0.096907, Valid Loss: 2.302762, Valid Acc: 0.100969, Time 00:00:33\n", "Epoch 2. Train Loss: 2.302916, Train Acc: 0.097287, Valid Loss: 2.302740, Valid Acc: 0.099585, Time 00:00:33\n", "Epoch 3. Train Loss: 2.302395, Train Acc: 0.102042, Valid Loss: 2.297652, Valid Acc: 0.108782, Time 00:00:32\n", "Epoch 4. Train Loss: 2.079523, Train Acc: 0.202026, Valid Loss: 1.868179, Valid Acc: 0.255736, Time 00:00:31\n", "Epoch 5. Train Loss: 1.781262, Train Acc: 0.307625, Valid Loss: 1.735122, Valid Acc: 0.323279, Time 00:00:31\n", "Epoch 6. Train Loss: 1.565095, Train Acc: 0.400975, Valid Loss: 1.463914, Valid Acc: 0.449565, Time 00:00:31\n", "Epoch 7. Train Loss: 1.360450, Train Acc: 0.495225, Valid Loss: 1.374488, Valid Acc: 0.490803, Time 00:00:31\n", "Epoch 8. Train Loss: 1.144470, Train Acc: 0.585758, Valid Loss: 1.384803, Valid Acc: 0.524624, Time 00:00:31\n", "Epoch 9. Train Loss: 0.954556, Train Acc: 0.659287, Valid Loss: 1.113850, Valid Acc: 0.609968, Time 00:00:32\n", "Epoch 10. Train Loss: 0.801952, Train Acc: 0.718131, Valid Loss: 1.080254, Valid Acc: 0.639933, Time 00:00:31\n", "Epoch 11. Train Loss: 0.665018, Train Acc: 0.765945, Valid Loss: 0.916277, Valid Acc: 0.698972, Time 00:00:31\n", "Epoch 12. Train Loss: 0.547411, Train Acc: 0.811241, Valid Loss: 1.030948, Valid Acc: 0.678896, Time 00:00:32\n", "Epoch 13. Train Loss: 0.442779, Train Acc: 0.846228, Valid Loss: 0.869791, Valid Acc: 0.732496, Time 00:00:32\n", "Epoch 14. Train Loss: 0.357279, Train Acc: 0.875440, Valid Loss: 1.233777, Valid Acc: 0.671677, Time 00:00:31\n", "Epoch 15. Train Loss: 0.285171, Train Acc: 0.900096, Valid Loss: 0.852879, Valid Acc: 0.765131, Time 00:00:32\n", "Epoch 16. Train Loss: 0.222431, Train Acc: 0.923374, Valid Loss: 1.848096, Valid Acc: 0.614023, Time 00:00:31\n", "Epoch 17. Train Loss: 0.174834, Train Acc: 0.939478, Valid Loss: 1.137286, Valid Acc: 0.728639, Time 00:00:31\n", "Epoch 18. Train Loss: 0.144375, Train Acc: 0.950587, Valid Loss: 0.907310, Valid Acc: 0.776800, Time 00:00:31\n", "Epoch 19. Train Loss: 0.115332, Train Acc: 0.960878, Valid Loss: 1.009886, Valid Acc: 0.761175, Time 00:00:31\n" ] } ], "source": [ "train(net, train_data, test_data, 20, optimizer, criterion)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,跑完 20 次,VGG 能在 CIFAR10 上取得 76% 左右的测试准确率" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.4" } }, "nbformat": 4, "nbformat_minor": 2 }