You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

8-resnet.ipynb 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# ResNet\n",
  8. "当大家还在惊叹 GoogLeNet 的 inception 结构的时候,微软亚洲研究院的研究员已经在设计更深但结构更加简单的网络 ResNet,并且凭借这个网络子在 2015 年 ImageNet 比赛上大获全胜。\n",
  9. "\n",
  10. "ResNet 有效地解决了深度神经网络难以训练的问题,可以训练高达 1000 层的卷积网络。网络之所以难以训练,是因为存在着梯度消失的问题,离 loss 函数越远的层,在反向传播的时候,梯度越小,就越难以更新,随着层数的增加,这个现象越严重。之前有两种常见的方案来解决这个问题:\n",
  11. "\n",
  12. "1.按层训练,先训练比较浅的层,然后在不断增加层数,但是这种方法效果不是特别好,而且比较麻烦\n",
  13. "\n",
  14. "2.使用更宽的层,或者增加输出通道,而不加深网络的层数,这种结构往往得到的效果又不好\n",
  15. "\n",
  16. "ResNet 通过引入了跨层链接解决了梯度回传消失的问题。\n",
  17. "\n",
  18. "![](https://ws1.sinaimg.cn/large/006tNc79ly1fmptq2snv9j30j808t74a.jpg)"
  19. ]
  20. },
  21. {
  22. "cell_type": "markdown",
  23. "metadata": {},
  24. "source": [
  25. "这就普通的网络连接跟跨层残差连接的对比图,使用普通的连接,上层的梯度必须要一层一层传回来,而是用残差连接,相当于中间有了一条更短的路,梯度能够从这条更短的路传回来,避免了梯度过小的情况。\n",
  26. "\n",
  27. "假设某层的输入是 x,期望输出是 H(x), 如果我们直接把输入 x 传到输出作为初始结果,这就是一个更浅层的网络,更容易训练,而这个网络没有学会的部分,我们可以使用更深的网络 F(x) 去训练它,使得训练更加容易,最后希望拟合的结果就是 F(x) = H(x) - x,这就是一个残差的结构\n",
  28. "\n",
  29. "残差网络的结构就是上面这种残差块的堆叠,下面让我们来实现一个 residual block"
  30. ]
  31. },
  32. {
  33. "cell_type": "code",
  34. "execution_count": 5,
  35. "metadata": {
  36. "ExecuteTime": {
  37. "end_time": "2017-12-22T12:56:06.772059Z",
  38. "start_time": "2017-12-22T12:56:06.766027Z"
  39. }
  40. },
  41. "outputs": [],
  42. "source": [
  43. "import sys\n",
  44. "sys.path.append('..')\n",
  45. "\n",
  46. "import numpy as np\n",
  47. "import torch\n",
  48. "from torch import nn\n",
  49. "import torch.nn.functional as F\n",
  50. "from torch.autograd import Variable\n",
  51. "from torchvision.datasets import CIFAR10"
  52. ]
  53. },
  54. {
  55. "cell_type": "code",
  56. "execution_count": 6,
  57. "metadata": {
  58. "ExecuteTime": {
  59. "end_time": "2017-12-22T12:47:49.222432Z",
  60. "start_time": "2017-12-22T12:47:49.217940Z"
  61. }
  62. },
  63. "outputs": [],
  64. "source": [
  65. "def conv3x3(in_channel, out_channel, stride=1):\n",
  66. " return nn.Conv2d(in_channel, out_channel, 3, stride=stride, padding=1, bias=False)"
  67. ]
  68. },
  69. {
  70. "cell_type": "code",
  71. "execution_count": 7,
  72. "metadata": {
  73. "ExecuteTime": {
  74. "end_time": "2017-12-22T13:14:02.429145Z",
  75. "start_time": "2017-12-22T13:14:02.383322Z"
  76. }
  77. },
  78. "outputs": [],
  79. "source": [
  80. "class residual_block(nn.Module):\n",
  81. " def __init__(self, in_channel, out_channel, same_shape=True):\n",
  82. " super(residual_block, self).__init__()\n",
  83. " self.same_shape = same_shape\n",
  84. " stride=1 if self.same_shape else 2\n",
  85. " \n",
  86. " self.conv1 = conv3x3(in_channel, out_channel, stride=stride)\n",
  87. " self.bn1 = nn.BatchNorm2d(out_channel)\n",
  88. " \n",
  89. " self.conv2 = conv3x3(out_channel, out_channel)\n",
  90. " self.bn2 = nn.BatchNorm2d(out_channel)\n",
  91. " if not self.same_shape:\n",
  92. " self.conv3 = nn.Conv2d(in_channel, out_channel, 1, stride=stride)\n",
  93. " \n",
  94. " def forward(self, x):\n",
  95. " out = self.conv1(x)\n",
  96. " out = F.relu(self.bn1(out), True)\n",
  97. " out = self.conv2(out)\n",
  98. " out = F.relu(self.bn2(out), True)\n",
  99. " \n",
  100. " if not self.same_shape:\n",
  101. " x = self.conv3(x)\n",
  102. " return F.relu(x+out, True)"
  103. ]
  104. },
  105. {
  106. "cell_type": "markdown",
  107. "metadata": {},
  108. "source": [
  109. "我们测试一下一个 residual block 的输入和输出"
  110. ]
  111. },
  112. {
  113. "cell_type": "code",
  114. "execution_count": 8,
  115. "metadata": {
  116. "ExecuteTime": {
  117. "end_time": "2017-12-22T13:14:05.793185Z",
  118. "start_time": "2017-12-22T13:14:05.763382Z"
  119. }
  120. },
  121. "outputs": [
  122. {
  123. "name": "stdout",
  124. "output_type": "stream",
  125. "text": [
  126. "input: torch.Size([1, 32, 96, 96])\n",
  127. "output: torch.Size([1, 32, 96, 96])\n"
  128. ]
  129. }
  130. ],
  131. "source": [
  132. "# 输入输出形状相同\n",
  133. "test_net = residual_block(32, 32)\n",
  134. "test_x = Variable(torch.zeros(1, 32, 96, 96))\n",
  135. "print('input: {}'.format(test_x.shape))\n",
  136. "test_y = test_net(test_x)\n",
  137. "print('output: {}'.format(test_y.shape))"
  138. ]
  139. },
  140. {
  141. "cell_type": "code",
  142. "execution_count": 9,
  143. "metadata": {
  144. "ExecuteTime": {
  145. "end_time": "2017-12-22T13:14:11.929120Z",
  146. "start_time": "2017-12-22T13:14:11.914604Z"
  147. }
  148. },
  149. "outputs": [
  150. {
  151. "name": "stdout",
  152. "output_type": "stream",
  153. "text": [
  154. "input: torch.Size([1, 3, 96, 96])\n",
  155. "output: torch.Size([1, 32, 48, 48])\n"
  156. ]
  157. }
  158. ],
  159. "source": [
  160. "# 输入输出形状不同\n",
  161. "test_net = residual_block(3, 32, False)\n",
  162. "test_x = Variable(torch.zeros(1, 3, 96, 96))\n",
  163. "print('input: {}'.format(test_x.shape))\n",
  164. "test_y = test_net(test_x)\n",
  165. "print('output: {}'.format(test_y.shape))"
  166. ]
  167. },
  168. {
  169. "cell_type": "markdown",
  170. "metadata": {},
  171. "source": [
  172. "下面我们尝试实现一个 ResNet,它就是 residual block 模块的堆叠"
  173. ]
  174. },
  175. {
  176. "cell_type": "code",
  177. "execution_count": 10,
  178. "metadata": {
  179. "ExecuteTime": {
  180. "end_time": "2017-12-22T13:27:46.099404Z",
  181. "start_time": "2017-12-22T13:27:45.986235Z"
  182. }
  183. },
  184. "outputs": [],
  185. "source": [
  186. "class resnet(nn.Module):\n",
  187. " def __init__(self, in_channel, num_classes, verbose=False):\n",
  188. " super(resnet, self).__init__()\n",
  189. " self.verbose = verbose\n",
  190. " \n",
  191. " self.block1 = nn.Conv2d(in_channel, 64, 7, 2)\n",
  192. " \n",
  193. " self.block2 = nn.Sequential(\n",
  194. " nn.MaxPool2d(3, 2),\n",
  195. " residual_block(64, 64),\n",
  196. " residual_block(64, 64)\n",
  197. " )\n",
  198. " \n",
  199. " self.block3 = nn.Sequential(\n",
  200. " residual_block(64, 128, False),\n",
  201. " residual_block(128, 128)\n",
  202. " )\n",
  203. " \n",
  204. " self.block4 = nn.Sequential(\n",
  205. " residual_block(128, 256, False),\n",
  206. " residual_block(256, 256)\n",
  207. " )\n",
  208. " \n",
  209. " self.block5 = nn.Sequential(\n",
  210. " residual_block(256, 512, False),\n",
  211. " residual_block(512, 512),\n",
  212. " nn.AvgPool2d(3)\n",
  213. " )\n",
  214. " \n",
  215. " self.classifier = nn.Linear(512, num_classes)\n",
  216. " \n",
  217. " def forward(self, x):\n",
  218. " x = self.block1(x)\n",
  219. " if self.verbose:\n",
  220. " print('block 1 output: {}'.format(x.shape))\n",
  221. " x = self.block2(x)\n",
  222. " if self.verbose:\n",
  223. " print('block 2 output: {}'.format(x.shape))\n",
  224. " x = self.block3(x)\n",
  225. " if self.verbose:\n",
  226. " print('block 3 output: {}'.format(x.shape))\n",
  227. " x = self.block4(x)\n",
  228. " if self.verbose:\n",
  229. " print('block 4 output: {}'.format(x.shape))\n",
  230. " x = self.block5(x)\n",
  231. " if self.verbose:\n",
  232. " print('block 5 output: {}'.format(x.shape))\n",
  233. " x = x.view(x.shape[0], -1)\n",
  234. " x = self.classifier(x)\n",
  235. " return x"
  236. ]
  237. },
  238. {
  239. "cell_type": "markdown",
  240. "metadata": {},
  241. "source": [
  242. "输出一下每个 block 之后的大小"
  243. ]
  244. },
  245. {
  246. "cell_type": "code",
  247. "execution_count": 11,
  248. "metadata": {
  249. "ExecuteTime": {
  250. "end_time": "2017-12-22T13:28:00.597030Z",
  251. "start_time": "2017-12-22T13:28:00.417746Z"
  252. }
  253. },
  254. "outputs": [
  255. {
  256. "name": "stdout",
  257. "output_type": "stream",
  258. "text": [
  259. "block 1 output: torch.Size([1, 64, 45, 45])\n",
  260. "block 2 output: torch.Size([1, 64, 22, 22])\n",
  261. "block 3 output: torch.Size([1, 128, 11, 11])\n",
  262. "block 4 output: torch.Size([1, 256, 6, 6])\n",
  263. "block 5 output: torch.Size([1, 512, 1, 1])\n",
  264. "output: torch.Size([1, 10])\n"
  265. ]
  266. }
  267. ],
  268. "source": [
  269. "test_net = resnet(3, 10, True)\n",
  270. "test_x = Variable(torch.zeros(1, 3, 96, 96))\n",
  271. "test_y = test_net(test_x)\n",
  272. "print('output: {}'.format(test_y.shape))"
  273. ]
  274. },
  275. {
  276. "cell_type": "code",
  277. "execution_count": 32,
  278. "metadata": {
  279. "ExecuteTime": {
  280. "end_time": "2017-12-22T13:29:01.484172Z",
  281. "start_time": "2017-12-22T13:29:00.095952Z"
  282. },
  283. "collapsed": true
  284. },
  285. "outputs": [],
  286. "source": [
  287. "from utils import train\n",
  288. "\n",
  289. "def data_tf(x):\n",
  290. " x = x.resize((96, 96), 2) # 将图片放大到 96 x 96\n",
  291. " x = np.array(x, dtype='float32') / 255\n",
  292. " x = (x - 0.5) / 0.5 # 标准化,这个技巧之后会讲到\n",
  293. " x = x.transpose((2, 0, 1)) # 将 channel 放到第一维,只是 pytorch 要求的输入方式\n",
  294. " x = torch.from_numpy(x)\n",
  295. " return x\n",
  296. " \n",
  297. "train_set = CIFAR10('./data', train=True, transform=data_tf)\n",
  298. "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
  299. "test_set = CIFAR10('./data', train=False, transform=data_tf)\n",
  300. "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n",
  301. "\n",
  302. "net = resnet(3, 10)\n",
  303. "optimizer = torch.optim.SGD(net.parameters(), lr=0.01)\n",
  304. "criterion = nn.CrossEntropyLoss()"
  305. ]
  306. },
  307. {
  308. "cell_type": "code",
  309. "execution_count": 33,
  310. "metadata": {
  311. "ExecuteTime": {
  312. "end_time": "2017-12-22T13:45:00.783186Z",
  313. "start_time": "2017-12-22T13:29:09.214453Z"
  314. }
  315. },
  316. "outputs": [
  317. {
  318. "name": "stdout",
  319. "output_type": "stream",
  320. "text": [
  321. "Epoch 0. Train Loss: 1.437317, Train Acc: 0.476662, Valid Loss: 1.928288, Valid Acc: 0.384691, Time 00:00:44\n",
  322. "Epoch 1. Train Loss: 0.992832, Train Acc: 0.648198, Valid Loss: 1.009847, Valid Acc: 0.642405, Time 00:00:48\n",
  323. "Epoch 2. Train Loss: 0.767309, Train Acc: 0.732617, Valid Loss: 1.827319, Valid Acc: 0.430380, Time 00:00:47\n",
  324. "Epoch 3. Train Loss: 0.606737, Train Acc: 0.788043, Valid Loss: 1.304808, Valid Acc: 0.585245, Time 00:00:46\n",
  325. "Epoch 4. Train Loss: 0.484436, Train Acc: 0.834499, Valid Loss: 1.335749, Valid Acc: 0.617089, Time 00:00:47\n",
  326. "Epoch 5. Train Loss: 0.374320, Train Acc: 0.872922, Valid Loss: 0.878519, Valid Acc: 0.724288, Time 00:00:47\n",
  327. "Epoch 6. Train Loss: 0.280981, Train Acc: 0.904212, Valid Loss: 0.931616, Valid Acc: 0.716871, Time 00:00:48\n",
  328. "Epoch 7. Train Loss: 0.210800, Train Acc: 0.929747, Valid Loss: 1.448870, Valid Acc: 0.638548, Time 00:00:48\n",
  329. "Epoch 8. Train Loss: 0.147873, Train Acc: 0.951427, Valid Loss: 1.356992, Valid Acc: 0.657536, Time 00:00:47\n",
  330. "Epoch 9. Train Loss: 0.112824, Train Acc: 0.963895, Valid Loss: 1.630560, Valid Acc: 0.627769, Time 00:00:47\n",
  331. "Epoch 10. Train Loss: 0.082685, Train Acc: 0.973905, Valid Loss: 0.982882, Valid Acc: 0.744264, Time 00:00:44\n",
  332. "Epoch 11. Train Loss: 0.065325, Train Acc: 0.979680, Valid Loss: 0.911631, Valid Acc: 0.767009, Time 00:00:47\n",
  333. "Epoch 12. Train Loss: 0.041401, Train Acc: 0.987952, Valid Loss: 1.167992, Valid Acc: 0.729826, Time 00:00:48\n",
  334. "Epoch 13. Train Loss: 0.037516, Train Acc: 0.989011, Valid Loss: 1.081807, Valid Acc: 0.746737, Time 00:00:47\n",
  335. "Epoch 14. Train Loss: 0.030674, Train Acc: 0.991468, Valid Loss: 0.935292, Valid Acc: 0.774031, Time 00:00:45\n",
  336. "Epoch 15. Train Loss: 0.021743, Train Acc: 0.994565, Valid Loss: 0.879348, Valid Acc: 0.790150, Time 00:00:47\n",
  337. "Epoch 16. Train Loss: 0.014642, Train Acc: 0.996463, Valid Loss: 1.328587, Valid Acc: 0.724387, Time 00:00:47\n",
  338. "Epoch 17. Train Loss: 0.011072, Train Acc: 0.997363, Valid Loss: 0.909065, Valid Acc: 0.792919, Time 00:00:47\n",
  339. "Epoch 18. Train Loss: 0.006870, Train Acc: 0.998561, Valid Loss: 0.923746, Valid Acc: 0.794403, Time 00:00:46\n",
  340. "Epoch 19. Train Loss: 0.004240, Train Acc: 0.999500, Valid Loss: 0.877908, Valid Acc: 0.802314, Time 00:00:46\n"
  341. ]
  342. }
  343. ],
  344. "source": [
  345. "train(net, train_data, test_data, 20, optimizer, criterion)"
  346. ]
  347. },
  348. {
  349. "cell_type": "markdown",
  350. "metadata": {},
  351. "source": [
  352. "ResNet 使用跨层通道使得训练非常深的卷积神经网络成为可能。同样它使用很简单的卷积层配置,使得其拓展更加简单。\n",
  353. "\n",
  354. "**小练习: \n",
  355. "1.尝试一下论文中提出的 bottleneck 的结构 \n",
  356. "2.尝试改变 conv -> bn -> relu 的顺序为 bn -> relu -> conv,看看精度会不会提高**"
  357. ]
  358. }
  359. ],
  360. "metadata": {
  361. "kernelspec": {
  362. "display_name": "Python 3",
  363. "language": "python",
  364. "name": "python3"
  365. },
  366. "language_info": {
  367. "codemirror_mode": {
  368. "name": "ipython",
  369. "version": 3
  370. },
  371. "file_extension": ".py",
  372. "mimetype": "text/x-python",
  373. "name": "python",
  374. "nbconvert_exporter": "python",
  375. "pygments_lexer": "ipython3",
  376. "version": "3.6.9"
  377. }
  378. },
  379. "nbformat": 4,
  380. "nbformat_minor": 2
  381. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。