|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449 |
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# ResNet\n",
- "\n",
- "当大家还在惊叹 GoogLeNet 的 Inception 结构的时候,微软亚洲研究院的研究员已经在设计更深但结构更加简单的网络 ResNet,并且凭借这个网络子在 2015 年 ImageNet 比赛上大获全胜。\n",
- "\n",
- "ResNet 有效地解决了深度神经网络难以训练的问题,可以训练高达 1000 层的卷积网络。网络之所以难以训练,是因为存在着梯度消失的问题,离 loss 函数越远的层,在反向传播的时候,梯度越小,就越难以更新,随着层数的增加,这个现象越严重。之前有两种常见的方案来解决这个问题:\n",
- "\n",
- "1. 按层训练,先训练比较浅的层,然后在不断增加层数,但是这种方法效果不是特别好,而且比较麻烦\n",
- "2. 使用更宽的层,或者增加输出通道,而不加深网络的层数,这种结构往往得到的效果又不好\n",
- "\n",
- "ResNet 通过引入了跨层链接解决了梯度回传消失的问题。\n",
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "这就普通的网络连接跟跨层残差连接的对比图,使用普通的连接(左图),上层的梯度必须要一层一层传回来;而是用残差连接(右图),相当于中间有了一条更短的路,梯度能够从这条更短的路传回来,避免了梯度过小的情况。\n",
- "\n",
- "假设某层的输入是 $x$,期望输出是 $H(x)$\n",
- "* 如果我们直接把输入 $x$ 传到输出作为初始结果,这就是一个更浅层的网络,更容易训练\n",
- "* 而这个网络没有学习的部分,我们可以使用更深的网络 $F(x)$ 去训练它,使得训练更加容易\n",
- "* 最后希望拟合的结果就是 $F(x) = H(x) - x$,这就是一个残差的结构\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 1. ResidualBlock\n",
- "\n",
- "残差网络的结构就是上面这种残差块的堆叠,下面让我们来实现一个 residual block"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T12:56:06.772059Z",
- "start_time": "2017-12-22T12:56:06.766027Z"
- }
- },
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import torch\n",
- "from torch import nn\n",
- "import torch.nn.functional as F\n",
- "from torch.autograd import Variable\n",
- "from torchvision.datasets import CIFAR10\n",
- "from torchvision import transforms as tfs"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T12:47:49.222432Z",
- "start_time": "2017-12-22T12:47:49.217940Z"
- }
- },
- "outputs": [],
- "source": [
- "def conv3x3(in_channel, out_channel, stride=1):\n",
- " return nn.Conv2d(in_channel, out_channel, 3, \n",
- " stride=stride, padding=1, bias=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T13:14:02.429145Z",
- "start_time": "2017-12-22T13:14:02.383322Z"
- }
- },
- "outputs": [],
- "source": [
- "class Residual_Block(nn.Module):\n",
- " def __init__(self, in_channel, out_channel, same_shape=True):\n",
- " super(Residual_Block, self).__init__()\n",
- " self.same_shape = same_shape\n",
- " stride=1 if self.same_shape else 2\n",
- " \n",
- " self.conv1 = conv3x3(in_channel, out_channel, stride=stride)\n",
- " self.bn1 = nn.BatchNorm2d(out_channel)\n",
- " \n",
- " self.conv2 = conv3x3(out_channel, out_channel)\n",
- " self.bn2 = nn.BatchNorm2d(out_channel)\n",
- " if not self.same_shape:\n",
- " self.conv3 = nn.Conv2d(in_channel, out_channel, 1, \n",
- " stride=stride)\n",
- " \n",
- " def forward(self, x):\n",
- " out = self.conv1(x)\n",
- " out = F.relu(self.bn1(out), True)\n",
- " out = self.conv2(out)\n",
- " out = F.relu(self.bn2(out), True)\n",
- " \n",
- " if not self.same_shape:\n",
- " x = self.conv3(x)\n",
- " return F.relu(x+out, True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "我们测试一下一个 residual block 的输入和输出"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T13:14:05.793185Z",
- "start_time": "2017-12-22T13:14:05.763382Z"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input: torch.Size([1, 32, 96, 96])\n",
- "output: torch.Size([1, 32, 96, 96])\n"
- ]
- }
- ],
- "source": [
- "# 输入输出形状相同\n",
- "test_net = Residual_Block(32, 32)\n",
- "test_x = Variable(torch.zeros(1, 32, 96, 96))\n",
- "print('input: {}'.format(test_x.shape))\n",
- "test_y = test_net(test_x)\n",
- "print('output: {}'.format(test_y.shape))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T13:14:11.929120Z",
- "start_time": "2017-12-22T13:14:11.914604Z"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input: torch.Size([1, 3, 96, 96])\n",
- "output: torch.Size([1, 32, 48, 48])\n"
- ]
- }
- ],
- "source": [
- "# 输入输出形状不同\n",
- "test_net = Residual_Block(3, 32, False)\n",
- "test_x = Variable(torch.zeros(1, 3, 96, 96))\n",
- "print('input: {}'.format(test_x.shape))\n",
- "test_y = test_net(test_x)\n",
- "print('output: {}'.format(test_y.shape))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "一个Residual_Block的结构如下图所示\n",
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 2. ResNet的网络实现\n",
- "\n",
- "下面实现一个 ResNet,它就是 residual block 模块的堆叠"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T13:27:46.099404Z",
- "start_time": "2017-12-22T13:27:45.986235Z"
- }
- },
- "outputs": [],
- "source": [
- "class ResNet(nn.Module):\n",
- " def __init__(self, in_channel, num_classes, verbose=False):\n",
- " super(ResNet, self).__init__()\n",
- " self.verbose = verbose\n",
- " \n",
- " self.block1 = nn.Conv2d(in_channel, 64, 7, 2)\n",
- " \n",
- " self.block2 = nn.Sequential(\n",
- " nn.MaxPool2d(3, 2),\n",
- " Residual_Block(64, 64),\n",
- " Residual_Block(64, 64)\n",
- " )\n",
- " \n",
- " self.block3 = nn.Sequential(\n",
- " Residual_Block(64, 128, False),\n",
- " Residual_Block(128, 128)\n",
- " )\n",
- " \n",
- " self.block4 = nn.Sequential(\n",
- " Residual_Block(128, 256, False),\n",
- " Residual_Block(256, 256)\n",
- " )\n",
- " \n",
- " self.block5 = nn.Sequential(\n",
- " Residual_Block(256, 512, False),\n",
- " Residual_Block(512, 512),\n",
- " nn.AvgPool2d(3)\n",
- " )\n",
- " \n",
- " self.classifier = nn.Linear(512, num_classes)\n",
- " \n",
- " def forward(self, x):\n",
- " x = self.block1(x)\n",
- " if self.verbose:\n",
- " print('block 1 output: {}'.format(x.shape))\n",
- " x = self.block2(x)\n",
- " if self.verbose:\n",
- " print('block 2 output: {}'.format(x.shape))\n",
- " x = self.block3(x)\n",
- " if self.verbose:\n",
- " print('block 3 output: {}'.format(x.shape))\n",
- " x = self.block4(x)\n",
- " if self.verbose:\n",
- " print('block 4 output: {}'.format(x.shape))\n",
- " x = self.block5(x)\n",
- " if self.verbose:\n",
- " print('block 5 output: {}'.format(x.shape))\n",
- " x = x.view(x.shape[0], -1)\n",
- " x = self.classifier(x)\n",
- " return x"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "输出一下每个 block 之后的大小"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T13:28:00.597030Z",
- "start_time": "2017-12-22T13:28:00.417746Z"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "block 1 output: torch.Size([1, 64, 45, 45])\n",
- "block 2 output: torch.Size([1, 64, 22, 22])\n",
- "block 3 output: torch.Size([1, 128, 11, 11])\n",
- "block 4 output: torch.Size([1, 256, 6, 6])\n",
- "block 5 output: torch.Size([1, 512, 1, 1])\n",
- "output: torch.Size([1, 10])\n"
- ]
- }
- ],
- "source": [
- "test_net = ResNet(3, 10, True)\n",
- "test_x = Variable(torch.zeros(1, 3, 96, 96))\n",
- "test_y = test_net(test_x)\n",
- "print('output: {}'.format(test_y.shape))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T13:29:01.484172Z",
- "start_time": "2017-12-22T13:29:00.095952Z"
- }
- },
- "outputs": [],
- "source": [
- "from utils import train\n",
- "\n",
- "def data_tf(x):\n",
- " im_aug = tfs.Compose([\n",
- " tfs.Resize(96),\n",
- " tfs.ToTensor(),\n",
- " tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
- " ])\n",
- " x = im_aug(x)\n",
- " return x\n",
- " \n",
- "train_set = CIFAR10('../../data', train=True, transform=data_tf)\n",
- "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
- "test_set = CIFAR10('../../data', train=False, transform=data_tf)\n",
- "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n",
- "\n",
- "net = ResNet(3, 10)\n",
- "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n",
- "criterion = nn.CrossEntropyLoss()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T13:45:00.783186Z",
- "start_time": "2017-12-22T13:29:09.214453Z"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[ 0] Train:(L=1.506980, Acc=0.449868), Valid:(L=1.119623, Acc=0.598596), T: 00:00:48\n",
- "[ 1] Train:(L=1.022635, Acc=0.641504), Valid:(L=0.942414, Acc=0.669600), T: 00:00:47\n",
- "[ 2] Train:(L=0.806174, Acc=0.717551), Valid:(L=0.921687, Acc=0.682061), T: 00:00:47\n",
- "[ 3] Train:(L=0.638939, Acc=0.775555), Valid:(L=0.802450, Acc=0.729727), T: 00:00:47\n",
- "[ 4] Train:(L=0.497571, Acc=0.826606), Valid:(L=0.658700, Acc=0.775316), T: 00:00:47\n",
- "[ 5] Train:(L=0.364864, Acc=0.872442), Valid:(L=0.717290, Acc=0.768888), T: 00:00:47\n",
- "[ 6] Train:(L=0.263076, Acc=0.907888), Valid:(L=0.832575, Acc=0.750000), T: 00:00:47\n",
- "[ 7] Train:(L=0.181254, Acc=0.935782), Valid:(L=0.818366, Acc=0.764933), T: 00:00:47\n",
- "[ 8] Train:(L=0.124111, Acc=0.957820), Valid:(L=0.883527, Acc=0.778184), T: 00:00:47\n",
- "[ 9] Train:(L=0.108587, Acc=0.961657), Valid:(L=0.899127, Acc=0.780756), T: 00:00:47\n",
- "[10] Train:(L=0.091386, Acc=0.968670), Valid:(L=0.975022, Acc=0.781448), T: 00:00:47\n",
- "[11] Train:(L=0.079259, Acc=0.972287), Valid:(L=1.061239, Acc=0.770075), T: 00:00:47\n",
- "[12] Train:(L=0.067858, Acc=0.976123), Valid:(L=1.025909, Acc=0.782140), T: 00:00:47\n",
- "[13] Train:(L=0.064745, Acc=0.977701), Valid:(L=0.987410, Acc=0.789062), T: 00:00:47\n",
- "[14] Train:(L=0.056921, Acc=0.979779), Valid:(L=1.165746, Acc=0.773438), T: 00:00:47\n",
- "[15] Train:(L=0.058128, Acc=0.980039), Valid:(L=1.057119, Acc=0.782437), T: 00:00:47\n",
- "[16] Train:(L=0.050794, Acc=0.982257), Valid:(L=1.098127, Acc=0.779074), T: 00:00:47\n",
- "[17] Train:(L=0.046720, Acc=0.984415), Valid:(L=1.066124, Acc=0.787184), T: 00:00:47\n",
- "[18] Train:(L=0.044737, Acc=0.984375), Valid:(L=1.053032, Acc=0.792029), T: 00:00:47\n"
- ]
- }
- ],
- "source": [
- "res = train(net, train_data, test_data, 20, optimizer, criterion)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import matplotlib.pyplot as plt\n",
- "%matplotlib inline\n",
- "\n",
- "plt.plot(res[0], label='train')\n",
- "plt.plot(res[2], label='valid')\n",
- "plt.xlabel('epoch')\n",
- "plt.ylabel('Loss')\n",
- "plt.legend(loc='best')\n",
- "plt.savefig('fig-res-resnet-train-validate-loss.pdf')\n",
- "plt.show()\n",
- "\n",
- "plt.plot(res[1], label='train')\n",
- "plt.plot(res[3], label='valid')\n",
- "plt.xlabel('epoch')\n",
- "plt.ylabel('Acc')\n",
- "plt.legend(loc='best')\n",
- "plt.savefig('fig-res-resnet-train-validate-acc.pdf')\n",
- "plt.show()\n",
- "\n",
- "# save raw data\n",
- "import numpy\n",
- "numpy.save('fig-res-resnet_data.npy', res)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "ResNet 使用跨层通道使得训练非常深的卷积神经网络成为可能。同样它使用很简单的卷积层配置,使得其拓展更加简单。\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 练习\n",
- "\n",
- "* 尝试一下论文中提出的 bottleneck 的结构 \n",
- "* 尝试改变 conv -> bn -> relu 的顺序为 bn -> relu -> conv,看看精度会不会提高\n",
- "* 在Residual_Block加入1x1卷积,并尝试结果的差别"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 参考资料\n",
- "* [Residual Networks (ResNet)](https://d2l.ai/chapter_convolutional-modern/resnet.html)\n",
- "* [An Overview of ResNet and its Variants](https://towardsdatascience.com/an-overview-of-resnet-and-its-variants-5281e2f56035)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|