You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

07-densenet.ipynb 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# DenseNet\n",
  8. "\n",
  9. "因为 ResNet 提出了跨层链接的思想,这直接影响了随后出现的卷积网络架构,其中最有名的就是 CVPR 2017 的 Best Paper,DenseNet。DenseNet 和 ResNet 不同在于 ResNet 是跨层求和,而 DenseNet 是跨层将特征在通道维度进行拼接,下面可以看看他们两者的图示:\n",
  10. "\n",
  11. "![cnn_vs_resnet_vs_densenet.png](images/cnn_vs_resnet_vs_densenet.png)"
  12. ]
  13. },
  14. {
  15. "cell_type": "markdown",
  16. "metadata": {},
  17. "source": [
  18. "第2张图是 ResNet,第3张图是 DenseNet,因为是在通道维度进行特征的拼接,所以底层的输出会保留进入所有后面的层,这能够更好的保证梯度的传播,同时能够使用低维的特征和高维的特征进行联合训练,能够得到更好的结果。\n",
  19. "\n",
  20. "DenseNet主要的优点包括:\n",
  21. "1. 减轻了vanishing-gradient(梯度消失)\n",
  22. "2. 加强了feature的传递\n",
  23. "3. 更有效地利用了feature\n",
  24. "4. 一定程度上较少了参数数量\n",
  25. "\n",
  26. "在深度学习网络中,随着网络深度的加深,梯度消失问题会愈加明显,目前很多论文都针对这个问题提出了解决方案,比如ResNet,Highway Networks,Stochastic depth,FractalNets等,尽管这些算法的网络结构有差别,但是核心都在于:**create short paths from early layers to later layers**。延续这个思路,那就是在保证网络中层与层之间最大程度的信息传输的前提下,直接将所有层连接起来。\n",
  27. "\n",
  28. "先放一个dense block的结构图。在传统的卷积神经网络中,如果你有L层,那么就会有L个连接,但是在DenseNet中,会有 **L(L+1)/2** 个连接。简单讲,就是每一层的输入来自前面所有层的输出。如下图:x0是input,H1的输入是x0(input),H2的输入是x0和x1(x1是H1的输出)……\n",
  29. "\n",
  30. "![DesNet_arch.png](images/DesNet_arch.png)"
  31. ]
  32. },
  33. {
  34. "cell_type": "markdown",
  35. "metadata": {},
  36. "source": [
  37. "## 1. Dense_Block\n",
  38. "DenseNet 主要由 Dense Block 构成,下面我们来实现一个 Densen Block"
  39. ]
  40. },
  41. {
  42. "cell_type": "code",
  43. "execution_count": 1,
  44. "metadata": {
  45. "ExecuteTime": {
  46. "end_time": "2017-12-22T15:38:31.113030Z",
  47. "start_time": "2017-12-22T15:38:30.612922Z"
  48. },
  49. "collapsed": true
  50. },
  51. "outputs": [],
  52. "source": [
  53. "import sys\n",
  54. "sys.path.append('..')\n",
  55. "\n",
  56. "import numpy as np\n",
  57. "import torch\n",
  58. "from torch import nn\n",
  59. "from torch.autograd import Variable\n",
  60. "from torchvision.datasets import CIFAR10"
  61. ]
  62. },
  63. {
  64. "cell_type": "markdown",
  65. "metadata": {},
  66. "source": [
  67. "首先定义一个卷积块,这个卷积块的顺序是 bn -> relu -> conv"
  68. ]
  69. },
  70. {
  71. "cell_type": "code",
  72. "execution_count": 2,
  73. "metadata": {
  74. "ExecuteTime": {
  75. "end_time": "2017-12-22T15:38:31.121249Z",
  76. "start_time": "2017-12-22T15:38:31.115369Z"
  77. },
  78. "collapsed": true
  79. },
  80. "outputs": [],
  81. "source": [
  82. "def Conv_Block(in_channel, out_channel):\n",
  83. " layer = nn.Sequential(\n",
  84. " nn.BatchNorm2d(in_channel),\n",
  85. " nn.ReLU(True),\n",
  86. " nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False)\n",
  87. " )\n",
  88. " return layer"
  89. ]
  90. },
  91. {
  92. "cell_type": "markdown",
  93. "metadata": {},
  94. "source": [
  95. "Dense Block 将每次的卷积的输出称为 `growth_rate`,因为如果输入是 `in_channel`,有 n 层,那么输出就是 `in_channel + n * growh_rate`"
  96. ]
  97. },
  98. {
  99. "cell_type": "code",
  100. "execution_count": 3,
  101. "metadata": {
  102. "ExecuteTime": {
  103. "end_time": "2017-12-22T15:38:31.145274Z",
  104. "start_time": "2017-12-22T15:38:31.123363Z"
  105. },
  106. "collapsed": true
  107. },
  108. "outputs": [],
  109. "source": [
  110. "class Dense_Block(nn.Module):\n",
  111. " def __init__(self, in_channel, growth_rate, num_layers):\n",
  112. " super(Dense_Block, self).__init__()\n",
  113. " block = []\n",
  114. " channel = in_channel\n",
  115. " for i in range(num_layers):\n",
  116. " block.append(Conv_Block(channel, growth_rate))\n",
  117. " channel += growth_rate\n",
  118. " \n",
  119. " self.net = nn.Sequential(*block)\n",
  120. " \n",
  121. " def forward(self, x):\n",
  122. " for layer in self.net:\n",
  123. " out = layer(x)\n",
  124. " x = torch.cat((out, x), dim=1)\n",
  125. " return x"
  126. ]
  127. },
  128. {
  129. "cell_type": "markdown",
  130. "metadata": {},
  131. "source": [
  132. "我们验证一下输出的 channel 是否正确"
  133. ]
  134. },
  135. {
  136. "cell_type": "code",
  137. "execution_count": 4,
  138. "metadata": {
  139. "ExecuteTime": {
  140. "end_time": "2017-12-22T15:38:31.213632Z",
  141. "start_time": "2017-12-22T15:38:31.147196Z"
  142. }
  143. },
  144. "outputs": [
  145. {
  146. "name": "stdout",
  147. "output_type": "stream",
  148. "text": [
  149. "input shape: 3 x 96 x 96\n",
  150. "output shape: 39 x 96 x 96\n"
  151. ]
  152. }
  153. ],
  154. "source": [
  155. "test_net = dense_block(3, 12, 3)\n",
  156. "test_x = Variable(torch.zeros(1, 3, 96, 96))\n",
  157. "print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))\n",
  158. "test_y = test_net(test_x)\n",
  159. "print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))"
  160. ]
  161. },
  162. {
  163. "cell_type": "markdown",
  164. "metadata": {},
  165. "source": [
  166. "除了 dense block,DenseNet 中还有一个模块叫过渡层(transition block),因为 DenseNet 会不断地对维度进行拼接, 所以当层数很高的时候,输出的通道数就会越来越大,参数和计算量也会越来越大,为了避免这个问题,需要引入过渡层将输出通道降低下来,同时也将输入的长宽减半,这个过渡层可以使用 1 x 1 的卷积"
  167. ]
  168. },
  169. {
  170. "cell_type": "code",
  171. "execution_count": 5,
  172. "metadata": {
  173. "ExecuteTime": {
  174. "end_time": "2017-12-22T15:38:31.222120Z",
  175. "start_time": "2017-12-22T15:38:31.215770Z"
  176. },
  177. "collapsed": true
  178. },
  179. "outputs": [],
  180. "source": [
  181. "def transition(in_channel, out_channel):\n",
  182. " trans_layer = nn.Sequential(\n",
  183. " nn.BatchNorm2d(in_channel),\n",
  184. " nn.ReLU(True),\n",
  185. " nn.Conv2d(in_channel, out_channel, 1),\n",
  186. " nn.AvgPool2d(2, 2)\n",
  187. " )\n",
  188. " return trans_layer"
  189. ]
  190. },
  191. {
  192. "cell_type": "markdown",
  193. "metadata": {},
  194. "source": [
  195. "验证一下过渡层是否正确"
  196. ]
  197. },
  198. {
  199. "cell_type": "code",
  200. "execution_count": 6,
  201. "metadata": {
  202. "ExecuteTime": {
  203. "end_time": "2017-12-22T15:38:31.234846Z",
  204. "start_time": "2017-12-22T15:38:31.224078Z"
  205. }
  206. },
  207. "outputs": [
  208. {
  209. "name": "stdout",
  210. "output_type": "stream",
  211. "text": [
  212. "input shape: 3 x 96 x 96\n",
  213. "output shape: 12 x 48 x 48\n"
  214. ]
  215. }
  216. ],
  217. "source": [
  218. "test_net = transition(3, 12)\n",
  219. "test_x = Variable(torch.zeros(1, 3, 96, 96))\n",
  220. "print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))\n",
  221. "test_y = test_net(test_x)\n",
  222. "print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))"
  223. ]
  224. },
  225. {
  226. "cell_type": "markdown",
  227. "metadata": {},
  228. "source": [
  229. "## 2. DenseNet\n",
  230. "\n",
  231. "最后我们定义 DenseNet"
  232. ]
  233. },
  234. {
  235. "cell_type": "code",
  236. "execution_count": 7,
  237. "metadata": {
  238. "ExecuteTime": {
  239. "end_time": "2017-12-22T15:38:31.318822Z",
  240. "start_time": "2017-12-22T15:38:31.236857Z"
  241. },
  242. "collapsed": true
  243. },
  244. "outputs": [],
  245. "source": [
  246. "class densenet(nn.Module):\n",
  247. " def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12, 24, 16]):\n",
  248. " super(densenet, self).__init__()\n",
  249. " self.block1 = nn.Sequential(\n",
  250. " nn.Conv2d(in_channel, 64, 7, 2, 3),\n",
  251. " nn.BatchNorm2d(64),\n",
  252. " nn.ReLU(True),\n",
  253. " nn.MaxPool2d(3, 2, padding=1)\n",
  254. " )\n",
  255. " \n",
  256. " channels = 64\n",
  257. " block = []\n",
  258. " for i, layers in enumerate(block_layers):\n",
  259. " block.append(dense_block(channels, growth_rate, layers))\n",
  260. " channels += layers * growth_rate\n",
  261. " if i != len(block_layers) - 1:\n",
  262. " block.append(transition(channels, channels // 2)) # 通过 transition 层将大小减半,通道数减半\n",
  263. " channels = channels // 2\n",
  264. " \n",
  265. " self.block2 = nn.Sequential(*block)\n",
  266. " self.block2.add_module('bn', nn.BatchNorm2d(channels))\n",
  267. " self.block2.add_module('relu', nn.ReLU(True))\n",
  268. " self.block2.add_module('avg_pool', nn.AvgPool2d(3))\n",
  269. " \n",
  270. " self.classifier = nn.Linear(channels, num_classes)\n",
  271. " \n",
  272. " def forward(self, x):\n",
  273. " x = self.block1(x)\n",
  274. " x = self.block2(x)\n",
  275. " \n",
  276. " x = x.view(x.shape[0], -1)\n",
  277. " x = self.classifier(x)\n",
  278. " return x"
  279. ]
  280. },
  281. {
  282. "cell_type": "code",
  283. "execution_count": 8,
  284. "metadata": {
  285. "ExecuteTime": {
  286. "end_time": "2017-12-22T15:38:31.654182Z",
  287. "start_time": "2017-12-22T15:38:31.320788Z"
  288. }
  289. },
  290. "outputs": [
  291. {
  292. "name": "stdout",
  293. "output_type": "stream",
  294. "text": [
  295. "output: torch.Size([1, 10])\n"
  296. ]
  297. }
  298. ],
  299. "source": [
  300. "test_net = densenet(3, 10)\n",
  301. "test_x = Variable(torch.zeros(1, 3, 96, 96))\n",
  302. "test_y = test_net(test_x)\n",
  303. "print('output: {}'.format(test_y.shape))"
  304. ]
  305. },
  306. {
  307. "cell_type": "code",
  308. "execution_count": 9,
  309. "metadata": {
  310. "ExecuteTime": {
  311. "end_time": "2017-12-22T15:38:32.894729Z",
  312. "start_time": "2017-12-22T15:38:31.656356Z"
  313. },
  314. "collapsed": true
  315. },
  316. "outputs": [],
  317. "source": [
  318. "from utils import train\n",
  319. "\n",
  320. "def data_tf(x):\n",
  321. " x = x.resize((96, 96), 2) # 将图片放大到 96 x 96\n",
  322. " x = np.array(x, dtype='float32') / 255\n",
  323. " x = (x - 0.5) / 0.5 # 标准化,这个技巧之后会讲到\n",
  324. " x = x.transpose((2, 0, 1)) # 将 channel 放到第一维,只是 pytorch 要求的输入方式\n",
  325. " x = torch.from_numpy(x)\n",
  326. " return x\n",
  327. " \n",
  328. "train_set = CIFAR10('../../data', train=True, transform=data_tf)\n",
  329. "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
  330. "test_set = CIFAR10('../../data', train=False, transform=data_tf)\n",
  331. "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n",
  332. "\n",
  333. "net = densenet(3, 10)\n",
  334. "optimizer = torch.optim.SGD(net.parameters(), lr=0.01)\n",
  335. "criterion = nn.CrossEntropyLoss()"
  336. ]
  337. },
  338. {
  339. "cell_type": "code",
  340. "execution_count": 10,
  341. "metadata": {
  342. "ExecuteTime": {
  343. "end_time": "2017-12-22T16:15:38.168095Z",
  344. "start_time": "2017-12-22T15:38:32.896735Z"
  345. }
  346. },
  347. "outputs": [
  348. {
  349. "name": "stdout",
  350. "output_type": "stream",
  351. "text": [
  352. "Epoch 0. Train Loss: 1.374316, Train Acc: 0.507972, Valid Loss: 1.203217, Valid Acc: 0.572884, Time 00:01:44\n",
  353. "Epoch 1. Train Loss: 0.912924, Train Acc: 0.681506, Valid Loss: 1.555908, Valid Acc: 0.492286, Time 00:01:50\n",
  354. "Epoch 2. Train Loss: 0.701387, Train Acc: 0.755794, Valid Loss: 0.815147, Valid Acc: 0.718354, Time 00:01:49\n",
  355. "Epoch 3. Train Loss: 0.575985, Train Acc: 0.800911, Valid Loss: 0.696013, Valid Acc: 0.759494, Time 00:01:50\n",
  356. "Epoch 4. Train Loss: 0.479812, Train Acc: 0.836957, Valid Loss: 1.013879, Valid Acc: 0.676226, Time 00:01:51\n",
  357. "Epoch 5. Train Loss: 0.402165, Train Acc: 0.861413, Valid Loss: 0.674512, Valid Acc: 0.778481, Time 00:01:50\n",
  358. "Epoch 6. Train Loss: 0.334593, Train Acc: 0.888247, Valid Loss: 0.647112, Valid Acc: 0.791634, Time 00:01:50\n",
  359. "Epoch 7. Train Loss: 0.278181, Train Acc: 0.907149, Valid Loss: 0.773517, Valid Acc: 0.756527, Time 00:01:51\n",
  360. "Epoch 8. Train Loss: 0.227948, Train Acc: 0.922714, Valid Loss: 0.654399, Valid Acc: 0.800237, Time 00:01:49\n",
  361. "Epoch 9. Train Loss: 0.181156, Train Acc: 0.940157, Valid Loss: 1.179013, Valid Acc: 0.685225, Time 00:01:50\n",
  362. "Epoch 10. Train Loss: 0.151305, Train Acc: 0.950208, Valid Loss: 0.630000, Valid Acc: 0.807951, Time 00:01:50\n",
  363. "Epoch 11. Train Loss: 0.118433, Train Acc: 0.961077, Valid Loss: 1.247253, Valid Acc: 0.703323, Time 00:01:52\n",
  364. "Epoch 12. Train Loss: 0.094127, Train Acc: 0.969789, Valid Loss: 1.230697, Valid Acc: 0.723101, Time 00:01:51\n",
  365. "Epoch 13. Train Loss: 0.086181, Train Acc: 0.972047, Valid Loss: 0.904135, Valid Acc: 0.769284, Time 00:01:50\n",
  366. "Epoch 14. Train Loss: 0.064248, Train Acc: 0.980359, Valid Loss: 1.665002, Valid Acc: 0.624209, Time 00:01:51\n",
  367. "Epoch 15. Train Loss: 0.054932, Train Acc: 0.982996, Valid Loss: 0.927216, Valid Acc: 0.774723, Time 00:01:51\n",
  368. "Epoch 16. Train Loss: 0.043503, Train Acc: 0.987272, Valid Loss: 1.574383, Valid Acc: 0.707377, Time 00:01:52\n",
  369. "Epoch 17. Train Loss: 0.047615, Train Acc: 0.985154, Valid Loss: 0.987781, Valid Acc: 0.770471, Time 00:01:51\n",
  370. "Epoch 18. Train Loss: 0.039813, Train Acc: 0.988012, Valid Loss: 2.248944, Valid Acc: 0.631824, Time 00:01:50\n",
  371. "Epoch 19. Train Loss: 0.030183, Train Acc: 0.991168, Valid Loss: 0.887785, Valid Acc: 0.795392, Time 00:01:51\n"
  372. ]
  373. }
  374. ],
  375. "source": [
  376. "train(net, train_data, test_data, 20, optimizer, criterion)"
  377. ]
  378. },
  379. {
  380. "cell_type": "markdown",
  381. "metadata": {},
  382. "source": [
  383. "DenseNet 将残差连接改为了特征拼接,使得网络有了更稠密的连接"
  384. ]
  385. },
  386. {
  387. "cell_type": "markdown",
  388. "metadata": {},
  389. "source": [
  390. "## 参考\n",
  391. "* [DenseNet算法详解](https://blog.csdn.net/u014380165/article/details/75142664)\n",
  392. "* [DenseNet详解](https://zhuanlan.zhihu.com/p/43057737)"
  393. ]
  394. }
  395. ],
  396. "metadata": {
  397. "kernelspec": {
  398. "display_name": "Python 3",
  399. "language": "python",
  400. "name": "python3"
  401. },
  402. "language_info": {
  403. "codemirror_mode": {
  404. "name": "ipython",
  405. "version": 3
  406. },
  407. "file_extension": ".py",
  408. "mimetype": "text/x-python",
  409. "name": "python",
  410. "nbconvert_exporter": "python",
  411. "pygments_lexer": "ipython3",
  412. "version": "3.5.4"
  413. }
  414. },
  415. "nbformat": 4,
  416. "nbformat_minor": 2
  417. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。