|
|
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# DenseNet\n",
- "\n",
- "因为 ResNet 提出了跨层链接的思想,这直接影响了随后出现的卷积网络架构,其中最有名的就是 CVPR 2017 的 Best Paper,DenseNet。DenseNet 和 ResNet 不同在于 ResNet 是跨层求和,而 DenseNet 是跨层将特征在通道维度进行拼接,下面可以看看他们两者的图示:\n",
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "第2张图是 ResNet,第3张图是 DenseNet,因为是在通道维度进行特征的拼接,所以底层的输出会保留进入所有后面的层,这能够更好的保证梯度的传播,同时能够使用低维的特征和高维的特征进行联合训练,能够得到更好的结果。\n",
- "\n",
- "DenseNet主要的优点包括:\n",
- "1. 减轻了vanishing-gradient(梯度消失)\n",
- "2. 加强了feature的传递\n",
- "3. 更有效地利用了feature\n",
- "4. 一定程度上较少了参数数量\n",
- "\n",
- "在深度学习网络中,随着网络深度的加深,梯度消失问题会愈加明显,目前很多论文都针对这个问题提出了解决方案,比如ResNet,Highway Networks,Stochastic depth,FractalNets等,尽管这些算法的网络结构有差别,但是核心都在于:**create short paths from early layers to later layers**。延续这个思路,那就是在保证网络中层与层之间最大程度的信息传输的前提下,直接将所有层连接起来。\n",
- "\n",
- "先放一个dense block的结构图。在传统的卷积神经网络中,如果你有L层,那么就会有L个连接,但是在DenseNet中,会有 **L(L+1)/2** 个连接。简单讲,就是每一层的输入来自前面所有层的输出。如下图:x0是input,H1的输入是x0(input),H2的输入是x0和x1(x1是H1的输出)……\n",
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 1. Dense_Block\n",
- "DenseNet 主要由 Dense Block 构成,下面我们来实现一个 Densen Block"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T15:38:31.113030Z",
- "start_time": "2017-12-22T15:38:30.612922Z"
- },
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import torch\n",
- "from torch import nn\n",
- "from torch.autograd import Variable\n",
- "from torchvision.datasets import CIFAR10\n",
- "from torchvision import transforms as tfs"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "首先定义一个卷积块,这个卷积块的顺序是 bn -> relu -> conv"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T15:38:31.121249Z",
- "start_time": "2017-12-22T15:38:31.115369Z"
- },
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "def Conv_Block(in_channel, out_channel):\n",
- " layer = nn.Sequential(\n",
- " nn.BatchNorm2d(in_channel),\n",
- " nn.ReLU(True),\n",
- " nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False)\n",
- " )\n",
- " return layer"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Dense Block 将每次的卷积的输出称为 `growth_rate`,因为如果输入是 `in_channel`,有 n 层,那么输出就是 `in_channel + n * growh_rate`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T15:38:31.145274Z",
- "start_time": "2017-12-22T15:38:31.123363Z"
- },
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "class Dense_Block(nn.Module):\n",
- " def __init__(self, in_channel, growth_rate, num_layers):\n",
- " super(Dense_Block, self).__init__()\n",
- " block = []\n",
- " channel = in_channel\n",
- " for i in range(num_layers):\n",
- " block.append(Conv_Block(channel, growth_rate))\n",
- " channel += growth_rate\n",
- " \n",
- " self.net = nn.Sequential(*block)\n",
- " \n",
- " def forward(self, x):\n",
- " for layer in self.net:\n",
- " out = layer(x)\n",
- " x = torch.cat((out, x), dim=1)\n",
- " return x"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "我们验证一下输出的 channel 是否正确"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T15:38:31.213632Z",
- "start_time": "2017-12-22T15:38:31.147196Z"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input shape: 3 x 96 x 96\n",
- "output shape: 39 x 96 x 96\n"
- ]
- }
- ],
- "source": [
- "test_net = Dense_Block(3, 12, 3)\n",
- "test_x = Variable(torch.zeros(1, 3, 96, 96))\n",
- "print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))\n",
- "test_y = test_net(test_x)\n",
- "print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "除了 dense block,DenseNet 中还有一个模块叫过渡层(transition block),因为 DenseNet 会不断地对维度进行拼接, 所以当层数很高的时候,输出的通道数就会越来越大,参数和计算量也会越来越大,为了避免这个问题,需要引入过渡层将输出通道降低下来,同时也将输入的长宽减半,这个过渡层可以使用 1 x 1 的卷积"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T15:38:31.222120Z",
- "start_time": "2017-12-22T15:38:31.215770Z"
- },
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "def Transition_Block(in_channel, out_channel):\n",
- " trans_layer = nn.Sequential(\n",
- " nn.BatchNorm2d(in_channel),\n",
- " nn.ReLU(True),\n",
- " nn.Conv2d(in_channel, out_channel, 1),\n",
- " nn.AvgPool2d(2, 2)\n",
- " )\n",
- " return trans_layer"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "验证一下过渡层是否正确"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T15:38:31.234846Z",
- "start_time": "2017-12-22T15:38:31.224078Z"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "input shape: 3 x 96 x 96\n",
- "output shape: 12 x 48 x 48\n"
- ]
- }
- ],
- "source": [
- "test_net = Transition_Block(3, 12)\n",
- "test_x = Variable(torch.zeros(1, 3, 96, 96))\n",
- "print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))\n",
- "test_y = test_net(test_x)\n",
- "print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 2. DenseNet\n",
- "\n",
- "最后我们定义 DenseNet"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T15:38:31.318822Z",
- "start_time": "2017-12-22T15:38:31.236857Z"
- },
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "class DenseNet(nn.Module):\n",
- " def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12, 24, 16]):\n",
- " super(DenseNet, self).__init__()\n",
- " self.block1 = nn.Sequential(\n",
- " nn.Conv2d(in_channels=in_channel, out_channels=64, \n",
- " kernel_size=7, stride=2, padding=3),\n",
- " nn.BatchNorm2d(64),\n",
- " nn.ReLU(True),\n",
- " nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
- " )\n",
- " \n",
- " channels = 64\n",
- " block = []\n",
- " for i, layers in enumerate(block_layers):\n",
- " block.append(Dense_Block(channels, growth_rate, layers))\n",
- " channels += layers * growth_rate\n",
- " if i != len(block_layers) - 1:\n",
- " block.append(Transition_Block(channels, channels // 2)) # 通过 transition 层将大小减半,通道数减半\n",
- " channels = channels // 2\n",
- " \n",
- " self.block2 = nn.Sequential(*block)\n",
- " self.block2.add_module('bn', nn.BatchNorm2d(channels))\n",
- " self.block2.add_module('relu', nn.ReLU(True))\n",
- " self.block2.add_module('avg_pool', nn.AvgPool2d(3))\n",
- " \n",
- " self.classifier = nn.Linear(channels, num_classes)\n",
- " \n",
- " def forward(self, x):\n",
- " x = self.block1(x)\n",
- " x = self.block2(x)\n",
- " \n",
- " x = x.view(x.shape[0], -1)\n",
- " x = self.classifier(x)\n",
- " return x"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T15:38:31.654182Z",
- "start_time": "2017-12-22T15:38:31.320788Z"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "output: torch.Size([1, 10])\n"
- ]
- }
- ],
- "source": [
- "test_net = DenseNet(3, 10)\n",
- "test_x = Variable(torch.zeros(1, 3, 96, 96))\n",
- "test_y = test_net(test_x)\n",
- "print('output: {}'.format(test_y.shape))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T15:38:32.894729Z",
- "start_time": "2017-12-22T15:38:31.656356Z"
- },
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "from utils import train\n",
- "\n",
- "def data_tf(x):\n",
- " im_aug = tfs.Compose([\n",
- " tfs.Resize(96),\n",
- " tfs.ToTensor(),\n",
- " tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
- " ])\n",
- " x = im_aug(x)\n",
- " return x\n",
- " \n",
- "train_set = CIFAR10('../../data', train=True, transform=data_tf)\n",
- "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
- "test_set = CIFAR10('../../data', train=False, transform=data_tf)\n",
- "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n",
- "\n",
- "net = DenseNet(3, 10)\n",
- "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n",
- "criterion = nn.CrossEntropyLoss()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {
- "ExecuteTime": {
- "end_time": "2017-12-22T16:15:38.168095Z",
- "start_time": "2017-12-22T15:38:32.896735Z"
- },
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[ 0] Train:(L=1.398217, Acc=0.485654), Valid:(L=1.140475, Acc=0.589102), T: 00:01:27\n",
- "[ 1] Train:(L=0.977232, Acc=0.655091), Valid:(L=0.878000, Acc=0.688192), T: 00:01:27\n",
- "[ 2] Train:(L=0.765218, Acc=0.732537), Valid:(L=0.722094, Acc=0.746242), T: 00:01:27\n",
- "[ 3] Train:(L=0.622742, Acc=0.782649), Valid:(L=0.603473, Acc=0.791337), T: 00:01:27\n",
- "[ 4] Train:(L=0.516785, Acc=0.818095), Valid:(L=0.558635, Acc=0.809434), T: 00:01:27\n",
- "[ 5] Train:(L=0.429744, Acc=0.849984), Valid:(L=0.562023, Acc=0.815961), T: 00:01:27\n",
- "[ 6] Train:(L=0.353802, Acc=0.876559), Valid:(L=0.590441, Acc=0.806369), T: 00:01:27\n",
- "[ 7] Train:(L=0.287238, Acc=0.900715), Valid:(L=0.549039, Acc=0.826642), T: 00:01:27\n",
- "[ 8] Train:(L=0.234364, Acc=0.916980), Valid:(L=0.518194, Acc=0.841278), T: 00:01:27\n",
- "[ 9] Train:(L=0.182421, Acc=0.935422), Valid:(L=0.522031, Acc=0.850376), T: 00:01:27\n",
- "[10] Train:(L=0.147457, Acc=0.948170), Valid:(L=0.577835, Acc=0.840289), T: 00:01:27\n",
- "[11] Train:(L=0.111591, Acc=0.960938), Valid:(L=0.541511, Acc=0.856903), T: 00:01:27\n",
- "[12] Train:(L=0.096760, Acc=0.965933), Valid:(L=0.598660, Acc=0.848991), T: 00:01:27\n",
- "[13] Train:(L=0.085699, Acc=0.968810), Valid:(L=0.690155, Acc=0.839992), T: 00:01:27\n",
- "[14] Train:(L=0.064518, Acc=0.976303), Valid:(L=0.695370, Acc=0.847409), T: 00:01:27\n",
- "[15] Train:(L=0.067802, Acc=0.976063), Valid:(L=0.747289, Acc=0.837421), T: 00:01:27\n",
- "[16] Train:(L=0.060445, Acc=0.978760), Valid:(L=0.713599, Acc=0.842366), T: 00:01:27\n",
- "[17] Train:(L=0.055962, Acc=0.980259), Valid:(L=0.711125, Acc=0.848101), T: 00:01:27\n",
- "[18] Train:(L=0.049886, Acc=0.982816), Valid:(L=0.731038, Acc=0.842662), T: 00:01:28\n",
- "[19] Train:(L=0.037981, Acc=0.986693), Valid:(L=0.734789, Acc=0.855419), T: 00:01:38\n"
- ]
- }
- ],
- "source": [
- "res = train(net, train_data, test_data, 20, optimizer, criterion)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "import matplotlib.pyplot as plt\n",
- "%matplotlib inline\n",
- "\n",
- "plt.plot(res[0], label='train')\n",
- "plt.plot(res[2], label='valid')\n",
- "plt.xlabel('epoch')\n",
- "plt.ylabel('Loss')\n",
- "plt.legend(loc='best')\n",
- "plt.savefig('fig-res-densenet-train-validate-loss.pdf')\n",
- "plt.show()\n",
- "\n",
- "plt.plot(res[1], label='train')\n",
- "plt.plot(res[3], label='valid')\n",
- "plt.xlabel('epoch')\n",
- "plt.ylabel('Acc')\n",
- "plt.legend(loc='best')\n",
- "plt.savefig('fig-res-densenet-train-validate-acc.pdf')\n",
- "plt.show()\n",
- "\n",
- "# save raw data\n",
- "import numpy\n",
- "numpy.save('fig-res-densenet_data.npy', res)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "DenseNet 将残差连接改为了特征拼接,使得网络有了更稠密的连接"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 参考\n",
- "* [DenseNet算法详解](https://blog.csdn.net/u014380165/article/details/75142664)\n",
- "* [DenseNet详解](https://zhuanlan.zhihu.com/p/43057737)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.5.4"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|