@@ -0,0 +1,157 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# LeNet5\n", | |||||
"\n", | |||||
"LeNet 诞生于 1994 年,是最早的卷积神经网络之一,并且推动了深度学习领域的发展。自从 1988 年开始,在多次迭代后这个开拓性成果被命名为 LeNet5。LeNet5 的架构的提出是基于如下的观点:图像的特征分布在整张图像上,通过带有可学习参数的卷积,从而有效的减少了参数数量,能够在多个位置上提取相似特征。\n", | |||||
"\n", | |||||
"在LeNet5提出的时候,没有 GPU 帮助训练,甚至 CPU 的速度也很慢,因此,LeNet5的规模并不大。其包含七个处理层,每一层都包含可训练参数(权重),当时使用的输入数据是 $32 \\times 32$ 像素的图像。LeNet-5 这个网络虽然很小,但是它包含了深度学习的基本模块:卷积层,池化层,全连接层。它是其他深度学习模型的基础,这里对LeNet5进行深入分析和讲解,通过实例分析,加深对与卷积层和池化层的理解。" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": { | |||||
"collapsed": true | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"import sys\n", | |||||
"sys.path.append('..')\n", | |||||
"\n", | |||||
"import numpy as np\n", | |||||
"import torch\n", | |||||
"from torch import nn\n", | |||||
"from torch.autograd import Variable\n", | |||||
"from torchvision.datasets import CIFAR10\n", | |||||
"from torchvision import transforms as tfs" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": 1, | |||||
"metadata": { | |||||
"collapsed": true | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"import torch\n", | |||||
"from torch import nn\n", | |||||
"\n", | |||||
"lenet5 = nn.Sequential(\n", | |||||
" nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),\n", | |||||
" nn.AvgPool2d(kernel_size=2, stride=2),\n", | |||||
" nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),\n", | |||||
" nn.AvgPool2d(kernel_size=2, stride=2),\n", | |||||
" nn.Flatten(),\n", | |||||
" nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),\n", | |||||
" nn.Linear(120, 84), nn.Sigmoid(),\n", | |||||
" nn.Linear(84, 10) )" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"collapsed": true | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"from utils import train\n", | |||||
"\n", | |||||
"# 使用数据增强\n", | |||||
"def train_tf(x):\n", | |||||
" im_aug = tfs.Compose([\n", | |||||
" tfs.Resize(224),\n", | |||||
" tfs.ToTensor(),\n", | |||||
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n", | |||||
" ])\n", | |||||
" x = im_aug(x)\n", | |||||
" return x\n", | |||||
"\n", | |||||
"def test_tf(x):\n", | |||||
" im_aug = tfs.Compose([\n", | |||||
" tfs.Resize(224),\n", | |||||
" tfs.ToTensor(),\n", | |||||
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n", | |||||
" ])\n", | |||||
" x = im_aug(x)\n", | |||||
" return x\n", | |||||
" \n", | |||||
"train_set = CIFAR10('../../data', train=True, transform=train_tf)\n", | |||||
"train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n", | |||||
"test_set = CIFAR10('../../data', train=False, transform=test_tf)\n", | |||||
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n", | |||||
"\n", | |||||
"net = lenet5\n", | |||||
"optimizer = torch.optim.SGD(net.parameters(), lr=1e-1)\n", | |||||
"criterion = nn.CrossEntropyLoss()" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"collapsed": true | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"(l_train_loss, l_train_acc, l_valid_loss, l_valid_acc) = train(net, \n", | |||||
" train_data, test_data, \n", | |||||
" 20, \n", | |||||
" optimizer, criterion,\n", | |||||
" use_cuda=False)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"collapsed": true | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"import matplotlib.pyplot as plt\n", | |||||
"%matplotlib inline\n", | |||||
"\n", | |||||
"plt.plot(l_train_loss, label='train')\n", | |||||
"plt.plot(l_valid_loss, label='valid')\n", | |||||
"plt.xlabel('epoch')\n", | |||||
"plt.legend(loc='best')\n", | |||||
"plt.savefig('fig-res-lenet5-train-validate-loss.pdf')\n", | |||||
"plt.show()\n", | |||||
"\n", | |||||
"plt.plot(l_train_acc, label='train')\n", | |||||
"plt.plot(l_valid_acc, label='valid')\n", | |||||
"plt.xlabel('epoch')\n", | |||||
"plt.legend(loc='best')\n", | |||||
"plt.savefig('fig-res-lenet5-train-validate-acc.pdf')\n", | |||||
"plt.show()" | |||||
] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python 3", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.5.4" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -0,0 +1,99 @@ | |||||
{ | |||||
"cells": [ | |||||
{ | |||||
"cell_type": "markdown", | |||||
"metadata": {}, | |||||
"source": [ | |||||
"# AlexNet\n", | |||||
"\n", | |||||
"\n", | |||||
"第一个典型的卷积神经网络是 LeNet5 ,但是第一个开启深度学习的网络却是 AlexNet,这个网络在2012年的ImageNet竞赛中取得冠军。这网络提出了深度学习常用的技术:ReLU和Dropout。AlexNet网络结构在整体上类似于LeNet,都是先卷积然后在全连接,但在细节上有很大不同,AlexNet更为复杂,Alexnet模型由5个卷积层和3个池化Pooling层,其中还有3个全连接层构成,共有$6 \\times 10^7$个参数和65000个神经元,最终的输出层是1000通道的Softmax。AlexNet 跟 LeNet 结构类似,但使⽤了更多的卷积层和更⼤的参数空间来拟合⼤规模数据集 ImageNet,它是浅层神经⽹络和深度神经⽹络的分界线。\n" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"collapsed": true | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"import torch.nn as nn\n", | |||||
"import torch\n", | |||||
"\n", | |||||
"class AlexNet(nn.Module):\n", | |||||
" def __init__(self, num_classes=1000, init_weights=False): \n", | |||||
" super(AlexNet, self).__init__()\n", | |||||
" self.features = nn.Sequential( \n", | |||||
" nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2), \n", | |||||
" nn.ReLU(inplace=True), #inplace 可以载入更大模型\n", | |||||
" nn.MaxPool2d(kernel_size=3, stride=2), \n", | |||||
"\n", | |||||
" nn.Conv2d(96, 256, kernel_size=5, padding=2),\n", | |||||
" nn.ReLU(inplace=True),\n", | |||||
" nn.MaxPool2d(kernel_size=3, stride=2),\n", | |||||
"\n", | |||||
" nn.Conv2d(256, 384, kernel_size=3, padding=1),\n", | |||||
" nn.ReLU(inplace=True),\n", | |||||
"\n", | |||||
" nn.Conv2d(384, 384, kernel_size=3, padding=1),\n", | |||||
" nn.ReLU(inplace=True),\n", | |||||
"\n", | |||||
" nn.Conv2d(384, 256, kernel_size=3, padding=1),\n", | |||||
" nn.ReLU(inplace=True),\n", | |||||
" nn.MaxPool2d(kernel_size=3, stride=2),\n", | |||||
" )\n", | |||||
" self.classifier = nn.Sequential(\n", | |||||
" nn.Dropout(p=0.5),\n", | |||||
" nn.Linear(256*6*6, 4096), #全链接\n", | |||||
" nn.ReLU(inplace=True),\n", | |||||
" nn.Dropout(p=0.5),\n", | |||||
" nn.Linear(4096, 4096),\n", | |||||
" nn.ReLU(inplace=True),\n", | |||||
" nn.Linear(4096, num_classes),\n", | |||||
" )\n", | |||||
" if init_weights:\n", | |||||
" self._initialize_weights()\n", | |||||
"\n", | |||||
" def forward(self, x):\n", | |||||
" x = self.features(x)\n", | |||||
" x = torch.flatten(x, start_dim=1) #展平或者view()\n", | |||||
" x = self.classifier(x)\n", | |||||
" return x\n", | |||||
"\n", | |||||
" def _initialize_weights(self):\n", | |||||
" for m in self.modules():\n", | |||||
" if isinstance(m, nn.Conv2d):\n", | |||||
" #何教授方法\n", | |||||
" nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') \n", | |||||
" if m.bias is not None:\n", | |||||
" nn.init.constant_(m.bias, 0)\n", | |||||
" elif isinstance(m, nn.Linear):\n", | |||||
" #正态分布赋值\n", | |||||
" nn.init.normal_(m.weight, 0, 0.01) \n", | |||||
" nn.init.constant_(m.bias, 0)" | |||||
] | |||||
} | |||||
], | |||||
"metadata": { | |||||
"kernelspec": { | |||||
"display_name": "Python 3", | |||||
"language": "python", | |||||
"name": "python3" | |||||
}, | |||||
"language_info": { | |||||
"codemirror_mode": { | |||||
"name": "ipython", | |||||
"version": 3 | |||||
}, | |||||
"file_extension": ".py", | |||||
"mimetype": "text/x-python", | |||||
"name": "python", | |||||
"nbconvert_exporter": "python", | |||||
"pygments_lexer": "ipython3", | |||||
"version": "3.5.4" | |||||
} | |||||
}, | |||||
"nbformat": 4, | |||||
"nbformat_minor": 2 | |||||
} |
@@ -48,7 +48,7 @@ | |||||
"VGG网络的特点:\n", | "VGG网络的特点:\n", | ||||
"* 小卷积核和连续的卷积层: VGG中使用的都是3×3卷积核,并且使用了连续多个卷积层。这样做的好处主要有,\n", | "* 小卷积核和连续的卷积层: VGG中使用的都是3×3卷积核,并且使用了连续多个卷积层。这样做的好处主要有,\n", | ||||
" - 使用连续的的多个小卷积核(3×3),来代替一个大的卷积核(例如(5×5)。使用小的卷积核的问题是,其感受野必然变小。所以,VGG中就使用连续的3×3卷积核,来增大感受野。VGG认为2个连续的3×3卷积核能够替代一个5×5卷积核,三个连续的3×3能够代替一个7×7。\n", | " - 使用连续的的多个小卷积核(3×3),来代替一个大的卷积核(例如(5×5)。使用小的卷积核的问题是,其感受野必然变小。所以,VGG中就使用连续的3×3卷积核,来增大感受野。VGG认为2个连续的3×3卷积核能够替代一个5×5卷积核,三个连续的3×3能够代替一个7×7。\n", | ||||
" - 小卷积核的参数较少。3个3×3的卷积核参数为3×3×=27,而一个7×7的卷积核参数为7×7=49\n", | |||||
" - 小卷积核的参数较少。3个3×3的卷积核参数为3×3×3=27,而一个7×7的卷积核参数为7×7=49\n", | |||||
" - 由于每个卷积层都有一个非线性的激活函数,多个卷积层增加了非线性映射。\n", | " - 由于每个卷积层都有一个非线性的激活函数,多个卷积层增加了非线性映射。\n", | ||||
"* 小池化核,使用的是2×2\n", | "* 小池化核,使用的是2×2\n", | ||||
"* 通道数更多,特征度更宽: 每个通道代表着一个FeatureMap,更多的通道数表示更丰富的图像特征。VGG网络第一层的通道数为64,后面每层都进行了翻倍,最多到512个通道,通道数的增加,使得更多的信息可以被提取出来。\n", | "* 通道数更多,特征度更宽: 每个通道代表着一个FeatureMap,更多的通道数表示更丰富的图像特征。VGG网络第一层的通道数为64,后面每层都进行了翻倍,最多到512个通道,通道数的增加,使得更多的信息可以被提取出来。\n", | ||||
@@ -64,7 +64,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 2, | |||||
"execution_count": 1, | |||||
"metadata": { | "metadata": { | ||||
"ExecuteTime": { | "ExecuteTime": { | ||||
"end_time": "2017-12-22T09:01:51.296457Z", | "end_time": "2017-12-22T09:01:51.296457Z", | ||||
@@ -81,7 +81,8 @@ | |||||
"import torch\n", | "import torch\n", | ||||
"from torch import nn\n", | "from torch import nn\n", | ||||
"from torch.autograd import Variable\n", | "from torch.autograd import Variable\n", | ||||
"from torchvision.datasets import CIFAR10" | |||||
"from torchvision.datasets import CIFAR10\n", | |||||
"from torchvision import transforms as tfs" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -98,7 +99,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 3, | |||||
"execution_count": 2, | |||||
"metadata": { | "metadata": { | ||||
"ExecuteTime": { | "ExecuteTime": { | ||||
"end_time": "2017-12-22T09:01:51.312500Z", | "end_time": "2017-12-22T09:01:51.312500Z", | ||||
@@ -108,7 +109,7 @@ | |||||
}, | }, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"def vgg_block(num_convs, in_channels, out_channels):\n", | |||||
"def VGG_Block(num_convs, in_channels, out_channels):\n", | |||||
" net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(True)] # 定义第一层\n", | " net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(True)] # 定义第一层\n", | ||||
"\n", | "\n", | ||||
" for i in range(num_convs-1): # 定义后面的很多层\n", | " for i in range(num_convs-1): # 定义后面的很多层\n", | ||||
@@ -128,7 +129,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 4, | |||||
"execution_count": 3, | |||||
"metadata": { | "metadata": { | ||||
"ExecuteTime": { | "ExecuteTime": { | ||||
"end_time": "2017-12-22T08:20:40.819497Z", | "end_time": "2017-12-22T08:20:40.819497Z", | ||||
@@ -153,13 +154,13 @@ | |||||
} | } | ||||
], | ], | ||||
"source": [ | "source": [ | ||||
"block_demo = vgg_block(3, 64, 128)\n", | |||||
"block_demo = VGG_Block(3, 64, 128)\n", | |||||
"print(block_demo)" | "print(block_demo)" | ||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 5, | |||||
"execution_count": 4, | |||||
"metadata": { | "metadata": { | ||||
"ExecuteTime": { | "ExecuteTime": { | ||||
"end_time": "2017-12-22T07:52:04.632406Z", | "end_time": "2017-12-22T07:52:04.632406Z", | ||||
@@ -193,7 +194,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 6, | |||||
"execution_count": 5, | |||||
"metadata": { | "metadata": { | ||||
"ExecuteTime": { | "ExecuteTime": { | ||||
"end_time": "2017-12-22T09:01:54.497712Z", | "end_time": "2017-12-22T09:01:54.497712Z", | ||||
@@ -203,12 +204,12 @@ | |||||
}, | }, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"def vgg_stack(num_convs, channels):\n", | |||||
"def VGG_Stack(num_convs, channels):\n", | |||||
" net = []\n", | " net = []\n", | ||||
" for n, c in zip(num_convs, channels):\n", | " for n, c in zip(num_convs, channels):\n", | ||||
" in_c = c[0]\n", | " in_c = c[0]\n", | ||||
" out_c = c[1]\n", | " out_c = c[1]\n", | ||||
" net.append(vgg_block(n, in_c, out_c))\n", | |||||
" net.append(VGG_Block(n, in_c, out_c))\n", | |||||
" return nn.Sequential(*net)" | " return nn.Sequential(*net)" | ||||
] | ] | ||||
}, | }, | ||||
@@ -221,7 +222,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 7, | |||||
"execution_count": 6, | |||||
"metadata": { | "metadata": { | ||||
"ExecuteTime": { | "ExecuteTime": { | ||||
"end_time": "2017-12-22T09:01:55.149378Z", | "end_time": "2017-12-22T09:01:55.149378Z", | ||||
@@ -280,7 +281,7 @@ | |||||
} | } | ||||
], | ], | ||||
"source": [ | "source": [ | ||||
"vgg_net = vgg_stack((2, 2, 3, 3, 3), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))\n", | |||||
"vgg_net = VGG_Stack((2, 2, 3, 3, 3), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))\n", | |||||
"print(vgg_net)" | "print(vgg_net)" | ||||
] | ] | ||||
}, | }, | ||||
@@ -288,12 +289,12 @@ | |||||
"cell_type": "markdown", | "cell_type": "markdown", | ||||
"metadata": {}, | "metadata": {}, | ||||
"source": [ | "source": [ | ||||
"可以看到网络结构中有个 5 个 最大池化,说明图片的大小会减少 5 倍。可以验证一下,输入一张 256 x 256 的图片看看结果是什么" | |||||
"可以看到网络结构中有个 5 个 最大池化,说明图片的大小会减少 5 倍。可以验证一下,输入一张 224 x 224 的图片看看结果是什么" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 8, | |||||
"execution_count": 7, | |||||
"metadata": { | "metadata": { | ||||
"ExecuteTime": { | "ExecuteTime": { | ||||
"end_time": "2017-12-22T08:52:44.049650Z", | "end_time": "2017-12-22T08:52:44.049650Z", | ||||
@@ -305,12 +306,12 @@ | |||||
"name": "stdout", | "name": "stdout", | ||||
"output_type": "stream", | "output_type": "stream", | ||||
"text": [ | "text": [ | ||||
"torch.Size([1, 512, 8, 8])\n" | |||||
"torch.Size([1, 512, 7, 7])\n" | |||||
] | ] | ||||
} | } | ||||
], | ], | ||||
"source": [ | "source": [ | ||||
"test_x = Variable(torch.zeros(1, 3, 256, 256))\n", | |||||
"test_x = Variable(torch.zeros(1, 3, 224, 224))\n", | |||||
"test_y = vgg_net(test_x)\n", | "test_y = vgg_net(test_x)\n", | ||||
"print(test_y.shape)" | "print(test_y.shape)" | ||||
] | ] | ||||
@@ -324,7 +325,7 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 9, | |||||
"execution_count": 8, | |||||
"metadata": { | "metadata": { | ||||
"ExecuteTime": { | "ExecuteTime": { | ||||
"end_time": "2017-12-22T09:01:57.323034Z", | "end_time": "2017-12-22T09:01:57.323034Z", | ||||
@@ -334,14 +335,14 @@ | |||||
}, | }, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"class vgg(nn.Module):\n", | |||||
"class VGG_Net(nn.Module):\n", | |||||
" def __init__(self):\n", | " def __init__(self):\n", | ||||
" super(vgg, self).__init__()\n", | |||||
" self.feature = vgg_net\n", | |||||
" super(VGG_Net, self).__init__()\n", | |||||
" self.feature = VGG_Stack((2, 2, 3, 3, 3), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))\n", | |||||
" self.fc = nn.Sequential(\n", | " self.fc = nn.Sequential(\n", | ||||
" nn.Linear(512, 100),\n", | |||||
" nn.Linear(512*7*7, 4096),\n", | |||||
" nn.ReLU(True),\n", | " nn.ReLU(True),\n", | ||||
" nn.Linear(100, 10)\n", | |||||
" nn.Linear(4096, 10)\n", | |||||
" )\n", | " )\n", | ||||
" def forward(self, x):\n", | " def forward(self, x):\n", | ||||
" x = self.feature(x)\n", | " x = self.feature(x)\n", | ||||
@@ -359,74 +360,88 @@ | |||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 6, | |||||
"execution_count": 9, | |||||
"metadata": { | "metadata": { | ||||
"ExecuteTime": { | "ExecuteTime": { | ||||
"end_time": "2017-12-22T09:01:59.921373Z", | "end_time": "2017-12-22T09:01:59.921373Z", | ||||
"start_time": "2017-12-22T09:01:58.709531Z" | "start_time": "2017-12-22T09:01:58.709531Z" | ||||
}, | |||||
"collapsed": true | |||||
} | |||||
}, | }, | ||||
"outputs": [], | "outputs": [], | ||||
"source": [ | "source": [ | ||||
"from utils import train\n", | "from utils import train\n", | ||||
"\n", | "\n", | ||||
"def data_tf(x):\n", | |||||
" x = np.array(x, dtype='float32') / 255\n", | |||||
" x = (x - 0.5) / 0.5 # 标准化,这个技巧之后会讲到\n", | |||||
" x = x.transpose((2, 0, 1)) # 将 channel 放到第一维,只是 pytorch 要求的输入方式\n", | |||||
" x = torch.from_numpy(x)\n", | |||||
"# 使用数据增强\n", | |||||
"def train_tf(x):\n", | |||||
" im_aug = tfs.Compose([\n", | |||||
" tfs.Resize(224),\n", | |||||
" tfs.ToTensor(),\n", | |||||
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n", | |||||
" ])\n", | |||||
" x = im_aug(x)\n", | |||||
" return x\n", | |||||
"\n", | |||||
"def test_tf(x):\n", | |||||
" im_aug = tfs.Compose([\n", | |||||
" tfs.Resize(224),\n", | |||||
" tfs.ToTensor(),\n", | |||||
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n", | |||||
" ])\n", | |||||
" x = im_aug(x)\n", | |||||
" return x\n", | " return x\n", | ||||
" \n", | " \n", | ||||
"train_set = CIFAR10('../../data', train=True, transform=data_tf)\n", | |||||
"train_set = CIFAR10('../../data', train=True, transform=train_tf)\n", | |||||
"train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n", | "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n", | ||||
"test_set = CIFAR10('../../data', train=False, transform=data_tf)\n", | |||||
"test_set = CIFAR10('../../data', train=False, transform=test_tf)\n", | |||||
"test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n", | "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n", | ||||
"\n", | "\n", | ||||
"net = vgg()\n", | |||||
"net = VGG_Net()\n", | |||||
"optimizer = torch.optim.SGD(net.parameters(), lr=1e-1)\n", | "optimizer = torch.optim.SGD(net.parameters(), lr=1e-1)\n", | ||||
"criterion = nn.CrossEntropyLoss()" | "criterion = nn.CrossEntropyLoss()" | ||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
"cell_type": "code", | "cell_type": "code", | ||||
"execution_count": 7, | |||||
"execution_count": null, | |||||
"metadata": { | "metadata": { | ||||
"ExecuteTime": { | "ExecuteTime": { | ||||
"end_time": "2017-12-22T09:12:46.868967Z", | "end_time": "2017-12-22T09:12:46.868967Z", | ||||
"start_time": "2017-12-22T09:01:59.924086Z" | "start_time": "2017-12-22T09:01:59.924086Z" | ||||
} | } | ||||
}, | }, | ||||
"outputs": [ | |||||
{ | |||||
"name": "stdout", | |||||
"output_type": "stream", | |||||
"text": [ | |||||
"Epoch 0. Train Loss: 2.303118, Train Acc: 0.098186, Valid Loss: 2.302944, Valid Acc: 0.099585, Time 00:00:32\n", | |||||
"Epoch 1. Train Loss: 2.303085, Train Acc: 0.096907, Valid Loss: 2.302762, Valid Acc: 0.100969, Time 00:00:33\n", | |||||
"Epoch 2. Train Loss: 2.302916, Train Acc: 0.097287, Valid Loss: 2.302740, Valid Acc: 0.099585, Time 00:00:33\n", | |||||
"Epoch 3. Train Loss: 2.302395, Train Acc: 0.102042, Valid Loss: 2.297652, Valid Acc: 0.108782, Time 00:00:32\n", | |||||
"Epoch 4. Train Loss: 2.079523, Train Acc: 0.202026, Valid Loss: 1.868179, Valid Acc: 0.255736, Time 00:00:31\n", | |||||
"Epoch 5. Train Loss: 1.781262, Train Acc: 0.307625, Valid Loss: 1.735122, Valid Acc: 0.323279, Time 00:00:31\n", | |||||
"Epoch 6. Train Loss: 1.565095, Train Acc: 0.400975, Valid Loss: 1.463914, Valid Acc: 0.449565, Time 00:00:31\n", | |||||
"Epoch 7. Train Loss: 1.360450, Train Acc: 0.495225, Valid Loss: 1.374488, Valid Acc: 0.490803, Time 00:00:31\n", | |||||
"Epoch 8. Train Loss: 1.144470, Train Acc: 0.585758, Valid Loss: 1.384803, Valid Acc: 0.524624, Time 00:00:31\n", | |||||
"Epoch 9. Train Loss: 0.954556, Train Acc: 0.659287, Valid Loss: 1.113850, Valid Acc: 0.609968, Time 00:00:32\n", | |||||
"Epoch 10. Train Loss: 0.801952, Train Acc: 0.718131, Valid Loss: 1.080254, Valid Acc: 0.639933, Time 00:00:31\n", | |||||
"Epoch 11. Train Loss: 0.665018, Train Acc: 0.765945, Valid Loss: 0.916277, Valid Acc: 0.698972, Time 00:00:31\n", | |||||
"Epoch 12. Train Loss: 0.547411, Train Acc: 0.811241, Valid Loss: 1.030948, Valid Acc: 0.678896, Time 00:00:32\n", | |||||
"Epoch 13. Train Loss: 0.442779, Train Acc: 0.846228, Valid Loss: 0.869791, Valid Acc: 0.732496, Time 00:00:32\n", | |||||
"Epoch 14. Train Loss: 0.357279, Train Acc: 0.875440, Valid Loss: 1.233777, Valid Acc: 0.671677, Time 00:00:31\n", | |||||
"Epoch 15. Train Loss: 0.285171, Train Acc: 0.900096, Valid Loss: 0.852879, Valid Acc: 0.765131, Time 00:00:32\n", | |||||
"Epoch 16. Train Loss: 0.222431, Train Acc: 0.923374, Valid Loss: 1.848096, Valid Acc: 0.614023, Time 00:00:31\n", | |||||
"Epoch 17. Train Loss: 0.174834, Train Acc: 0.939478, Valid Loss: 1.137286, Valid Acc: 0.728639, Time 00:00:31\n", | |||||
"Epoch 18. Train Loss: 0.144375, Train Acc: 0.950587, Valid Loss: 0.907310, Valid Acc: 0.776800, Time 00:00:31\n", | |||||
"Epoch 19. Train Loss: 0.115332, Train Acc: 0.960878, Valid Loss: 1.009886, Valid Acc: 0.761175, Time 00:00:31\n" | |||||
] | |||||
} | |||||
], | |||||
"outputs": [], | |||||
"source": [ | "source": [ | ||||
"train(net, train_data, test_data, 20, optimizer, criterion)" | |||||
"(l_train_loss, l_train_acc, l_valid_loss, l_valid_acc) = train(net, \n", | |||||
" train_data, test_data, \n", | |||||
" 20, \n", | |||||
" optimizer, criterion,\n", | |||||
" use_cuda=False)" | |||||
] | |||||
}, | |||||
{ | |||||
"cell_type": "code", | |||||
"execution_count": null, | |||||
"metadata": { | |||||
"collapsed": true | |||||
}, | |||||
"outputs": [], | |||||
"source": [ | |||||
"import matplotlib.pyplot as plt\n", | |||||
"%matplotlib inline\n", | |||||
"\n", | |||||
"plt.plot(l_train_loss, label='train')\n", | |||||
"plt.plot(l_valid_loss, label='valid')\n", | |||||
"plt.xlabel('epoch')\n", | |||||
"plt.legend(loc='best')\n", | |||||
"plt.savefig('fig-res-vgg-train-validate-loss.pdf')\n", | |||||
"plt.show()\n", | |||||
"\n", | |||||
"plt.plot(l_train_acc, label='train')\n", | |||||
"plt.plot(l_valid_acc, label='valid')\n", | |||||
"plt.xlabel('epoch')\n", | |||||
"plt.legend(loc='best')\n", | |||||
"plt.savefig('fig-res-vgg-train-validate-acc.pdf')\n", | |||||
"plt.show()" | |||||
] | ] | ||||
}, | }, | ||||
{ | { |
@@ -403,7 +403,7 @@ | |||||
"name": "python", | "name": "python", | ||||
"nbconvert_exporter": "python", | "nbconvert_exporter": "python", | ||||
"pygments_lexer": "ipython3", | "pygments_lexer": "ipython3", | ||||
"version": "3.7.9" | |||||
"version": "3.5.4" | |||||
} | } | ||||
}, | }, | ||||
"nbformat": 4, | "nbformat": 4, |
@@ -13,16 +13,22 @@ def get_acc(output, label): | |||||
return num_correct / total | return num_correct / total | ||||
def train(net, train_data, valid_data, num_epochs, optimizer, criterion): | |||||
if torch.cuda.is_available(): | |||||
def train(net, train_data, valid_data, num_epochs, optimizer, criterion, use_cuda=True): | |||||
if use_cuda and torch.cuda.is_available(): | |||||
net = net.cuda() | net = net.cuda() | ||||
l_train_loss = [] | |||||
l_train_acc = [] | |||||
l_valid_loss = [] | |||||
l_valid_acc = [] | |||||
prev_time = datetime.now() | prev_time = datetime.now() | ||||
for epoch in range(num_epochs): | for epoch in range(num_epochs): | ||||
train_loss = 0 | train_loss = 0 | ||||
train_acc = 0 | train_acc = 0 | ||||
net = net.train() | net = net.train() | ||||
for im, label in train_data: | for im, label in train_data: | ||||
if torch.cuda.is_available(): | |||||
if use_cuda and torch.cuda.is_available(): | |||||
im = Variable(im.cuda()) # (bs, 3, h, w) | im = Variable(im.cuda()) # (bs, 3, h, w) | ||||
label = Variable(label.cuda()) # (bs, h, w) | label = Variable(label.cuda()) # (bs, h, w) | ||||
else: | else: | ||||
@@ -48,7 +54,7 @@ def train(net, train_data, valid_data, num_epochs, optimizer, criterion): | |||||
valid_acc = 0 | valid_acc = 0 | ||||
net = net.eval() | net = net.eval() | ||||
for im, label in valid_data: | for im, label in valid_data: | ||||
if torch.cuda.is_available(): | |||||
if use_cuda and torch.cuda.is_available(): | |||||
im = Variable(im.cuda(), volatile=True) | im = Variable(im.cuda(), volatile=True) | ||||
label = Variable(label.cuda(), volatile=True) | label = Variable(label.cuda(), volatile=True) | ||||
else: | else: | ||||
@@ -63,13 +69,21 @@ def train(net, train_data, valid_data, num_epochs, optimizer, criterion): | |||||
% (epoch, train_loss / len(train_data), | % (epoch, train_loss / len(train_data), | ||||
train_acc / len(train_data), valid_loss / len(valid_data), | train_acc / len(train_data), valid_loss / len(valid_data), | ||||
valid_acc / len(valid_data))) | valid_acc / len(valid_data))) | ||||
l_valid_acc.append(valid_acc / len(valid_data)) | |||||
l_valid_loss.append(valid_loss / len(valid_data)) | |||||
else: | else: | ||||
epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " % | epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " % | ||||
(epoch, train_loss / len(train_data), | (epoch, train_loss / len(train_data), | ||||
train_acc / len(train_data))) | train_acc / len(train_data))) | ||||
l_train_acc.append(train_acc / len(train_data)) | |||||
l_train_loss.append(train_loss / len(train_data)) | |||||
prev_time = cur_time | prev_time = cur_time | ||||
print(epoch_str + time_str) | print(epoch_str + time_str) | ||||
return (l_train_loss, l_train_acc, l_valid_loss, l_valid_acc) | |||||
def conv3x3(in_channel, out_channel, stride=1): | def conv3x3(in_channel, out_channel, stride=1): | ||||
return nn.Conv2d( | return nn.Conv2d( | ||||
@@ -25,15 +25,17 @@ | |||||
- CNN | - CNN | ||||
- [CNN Introduction](1_CNN/CNN_Introduction.pptx) | - [CNN Introduction](1_CNN/CNN_Introduction.pptx) | ||||
- [CNN simple demo](../demo_code/3_CNN_MNIST.py) | - [CNN simple demo](../demo_code/3_CNN_MNIST.py) | ||||
- [Basic of Conv](1_CNN/1-basic_conv.ipynb) | |||||
- [VGG Network](1_CNN/2-vgg.ipynb) | |||||
- [GoogleNet](1_CNN/3-googlenet.ipynb) | |||||
- [ResNet](1_CNN/4-resnet.ipynb) | |||||
- [DenseNet](1_CNN/5-densenet.ipynb) | |||||
- [Batch Normalization](1_CNN/6-batch-normalization.ipynb) | |||||
- [Learning Rate Decay](1_CNN/7-lr-decay.ipynb) | |||||
- [Regularization](1_CNN/8-regularization.ipynb) | |||||
- [Data Augumentation](1_CNN/9-data-augumentation.ipynb) | |||||
- [Basic of Conv](1_CNN/01-basic_conv.ipynb) | |||||
- [LeNet5](1_CNN/02-LeNet5.ipynb) | |||||
- [AlexNet](1_CNN/03-AlexNet.ipynb) | |||||
- [VGG Network](1_CNN/04-vgg.ipynb) | |||||
- [GoogleNet](1_CNN/05-googlenet.ipynb) | |||||
- [ResNet](1_CNN/06-resnet.ipynb) | |||||
- [DenseNet](1_CNN/07-densenet.ipynb) | |||||
- [Batch Normalization](1_CNN/08-batch-normalization.ipynb) | |||||
- [Learning Rate Decay](1_CNN/09-lr-decay.ipynb) | |||||
- [Regularization](1_CNN/10-regularization.ipynb) | |||||
- [Data Augumentation](1_CNN/11-data-augumentation.ipynb) | |||||
- RNN | - RNN | ||||
- [rnn/pytorch-rnn](2_RNN/pytorch-rnn.ipynb) | - [rnn/pytorch-rnn](2_RNN/pytorch-rnn.ipynb) | ||||
- [rnn/rnn-for-image](2_RNN/rnn-for-image.ipynb) | - [rnn/rnn-for-image](2_RNN/rnn-for-image.ipynb) | ||||
@@ -52,15 +52,17 @@ | |||||
- CNN | - CNN | ||||
- [CNN Introduction](7_deep_learning/1_CNN/CNN_Introduction.pptx) | - [CNN Introduction](7_deep_learning/1_CNN/CNN_Introduction.pptx) | ||||
- [CNN simple demo](demo_code/3_CNN_MNIST.py) | - [CNN simple demo](demo_code/3_CNN_MNIST.py) | ||||
- [Basic of Conv](7_deep_learning/1_CNN/1-basic_conv.ipynb) | |||||
- [VGG Network](7_deep_learning/1_CNN/2-vgg.ipynb) | |||||
- [GoogleNet](7_deep_learning/1_CNN/3-googlenet.ipynb) | |||||
- [ResNet](7_deep_learning/1_CNN/4-resnet.ipynb) | |||||
- [DenseNet](7_deep_learning/1_CNN/5-densenet.ipynb) | |||||
- [Batch Normalization](7_deep_learning/1_CNN/6-batch-normalization.ipynb) | |||||
- [Learning Rate Decay](7_deep_learning/1_CNN/7-lr-decay.ipynb) | |||||
- [Regularization](7_deep_learning/1_CNN/8-regularization.ipynb) | |||||
- [Data Augumentation](7_deep_learning/1_CNN/9-data-augumentation.ipynb) | |||||
- [Basic of Conv](7_deep_learning/1_CNN/01-basic_conv.ipynb) | |||||
- [LeNet5](7_deep_learning/1_CNN/02-LeNet5.ipynb) | |||||
- [AlexNet](7_deep_learning/1_CNN/03-AlexNet.ipynb) | |||||
- [VGG Network](7_deep_learning/1_CNN/04-vgg.ipynb) | |||||
- [GoogleNet](7_deep_learning/1_CNN/05-googlenet.ipynb) | |||||
- [ResNet](7_deep_learning/1_CNN/06-resnet.ipynb) | |||||
- [DenseNet](7_deep_learning/1_CNN/07-densenet.ipynb) | |||||
- [Batch Normalization](7_deep_learning/1_CNN/08-batch-normalization.ipynb) | |||||
- [Learning Rate Decay](7_deep_learning/1_CNN/09-lr-decay.ipynb) | |||||
- [Regularization](7_deep_learning/1_CNN/10-regularization.ipynb) | |||||
- [Data Augumentation](7_deep_learning/1_CNN/11-data-augumentation.ipynb) | |||||
- RNN | - RNN | ||||
- [rnn/pytorch-rnn](7_deep_learning/2_RNN/pytorch-rnn.ipynb) | - [rnn/pytorch-rnn](7_deep_learning/2_RNN/pytorch-rnn.ipynb) | ||||
- [rnn/rnn-for-image](7_deep_learning/2_RNN/rnn-for-image.ipynb) | - [rnn/rnn-for-image](7_deep_learning/2_RNN/rnn-for-image.ipynb) | ||||