You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

PyTorch_quick_intro.ipynb 63 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
6 years ago
6 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
6 years ago
4 years ago
6 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
6 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
6 years ago
4 years ago
6 years ago
6 years ago
4 years ago
6 years ago
6 years ago
4 years ago
6 years ago
4 years ago
6 years ago
6 years ago
4 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
4 years ago
6 years ago
6 years ago
6 years ago
4 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
4 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# PyTorch快速入门\n",
  8. "\n",
  9. "PyTorch的简洁设计使得它入门很简单,在深入介绍PyTorch之前,本节将先介绍一些PyTorch的基础知识,使得读者能够对PyTorch有一个大致的了解,并能够用PyTorch搭建一个简单的神经网络。部分内容读者可能暂时不太理解,可先不予以深究,后续的课程将会对此进行深入讲解。\n",
  10. "\n",
  11. "本节内容参考了PyTorch官方教程[^1]并做了相应的增删修改,使得内容更贴合新版本的PyTorch接口,同时也更适合新手快速入门。另外本书需要读者先掌握基础的Numpy使用,其他相关知识推荐读者参考CS231n的教程[^2]。\n",
  12. "\n",
  13. "[^1]: http://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html\n",
  14. "[^2]: http://cs231n.github.io/python-numpy-tutorial/"
  15. ]
  16. },
  17. {
  18. "cell_type": "markdown",
  19. "metadata": {},
  20. "source": [
  21. "## 1. Tensor\n",
  22. "\n",
  23. "Tensor是PyTorch中重要的数据结构,可认为是一个高维数组。它可以是一个数(标量)、一维数组(向量)、二维数组(矩阵)以及更高维的数组。Tensor和Numpy的ndarrays类似,但Tensor可以使用GPU进行加速。Tensor的使用和Numpy及Matlab的接口十分相似,下面通过几个例子来看看Tensor的基本使用。"
  24. ]
  25. },
  26. {
  27. "cell_type": "code",
  28. "execution_count": 1,
  29. "metadata": {},
  30. "outputs": [],
  31. "source": [
  32. "from __future__ import print_function\n",
  33. "import torch as t"
  34. ]
  35. },
  36. {
  37. "cell_type": "code",
  38. "execution_count": 2,
  39. "metadata": {},
  40. "outputs": [
  41. {
  42. "data": {
  43. "text/plain": [
  44. "tensor([[0., 0., 0.],\n",
  45. " [0., 0., 0.],\n",
  46. " [0., 0., 0.],\n",
  47. " [0., 0., 0.],\n",
  48. " [0., 0., 0.]])"
  49. ]
  50. },
  51. "execution_count": 2,
  52. "metadata": {},
  53. "output_type": "execute_result"
  54. }
  55. ],
  56. "source": [
  57. "# 构建 5x3 矩阵,只是分配了空间,未初始化\n",
  58. "x = t.Tensor(5, 3) \n",
  59. "x"
  60. ]
  61. },
  62. {
  63. "cell_type": "code",
  64. "execution_count": 3,
  65. "metadata": {},
  66. "outputs": [
  67. {
  68. "data": {
  69. "text/plain": [
  70. "tensor([[0.3807, 0.4897, 0.0356],\n",
  71. " [0.6701, 0.0606, 0.1818],\n",
  72. " [0.8798, 0.7115, 0.8265],\n",
  73. " [0.4094, 0.2264, 0.2041],\n",
  74. " [0.9088, 0.9256, 0.3438]])"
  75. ]
  76. },
  77. "execution_count": 3,
  78. "metadata": {},
  79. "output_type": "execute_result"
  80. }
  81. ],
  82. "source": [
  83. "# 使用[0,1]均匀分布随机初始化二维数组\n",
  84. "x = t.rand(5, 3) \n",
  85. "x"
  86. ]
  87. },
  88. {
  89. "cell_type": "code",
  90. "execution_count": 4,
  91. "metadata": {},
  92. "outputs": [
  93. {
  94. "name": "stdout",
  95. "output_type": "stream",
  96. "text": [
  97. "torch.Size([5, 3])\n"
  98. ]
  99. },
  100. {
  101. "data": {
  102. "text/plain": [
  103. "(3, 3)"
  104. ]
  105. },
  106. "execution_count": 4,
  107. "metadata": {},
  108. "output_type": "execute_result"
  109. }
  110. ],
  111. "source": [
  112. "print(x.size()) # 查看x的形状\n",
  113. "x.size()[1], x.size(1) # 查看列的个数, 两种写法等价"
  114. ]
  115. },
  116. {
  117. "cell_type": "markdown",
  118. "metadata": {},
  119. "source": [
  120. "`torch.Size` 是tuple对象的子类,因此它支持tuple的所有操作,如x.size()[0]"
  121. ]
  122. },
  123. {
  124. "cell_type": "code",
  125. "execution_count": 5,
  126. "metadata": {},
  127. "outputs": [
  128. {
  129. "data": {
  130. "text/plain": [
  131. "tensor([[1.1361, 1.4054, 0.9468],\n",
  132. " [1.6410, 0.5193, 0.3720],\n",
  133. " [0.9482, 1.6716, 1.4168],\n",
  134. " [1.3925, 0.9253, 0.2908],\n",
  135. " [1.4907, 1.7178, 0.7246]])"
  136. ]
  137. },
  138. "execution_count": 5,
  139. "metadata": {},
  140. "output_type": "execute_result"
  141. }
  142. ],
  143. "source": [
  144. "y = t.rand(5, 3)\n",
  145. "# 加法的第一种写法\n",
  146. "x + y"
  147. ]
  148. },
  149. {
  150. "cell_type": "code",
  151. "execution_count": 6,
  152. "metadata": {},
  153. "outputs": [
  154. {
  155. "data": {
  156. "text/plain": [
  157. "tensor([[1.1361, 1.4054, 0.9468],\n",
  158. " [1.6410, 0.5193, 0.3720],\n",
  159. " [0.9482, 1.6716, 1.4168],\n",
  160. " [1.3925, 0.9253, 0.2908],\n",
  161. " [1.4907, 1.7178, 0.7246]])"
  162. ]
  163. },
  164. "execution_count": 6,
  165. "metadata": {},
  166. "output_type": "execute_result"
  167. }
  168. ],
  169. "source": [
  170. "# 加法的第二种写法\n",
  171. "t.add(x, y)"
  172. ]
  173. },
  174. {
  175. "cell_type": "code",
  176. "execution_count": 7,
  177. "metadata": {},
  178. "outputs": [
  179. {
  180. "data": {
  181. "text/plain": [
  182. "tensor([[1.1361, 1.4054, 0.9468],\n",
  183. " [1.6410, 0.5193, 0.3720],\n",
  184. " [0.9482, 1.6716, 1.4168],\n",
  185. " [1.3925, 0.9253, 0.2908],\n",
  186. " [1.4907, 1.7178, 0.7246]])"
  187. ]
  188. },
  189. "execution_count": 7,
  190. "metadata": {},
  191. "output_type": "execute_result"
  192. }
  193. ],
  194. "source": [
  195. "# 加法的第三种写法:指定加法结果的输出目标为result\n",
  196. "result = t.Tensor(5, 3) # 预先分配空间\n",
  197. "t.add(x, y, out=result) # 输入到result\n",
  198. "result"
  199. ]
  200. },
  201. {
  202. "cell_type": "code",
  203. "execution_count": 8,
  204. "metadata": {},
  205. "outputs": [
  206. {
  207. "name": "stdout",
  208. "output_type": "stream",
  209. "text": [
  210. "最初y\n",
  211. "tensor([[0.7554, 0.9157, 0.9113],\n",
  212. " [0.9709, 0.4587, 0.1902],\n",
  213. " [0.0684, 0.9601, 0.5903],\n",
  214. " [0.9831, 0.6989, 0.0867],\n",
  215. " [0.5819, 0.7923, 0.3808]])\n",
  216. "第一种加法,y的结果\n",
  217. "tensor([[0.7554, 0.9157, 0.9113],\n",
  218. " [0.9709, 0.4587, 0.1902],\n",
  219. " [0.0684, 0.9601, 0.5903],\n",
  220. " [0.9831, 0.6989, 0.0867],\n",
  221. " [0.5819, 0.7923, 0.3808]])\n",
  222. "第二种加法,y的结果\n",
  223. "tensor([[1.1361, 1.4054, 0.9468],\n",
  224. " [1.6410, 0.5193, 0.3720],\n",
  225. " [0.9482, 1.6716, 1.4168],\n",
  226. " [1.3925, 0.9253, 0.2908],\n",
  227. " [1.4907, 1.7178, 0.7246]])\n"
  228. ]
  229. }
  230. ],
  231. "source": [
  232. "print('最初y')\n",
  233. "print(y)\n",
  234. "\n",
  235. "print('第一种加法,y的结果')\n",
  236. "y.add(x) # 普通加法,不改变y的内容\n",
  237. "print(y)\n",
  238. "\n",
  239. "print('第二种加法,y的结果')\n",
  240. "y.add_(x) # inplace 加法,y变了\n",
  241. "print(y)"
  242. ]
  243. },
  244. {
  245. "cell_type": "markdown",
  246. "metadata": {},
  247. "source": [
  248. "注意,函数名后面带下划线**`_`** 的函数会修改Tensor本身。例如,`x.add_(y)`和`x.t_()`会改变 `x`,但`x.add(y)`和`x.t()`返回一个新的Tensor, 而`x`不变。"
  249. ]
  250. },
  251. {
  252. "cell_type": "code",
  253. "execution_count": 9,
  254. "metadata": {},
  255. "outputs": [
  256. {
  257. "data": {
  258. "text/plain": [
  259. "tensor([0.4897, 0.0606, 0.7115, 0.2264, 0.9256])"
  260. ]
  261. },
  262. "execution_count": 9,
  263. "metadata": {},
  264. "output_type": "execute_result"
  265. }
  266. ],
  267. "source": [
  268. "# Tensor的选取操作与Numpy类似\n",
  269. "x[:, 1]"
  270. ]
  271. },
  272. {
  273. "cell_type": "markdown",
  274. "metadata": {},
  275. "source": [
  276. "Tensor还支持很多操作,包括数学运算、线性代数、选择、切片等等,其接口设计与Numpy极为相似。更详细的使用方法,会在第三章系统讲解。\n",
  277. "\n",
  278. "Tensor和Numpy的数组之间的互操作非常容易且快速。对于Tensor不支持的操作,可以先转为Numpy数组处理,之后再转回Tensor。"
  279. ]
  280. },
  281. {
  282. "cell_type": "code",
  283. "execution_count": 11,
  284. "metadata": {},
  285. "outputs": [
  286. {
  287. "data": {
  288. "text/plain": [
  289. "tensor([1., 1., 1., 1., 1.])"
  290. ]
  291. },
  292. "execution_count": 11,
  293. "metadata": {},
  294. "output_type": "execute_result"
  295. }
  296. ],
  297. "source": [
  298. "a = t.ones(5) # 新建一个全1的Tensor\n",
  299. "a"
  300. ]
  301. },
  302. {
  303. "cell_type": "code",
  304. "execution_count": 12,
  305. "metadata": {},
  306. "outputs": [
  307. {
  308. "data": {
  309. "text/plain": [
  310. "array([1., 1., 1., 1., 1.], dtype=float32)"
  311. ]
  312. },
  313. "execution_count": 12,
  314. "metadata": {},
  315. "output_type": "execute_result"
  316. }
  317. ],
  318. "source": [
  319. "b = a.numpy() # Tensor -> Numpy\n",
  320. "b"
  321. ]
  322. },
  323. {
  324. "cell_type": "code",
  325. "execution_count": 13,
  326. "metadata": {},
  327. "outputs": [
  328. {
  329. "name": "stdout",
  330. "output_type": "stream",
  331. "text": [
  332. "[1. 1. 1. 1. 1.]\n",
  333. "tensor([1., 1., 1., 1., 1.], dtype=torch.float64)\n"
  334. ]
  335. }
  336. ],
  337. "source": [
  338. "import numpy as np\n",
  339. "a = np.ones(5)\n",
  340. "b = t.from_numpy(a) # Numpy->Tensor\n",
  341. "print(a)\n",
  342. "print(b) "
  343. ]
  344. },
  345. {
  346. "cell_type": "markdown",
  347. "metadata": {},
  348. "source": [
  349. "Tensor和numpy对象共享内存,所以他们之间的转换很快,而且几乎不会消耗什么资源。但这也意味着,如果其中一个变了,另外一个也会随之改变。"
  350. ]
  351. },
  352. {
  353. "cell_type": "code",
  354. "execution_count": 14,
  355. "metadata": {},
  356. "outputs": [
  357. {
  358. "name": "stdout",
  359. "output_type": "stream",
  360. "text": [
  361. "[2. 2. 2. 2. 2.]\n",
  362. "tensor([2., 2., 2., 2., 2.], dtype=torch.float64)\n"
  363. ]
  364. }
  365. ],
  366. "source": [
  367. "b.add_(1) # 以`_`结尾的函数会修改自身\n",
  368. "print(a)\n",
  369. "print(b) # Tensor和Numpy共享内存"
  370. ]
  371. },
  372. {
  373. "cell_type": "markdown",
  374. "metadata": {},
  375. "source": [
  376. "Tensor可通过`.cuda` 方法转为GPU的Tensor,从而享受GPU带来的加速运算。"
  377. ]
  378. },
  379. {
  380. "cell_type": "code",
  381. "execution_count": 15,
  382. "metadata": {},
  383. "outputs": [
  384. {
  385. "name": "stdout",
  386. "output_type": "stream",
  387. "text": [
  388. "tensor([[1.5168, 1.8951, 0.9824],\n",
  389. " [2.3111, 0.5800, 0.5538],\n",
  390. " [1.8280, 2.3831, 2.2433],\n",
  391. " [1.8020, 1.1518, 0.4949],\n",
  392. " [2.3995, 2.6434, 1.0684]], device='cuda:0')\n"
  393. ]
  394. }
  395. ],
  396. "source": [
  397. "# 在不支持CUDA的机器下,下一步不会运行\n",
  398. "if t.cuda.is_available():\n",
  399. " x = x.cuda()\n",
  400. " y = y.cuda()\n",
  401. " x + y\n",
  402. "print(x+y)"
  403. ]
  404. },
  405. {
  406. "cell_type": "markdown",
  407. "metadata": {},
  408. "source": [
  409. "此处可能发现GPU运算的速度并未提升太多,这是因为x和y太小且运算也较为简单,而且将数据从内存转移到显存还需要花费额外的开销。GPU的优势需在大规模数据和复杂运算下才能体现出来。\n"
  410. ]
  411. },
  412. {
  413. "cell_type": "markdown",
  414. "metadata": {},
  415. "source": [
  416. "## 2. Autograd: 自动微分\n",
  417. "\n",
  418. "深度学习的算法本质上是通过反向传播求导数,而PyTorch的**`Autograd`**模块则实现了此功能。在Tensor上的所有操作,Autograd都能为它们自动提供微分,避免了手动计算导数的复杂过程。\n",
  419. " \n",
  420. "`autograd.Variable`是Autograd中的核心类,它简单封装了Tensor,并支持几乎所有Tensor有的操作。Tensor在被封装为Variable之后,可以调用它的`.backward`实现反向传播,自动计算所有梯度。Variable的数据结构如图2-6所示。\n",
  421. "\n",
  422. "\n",
  423. "![图2-6:Variable的数据结构](imgs/autograd_Variable.svg)\n",
  424. "\n",
  425. "\n",
  426. "Variable主要包含三个属性。\n",
  427. "- `data`:保存Variable所包含的Tensor\n",
  428. "- `grad`:保存`data`对应的梯度,`grad`也是个Variable,而不是Tensor,它和`data`的形状一样。\n",
  429. "- `grad_fn`:指向一个`Function`对象,这个`Function`用来反向传播计算输入的梯度,具体细节会在下一章讲解。"
  430. ]
  431. },
  432. {
  433. "cell_type": "code",
  434. "execution_count": 16,
  435. "metadata": {},
  436. "outputs": [],
  437. "source": [
  438. "from torch.autograd import Variable"
  439. ]
  440. },
  441. {
  442. "cell_type": "code",
  443. "execution_count": 17,
  444. "metadata": {
  445. "scrolled": true
  446. },
  447. "outputs": [
  448. {
  449. "data": {
  450. "text/plain": [
  451. "tensor([[1., 1.],\n",
  452. " [1., 1.]], requires_grad=True)"
  453. ]
  454. },
  455. "execution_count": 17,
  456. "metadata": {},
  457. "output_type": "execute_result"
  458. }
  459. ],
  460. "source": [
  461. "# 使用Tensor新建一个Variable\n",
  462. "x = Variable(t.ones(2, 2), requires_grad = True)\n",
  463. "x"
  464. ]
  465. },
  466. {
  467. "cell_type": "code",
  468. "execution_count": 18,
  469. "metadata": {
  470. "scrolled": true
  471. },
  472. "outputs": [
  473. {
  474. "data": {
  475. "text/plain": [
  476. "tensor(4., grad_fn=<SumBackward0>)"
  477. ]
  478. },
  479. "execution_count": 18,
  480. "metadata": {},
  481. "output_type": "execute_result"
  482. }
  483. ],
  484. "source": [
  485. "y = x.sum()\n",
  486. "y"
  487. ]
  488. },
  489. {
  490. "cell_type": "code",
  491. "execution_count": 19,
  492. "metadata": {},
  493. "outputs": [
  494. {
  495. "data": {
  496. "text/plain": [
  497. "<SumBackward0 at 0x7f0158d2f198>"
  498. ]
  499. },
  500. "execution_count": 19,
  501. "metadata": {},
  502. "output_type": "execute_result"
  503. }
  504. ],
  505. "source": [
  506. "y.grad_fn"
  507. ]
  508. },
  509. {
  510. "cell_type": "code",
  511. "execution_count": 20,
  512. "metadata": {},
  513. "outputs": [],
  514. "source": [
  515. "y.backward() # 反向传播,计算梯度"
  516. ]
  517. },
  518. {
  519. "cell_type": "code",
  520. "execution_count": 21,
  521. "metadata": {},
  522. "outputs": [
  523. {
  524. "data": {
  525. "text/plain": [
  526. "tensor([[1., 1.],\n",
  527. " [1., 1.]])"
  528. ]
  529. },
  530. "execution_count": 21,
  531. "metadata": {},
  532. "output_type": "execute_result"
  533. }
  534. ],
  535. "source": [
  536. "# y = x.sum() = (x[0][0] + x[0][1] + x[1][0] + x[1][1])\n",
  537. "# 每个值的梯度都为1\n",
  538. "x.grad "
  539. ]
  540. },
  541. {
  542. "cell_type": "markdown",
  543. "metadata": {},
  544. "source": [
  545. "注意:`grad`在反向传播过程中是累加的(accumulated),**这意味着每一次运行反向传播,梯度都会累加之前的梯度,所以反向传播之前需把梯度清零。**"
  546. ]
  547. },
  548. {
  549. "cell_type": "code",
  550. "execution_count": 22,
  551. "metadata": {},
  552. "outputs": [
  553. {
  554. "data": {
  555. "text/plain": [
  556. "tensor([[2., 2.],\n",
  557. " [2., 2.]])"
  558. ]
  559. },
  560. "execution_count": 22,
  561. "metadata": {},
  562. "output_type": "execute_result"
  563. }
  564. ],
  565. "source": [
  566. "y.backward()\n",
  567. "x.grad"
  568. ]
  569. },
  570. {
  571. "cell_type": "code",
  572. "execution_count": 23,
  573. "metadata": {
  574. "scrolled": true
  575. },
  576. "outputs": [
  577. {
  578. "data": {
  579. "text/plain": [
  580. "tensor([[3., 3.],\n",
  581. " [3., 3.]])"
  582. ]
  583. },
  584. "execution_count": 23,
  585. "metadata": {},
  586. "output_type": "execute_result"
  587. }
  588. ],
  589. "source": [
  590. "y.backward()\n",
  591. "x.grad"
  592. ]
  593. },
  594. {
  595. "cell_type": "code",
  596. "execution_count": 25,
  597. "metadata": {},
  598. "outputs": [
  599. {
  600. "data": {
  601. "text/plain": [
  602. "tensor([[0., 0.],\n",
  603. " [0., 0.]])"
  604. ]
  605. },
  606. "execution_count": 25,
  607. "metadata": {},
  608. "output_type": "execute_result"
  609. }
  610. ],
  611. "source": [
  612. "# 以下划线结束的函数是inplace操作,就像add_\n",
  613. "x.grad.data.zero_()"
  614. ]
  615. },
  616. {
  617. "cell_type": "code",
  618. "execution_count": 26,
  619. "metadata": {},
  620. "outputs": [
  621. {
  622. "data": {
  623. "text/plain": [
  624. "tensor([[1., 1.],\n",
  625. " [1., 1.]])"
  626. ]
  627. },
  628. "execution_count": 26,
  629. "metadata": {},
  630. "output_type": "execute_result"
  631. }
  632. ],
  633. "source": [
  634. "y.backward()\n",
  635. "x.grad"
  636. ]
  637. },
  638. {
  639. "cell_type": "markdown",
  640. "metadata": {},
  641. "source": [
  642. "Variable和Tensor具有近乎一致的接口,在实际使用中可以无缝切换。"
  643. ]
  644. },
  645. {
  646. "cell_type": "code",
  647. "execution_count": 28,
  648. "metadata": {},
  649. "outputs": [
  650. {
  651. "name": "stdout",
  652. "output_type": "stream",
  653. "text": [
  654. "tensor([[0.5403, 0.5403, 0.5403, 0.5403, 0.5403],\n",
  655. " [0.5403, 0.5403, 0.5403, 0.5403, 0.5403],\n",
  656. " [0.5403, 0.5403, 0.5403, 0.5403, 0.5403],\n",
  657. " [0.5403, 0.5403, 0.5403, 0.5403, 0.5403]])\n"
  658. ]
  659. },
  660. {
  661. "data": {
  662. "text/plain": [
  663. "tensor([[0.5403, 0.5403, 0.5403, 0.5403, 0.5403],\n",
  664. " [0.5403, 0.5403, 0.5403, 0.5403, 0.5403],\n",
  665. " [0.5403, 0.5403, 0.5403, 0.5403, 0.5403],\n",
  666. " [0.5403, 0.5403, 0.5403, 0.5403, 0.5403]])"
  667. ]
  668. },
  669. "execution_count": 28,
  670. "metadata": {},
  671. "output_type": "execute_result"
  672. }
  673. ],
  674. "source": [
  675. "x = Variable(t.ones(4,5))\n",
  676. "y = t.cos(x)\n",
  677. "x_tensor_cos = t.cos(x.data)\n",
  678. "print(y)\n",
  679. "x_tensor_cos"
  680. ]
  681. },
  682. {
  683. "cell_type": "markdown",
  684. "metadata": {},
  685. "source": [
  686. "## 3. 神经网络 (FIXME)\n",
  687. "\n",
  688. "Autograd实现了反向传播功能,但是直接用来写深度学习的代码在很多情况下还是稍显复杂,torch.nn是专门为神经网络设计的模块化接口。nn构建于 Autograd之上,可用来定义和运行神经网络。nn.Module是nn中最重要的类,可把它看成是一个网络的封装,包含网络各层定义以及forward方法,调用forward(input)方法,可返回前向传播的结果。下面就以最早的卷积神经网络:LeNet为例,来看看如何用`nn.Module`实现。LeNet的网络结构如图2-7所示。\n",
  689. "\n",
  690. "![图2-7:LeNet网络结构](imgs/nn_lenet.png)\n",
  691. "\n",
  692. "这是一个基础的前向传播(feed-forward)网络: 接收输入,经过层层传递运算,得到输出。\n",
  693. "\n",
  694. "### 3.1 定义网络\n",
  695. "\n",
  696. "定义网络时,需要继承`nn.Module`,并实现它的forward方法,把网络中具有可学习参数的层放在构造函数`__init__`中。如果某一层(如ReLU)不具有可学习的参数,则既可以放在构造函数中,也可以不放,但建议不放在其中,而在forward中使用`nn.functional`代替。"
  697. ]
  698. },
  699. {
  700. "cell_type": "code",
  701. "execution_count": 25,
  702. "metadata": {},
  703. "outputs": [
  704. {
  705. "name": "stdout",
  706. "output_type": "stream",
  707. "text": [
  708. "Net(\n",
  709. " (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n",
  710. " (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
  711. " (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
  712. " (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
  713. " (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
  714. ")\n"
  715. ]
  716. }
  717. ],
  718. "source": [
  719. "import torch.nn as nn\n",
  720. "import torch.nn.functional as F\n",
  721. "\n",
  722. "class Net(nn.Module):\n",
  723. " def __init__(self):\n",
  724. " # nn.Module子类的函数必须在构造函数中执行父类的构造函数\n",
  725. " # 下式等价于nn.Module.__init__(self)\n",
  726. " super(Net, self).__init__()\n",
  727. " \n",
  728. " # 卷积层 '1'表示输入图片为单通道, '6'表示输出通道数,'5'表示卷积核为5*5\n",
  729. " self.conv1 = nn.Conv2d(1, 6, 5) \n",
  730. " # 卷积层\n",
  731. " self.conv2 = nn.Conv2d(6, 16, 5) \n",
  732. " # 仿射层/全连接层,y = Wx + b\n",
  733. " self.fc1 = nn.Linear(16*5*5, 120) \n",
  734. " self.fc2 = nn.Linear(120, 84)\n",
  735. " self.fc3 = nn.Linear(84, 10)\n",
  736. "\n",
  737. " def forward(self, x): \n",
  738. " # 卷积 -> 激活 -> 池化 \n",
  739. " x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
  740. " x = F.max_pool2d(F.relu(self.conv2(x)), 2) \n",
  741. " # reshape,‘-1’表示自适应\n",
  742. " x = x.view(x.size()[0], -1) \n",
  743. " x = F.relu(self.fc1(x))\n",
  744. " x = F.relu(self.fc2(x))\n",
  745. " x = self.fc3(x) \n",
  746. " return x\n",
  747. "\n",
  748. "net = Net()\n",
  749. "print(net)"
  750. ]
  751. },
  752. {
  753. "cell_type": "markdown",
  754. "metadata": {},
  755. "source": [
  756. "只要在nn.Module的子类中定义了forward函数,backward函数就会自动被实现(利用`Autograd`)。在`forward` 函数中可使用任何Variable支持的函数,还可以使用if、for循环、print、log等Python语法,写法和标准的Python写法一致。\n",
  757. "\n",
  758. "网络的可学习参数通过`net.parameters()`返回,`net.named_parameters`可同时返回可学习的参数及名称。"
  759. ]
  760. },
  761. {
  762. "cell_type": "code",
  763. "execution_count": 26,
  764. "metadata": {},
  765. "outputs": [
  766. {
  767. "name": "stdout",
  768. "output_type": "stream",
  769. "text": [
  770. "10\n"
  771. ]
  772. }
  773. ],
  774. "source": [
  775. "params = list(net.parameters())\n",
  776. "print(len(params))"
  777. ]
  778. },
  779. {
  780. "cell_type": "code",
  781. "execution_count": 27,
  782. "metadata": {},
  783. "outputs": [
  784. {
  785. "name": "stdout",
  786. "output_type": "stream",
  787. "text": [
  788. "conv1.weight : torch.Size([6, 1, 5, 5])\n",
  789. "conv1.bias : torch.Size([6])\n",
  790. "conv2.weight : torch.Size([16, 6, 5, 5])\n",
  791. "conv2.bias : torch.Size([16])\n",
  792. "fc1.weight : torch.Size([120, 400])\n",
  793. "fc1.bias : torch.Size([120])\n",
  794. "fc2.weight : torch.Size([84, 120])\n",
  795. "fc2.bias : torch.Size([84])\n",
  796. "fc3.weight : torch.Size([10, 84])\n",
  797. "fc3.bias : torch.Size([10])\n"
  798. ]
  799. }
  800. ],
  801. "source": [
  802. "for name,parameters in net.named_parameters():\n",
  803. " print(name,':',parameters.size())"
  804. ]
  805. },
  806. {
  807. "cell_type": "markdown",
  808. "metadata": {},
  809. "source": [
  810. "forward函数的输入和输出都是Variable,只有Variable才具有自动求导功能,而Tensor是没有的,所以在输入时,需把Tensor封装成Variable。"
  811. ]
  812. },
  813. {
  814. "cell_type": "code",
  815. "execution_count": 28,
  816. "metadata": {
  817. "scrolled": true
  818. },
  819. "outputs": [
  820. {
  821. "data": {
  822. "text/plain": [
  823. "torch.Size([1, 10])"
  824. ]
  825. },
  826. "execution_count": 28,
  827. "metadata": {},
  828. "output_type": "execute_result"
  829. }
  830. ],
  831. "source": [
  832. "input = Variable(t.randn(1, 1, 32, 32))\n",
  833. "out = net(input)\n",
  834. "out.size()"
  835. ]
  836. },
  837. {
  838. "cell_type": "code",
  839. "execution_count": 29,
  840. "metadata": {},
  841. "outputs": [],
  842. "source": [
  843. "net.zero_grad() # 所有参数的梯度清零\n",
  844. "out.backward(Variable(t.ones(1,10))) # 反向传播"
  845. ]
  846. },
  847. {
  848. "cell_type": "markdown",
  849. "metadata": {},
  850. "source": [
  851. "需要注意的是,torch.nn只支持mini-batches,不支持一次只输入一个样本,即一次必须是一个batch。但如果只想输入一个样本,则用 `input.unsqueeze(0)`将batch_size设为1。例如 `nn.Conv2d` 输入必须是4维的,形如$nSamples \\times nChannels \\times Height \\times Width$。可将nSample设为1,即$1 \\times nChannels \\times Height \\times Width$。"
  852. ]
  853. },
  854. {
  855. "cell_type": "markdown",
  856. "metadata": {},
  857. "source": [
  858. "### 3.2 损失函数\n",
  859. "\n",
  860. "nn实现了神经网络中大多数的损失函数,例如nn.MSELoss用来计算均方误差,nn.CrossEntropyLoss用来计算交叉熵损失。"
  861. ]
  862. },
  863. {
  864. "cell_type": "code",
  865. "execution_count": 30,
  866. "metadata": {
  867. "scrolled": true
  868. },
  869. "outputs": [
  870. {
  871. "data": {
  872. "text/plain": [
  873. "tensor(28.6268, grad_fn=<MseLossBackward>)"
  874. ]
  875. },
  876. "execution_count": 30,
  877. "metadata": {},
  878. "output_type": "execute_result"
  879. }
  880. ],
  881. "source": [
  882. "\n",
  883. "output = net(input)\n",
  884. "target = Variable(t.arange(0,10).float().unsqueeze(0)) \n",
  885. "criterion = nn.MSELoss()\n",
  886. "loss = criterion(output, target)\n",
  887. "loss"
  888. ]
  889. },
  890. {
  891. "cell_type": "markdown",
  892. "metadata": {},
  893. "source": [
  894. "如果对loss进行反向传播溯源(使用`gradfn`属性),可看到它的计算图如下:\n",
  895. "\n",
  896. "```\n",
  897. "input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d \n",
  898. " -> view -> linear -> relu -> linear -> relu -> linear \n",
  899. " -> MSELoss\n",
  900. " -> loss\n",
  901. "```\n",
  902. "\n",
  903. "当调用`loss.backward()`时,该图会动态生成并自动微分,也即会自动计算图中参数(Parameter)的导数。"
  904. ]
  905. },
  906. {
  907. "cell_type": "code",
  908. "execution_count": 31,
  909. "metadata": {},
  910. "outputs": [
  911. {
  912. "name": "stdout",
  913. "output_type": "stream",
  914. "text": [
  915. "反向传播之前 conv1.bias的梯度\n",
  916. "tensor([0., 0., 0., 0., 0., 0.])\n",
  917. "反向传播之后 conv1.bias的梯度\n",
  918. "tensor([-0.0368, 0.0240, 0.0169, 0.0118, -0.0122, -0.0259])\n"
  919. ]
  920. }
  921. ],
  922. "source": [
  923. "# 运行.backward,观察调用之前和调用之后的grad\n",
  924. "net.zero_grad() # 把net中所有可学习参数的梯度清零\n",
  925. "print('反向传播之前 conv1.bias的梯度')\n",
  926. "print(net.conv1.bias.grad)\n",
  927. "loss.backward()\n",
  928. "print('反向传播之后 conv1.bias的梯度')\n",
  929. "print(net.conv1.bias.grad)"
  930. ]
  931. },
  932. {
  933. "cell_type": "markdown",
  934. "metadata": {},
  935. "source": [
  936. "### 3.3 优化器"
  937. ]
  938. },
  939. {
  940. "cell_type": "markdown",
  941. "metadata": {},
  942. "source": [
  943. "在反向传播计算完所有参数的梯度后,还需要使用优化方法来更新网络的权重和参数,例如随机梯度下降法(SGD)的更新策略如下:\n",
  944. "```\n",
  945. "weight = weight - learning_rate * gradient\n",
  946. "```\n",
  947. "\n",
  948. "手动实现如下:\n",
  949. "\n",
  950. "```python\n",
  951. "learning_rate = 0.01\n",
  952. "for f in net.parameters():\n",
  953. " f.data.sub_(f.grad.data * learning_rate)# inplace 减法\n",
  954. "```\n",
  955. "\n",
  956. "`torch.optim`中实现了深度学习中绝大多数的优化方法,例如RMSProp、Adam、SGD等,更便于使用,因此大多数时候并不需要手动写上述代码。"
  957. ]
  958. },
  959. {
  960. "cell_type": "code",
  961. "execution_count": 32,
  962. "metadata": {},
  963. "outputs": [],
  964. "source": [
  965. "import torch.optim as optim\n",
  966. "#新建一个优化器,指定要调整的参数和学习率\n",
  967. "optimizer = optim.SGD(net.parameters(), lr = 0.01)\n",
  968. "\n",
  969. "# 在训练过程中\n",
  970. "# 先梯度清零(与net.zero_grad()效果一样)\n",
  971. "optimizer.zero_grad() \n",
  972. "\n",
  973. "# 计算损失\n",
  974. "output = net(input)\n",
  975. "loss = criterion(output, target)\n",
  976. "\n",
  977. "#反向传播\n",
  978. "loss.backward()\n",
  979. "\n",
  980. "#更新参数\n",
  981. "optimizer.step()"
  982. ]
  983. },
  984. {
  985. "cell_type": "markdown",
  986. "metadata": {},
  987. "source": [
  988. "\n",
  989. "\n",
  990. "### 3.4 数据加载与预处理\n",
  991. "\n",
  992. "在深度学习中数据加载及预处理是非常复杂繁琐的,但PyTorch提供了一些可极大简化和加快数据处理流程的工具。同时,对于常用的数据集,PyTorch也提供了封装好的接口供用户快速调用,这些数据集主要保存在torchvison中。\n",
  993. "\n",
  994. "`torchvision`实现了常用的图像数据加载功能,例如Imagenet、CIFAR10、MNIST等,以及常用的数据转换操作,这极大地方便了数据加载,并且代码具有可重用性。\n"
  995. ]
  996. },
  997. {
  998. "cell_type": "markdown",
  999. "metadata": {},
  1000. "source": [
  1001. "## 4. 小试牛刀:CIFAR-10分类\n",
  1002. "\n",
  1003. "下面我们来尝试实现对CIFAR-10数据集的分类,步骤如下: \n",
  1004. "\n",
  1005. "1. 使用torchvision加载并预处理CIFAR-10数据集\n",
  1006. "2. 定义网络\n",
  1007. "3. 定义损失函数和优化器\n",
  1008. "4. 训练网络并更新网络参数\n",
  1009. "5. 测试网络\n",
  1010. "\n",
  1011. "### 4.1 CIFAR-10数据加载及预处理\n",
  1012. "\n",
  1013. "CIFAR-10[^3]是一个常用的彩色图片数据集,它有10个类别: 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'。每张图片都是$3\\times32\\times32$,也即3-通道彩色图片,分辨率为$32\\times32$。\n",
  1014. "\n",
  1015. "[^3]: http://www.cs.toronto.edu/~kriz/cifar.html"
  1016. ]
  1017. },
  1018. {
  1019. "cell_type": "code",
  1020. "execution_count": 3,
  1021. "metadata": {},
  1022. "outputs": [],
  1023. "source": [
  1024. "import torch as t\n",
  1025. "import torchvision as tv\n",
  1026. "import torchvision.transforms as transforms\n",
  1027. "from torchvision.transforms import ToPILImage\n",
  1028. "show = ToPILImage() # 可以把Tensor转成Image,方便可视化"
  1029. ]
  1030. },
  1031. {
  1032. "cell_type": "code",
  1033. "execution_count": 4,
  1034. "metadata": {},
  1035. "outputs": [
  1036. {
  1037. "name": "stdout",
  1038. "output_type": "stream",
  1039. "text": [
  1040. "Files already downloaded and verified\n",
  1041. "Files already downloaded and verified\n"
  1042. ]
  1043. }
  1044. ],
  1045. "source": [
  1046. "# 第一次运行程序torchvision会自动下载CIFAR-10数据集,\n",
  1047. "# 大约100M,需花费一定的时间,\n",
  1048. "# 如果已经下载有CIFAR-10,可通过root参数指定\n",
  1049. "\n",
  1050. "# 定义对数据的预处理\n",
  1051. "transform = transforms.Compose([\n",
  1052. " transforms.ToTensor(), # 转为Tensor\n",
  1053. " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化\n",
  1054. " ])\n",
  1055. "\n",
  1056. "# 训练集\n",
  1057. "trainset = tv.datasets.CIFAR10(\n",
  1058. " root='../data/', \n",
  1059. " train=True, \n",
  1060. " download=True,\n",
  1061. " transform=transform)\n",
  1062. "\n",
  1063. "trainloader = t.utils.data.DataLoader(\n",
  1064. " trainset, \n",
  1065. " batch_size=4,\n",
  1066. " shuffle=True, \n",
  1067. " num_workers=2)\n",
  1068. "\n",
  1069. "# 测试集\n",
  1070. "testset = tv.datasets.CIFAR10(\n",
  1071. " '../data/',\n",
  1072. " train=False, \n",
  1073. " download=True, \n",
  1074. " transform=transform)\n",
  1075. "\n",
  1076. "testloader = t.utils.data.DataLoader(\n",
  1077. " testset,\n",
  1078. " batch_size=4, \n",
  1079. " shuffle=False,\n",
  1080. " num_workers=2)\n",
  1081. "\n",
  1082. "classes = ('plane', 'car', 'bird', 'cat',\n",
  1083. " 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')"
  1084. ]
  1085. },
  1086. {
  1087. "cell_type": "markdown",
  1088. "metadata": {},
  1089. "source": [
  1090. "Dataset对象是一个数据集,可以按下标访问,返回形如(data, label)的数据。"
  1091. ]
  1092. },
  1093. {
  1094. "cell_type": "code",
  1095. "execution_count": 5,
  1096. "metadata": {},
  1097. "outputs": [
  1098. {
  1099. "name": "stdout",
  1100. "output_type": "stream",
  1101. "text": [
  1102. "ship\n"
  1103. ]
  1104. },
  1105. {
  1106. "data": {
  1107. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAALVElEQVR4nO1cW3MVxxGe2d1zk46EhEASlpCEBaEo43K5UqmUKz8jpCo/MQ/Jj0j5JQllLGIw2NxsK4iLERJH17PXPEz316OZxdLoeb4XWruzvbPDfNOX6Tn64cuRUkopVde1OomqEbmsaqcZhIKbFTVJVVV5jRsWRGdRlaRc8d2Gbmtu3/ADTdM4Ql4m0tXavYs+NI1mVepjX9pUckUXlXMX7RMVcWbEwQpAppkCEACttMgsJiw1uOK1EYHvJWhtvQWqUhY0s0FrJqbGYy5V00S6Bwjf5RpdSZKU/vaorRrpldau2oRfFGdWAOJgBSBLZMLy5OS/Ey1DCakRXqAZBDZJunEayRVrkguNEpe3iSwO3DkxWCCv9R0ed1J0hvsO9uFtYLTSNg3d5QhsjTMrAHGwApBZfHJnvj2zwYtaYTLz5GzQ5qQiS2w8U6tsHmm3WeL1wmK9e+VEJ+C7ijlkN5VvZQluJM5HqTbvF0Y6zqwAxMEKQBysAGRwWH0XO7GMqCxVWNr4AawFvivQ1O6CaLvdWpNLXXNEnYgn7boC0r0Ga6ulSvrAjgVPg6pkj58vQUOt8S2W68APIhiIHvx5EAcrAJk18wkWc+ygF3EsLnLuSey9qwNudC0+hJU5Ep/Ddf0tEvqeDFwHiztYHBpXgzg0csv3VERV43EzevDnQRysAEg+S8JgvldbbEgSmBKka+mWzHyhFWJyNIJOC4hs4VJLPotttJDPzZbZy0cj7V2zqFJXuZh7lmp7zqAPSMYpV4g4HXGwAiCBdIt32tjWENlk7d/FE87f2mOHnSxqGtCQA1rtPqi8KxBOmjD0wPOQvTy45fm2JA1ASY3eeHyMOB1xsAKQiY/n08q68BteX4vQss9DQm3lfxEb4lErGhVb62mSXLXVP5ek1h0v2e0pP/HtkuJ2I9w4swIQBysAWcPTrvIqAM6I1CcdiMIzuYBHmGR4MEFWl1mQ8pNlUzhv0QolCOzxCotV3fD/OvIwKLNgVbVmd9oLRWtRLouD1q6vHGdWAOJgBUBI0VKMcDYknt8olUaNn8axd114ejeunRKC1O5mkrWDeyI6dL7DN82+oYQ1TKz8rV8A4Wd7Ik5HHKwAZFq5+/26JegTaEmQIg/jjnjLhEey0dr8hJXBbgjebO3XukkebM3aqkBkLQkc3PXKJrxMqd3hunYXkxQ+s/tVER9HHKwAZLLm+5nEVrRUyzGJavdOW+7UripgS4dSPNGIJCZtKaZI47CG1OonzC7KpPx9B+isGrwOHRUaViBpTZ5qmqb8FRFnRhysAMTBCoCUSYK24H9b4liWNvHFUZXfkrjlhYP5n1mrQ8aBcCUmPOFuUbMclf68q4QOpFYVJ+JorINwgLAdlXiOQtW6THsZ7RhInwdxsAJgbbJaRbzm38r2tgE/Nhan2fWMMZNB0IP9D9C0vf3OCEXB2StW1ZuYcl47nBxSr/hQTpL15TO482VJroZfoCDuiFecUdvBAF/WnBqLZ3fOgzhYAcjatm1cwYb2JrNVe8R/s+1Dk4QLfp/98BCq7t69a4TxeGyEPCc+Fg1Zyi++/NIIn9++bQTQcHK2B1U4QqekNApW3k1tV6UbFdi5A1hPmGZvszXiDIiDFYAsscp86N/W3BPguay1xmRmTR5/Gy6xXbh0ERdXlz+hFzEdtt+/N0JeEw0zVvr4+wdGuH79Bt868Qb+CPSKbTrTFoF3guJcvlLZpcbMOkmXt5Q2RpyGOFgBaNndCd/fYQ1ylo6Jyf8X+TGZuV5X3njzxroRpqbIBf3mm3tG6A5njXBwdER9YtZfnL3g99M6g4cKRWTZvEIoT0r8PLhStXcsPs6sAMTBCoCYlMqLqqQw1vb6pJCBnTdVOQ+CAjjR8fbtKyN8d/9b6Dw+PjbC5i+/GCHNiKTXrpOw9XLLCF999SfuFPWqKqQeIvUOi9f8OR22ffiZCvldB8mQW7UOqPzDOHBqO86sAMTBCkBWeT+XIlV6lt2QX3GQ/U9qX1aF00YOjLEvOneZrJvqiDVMFQV3U3Nz1GyOXNa8yo2w9YpoOL+wyMq5JMi22rUwivopd9wtnFq5YeOJPSfv5EyTRGsYjjhYAcgQOllzklDVYiPQLFMwgkix8hlLsaL0f3BhetoIPzx5YoT5K8vQeXBwYISpGaLh/v6+EV5vEfuevPjJCH/7+z+M8Jc7fzVCryuZUuvnlOhKXoBE2hFg2cUVtew+fNESzWKtwzkQBysAcbACkB0XpXNJ9kUsM4/cccXubJmT/52mXW5BQ//zTz8b4e3bX42wf3hohPxEJRScD96w6Q2MsLh01QhXr103wmBIy193YpJ7YvWZ/Ymyoe6N+St6aYe/y1udJeQQVVhwk9oNSOLMCkAcrABk9+7/10jwtuEldKzcU6/DfnNN/vrkgPzvJCEaNglduXdvwwgbG/eNsLu3Z4SF1TXoXF4mN+Lp06dGmGNXfmVlxQjrN24aYW2Nkl9vft02wrgQHoJZ45w2ipBTyziQxg6TtfdLRCtKey1q4SZpcC9EfBxxsAKQvf+wa6TBgCxRxkmlzLKGmoPJNSbIzDTlgvsDqkJ49uJ/dGuGMr/r69eMsDMi13x6fhE6//Xv/xhhc3PTCCWnqO7c+bMRZmcptH786LER3rwmGua2OWQTdshmt9MhIwinPpX9Hg6k4dNbNMTeKtYlv4Y64nTEwQpABpNSHNAEnp2l3FOv30W7hUt0scPcHI12jbC3T/Gw4jNqv7tJlmtpiUi3u0c03DnMofOPf/i9Eb74/DNqtks6+/zqmRnyRY8OaJvnYH/EfWeiWdVRiIgrzohhdwe0bbyAv2yj4W9UL0WcjjhYAcgSnszb22Rl9njCPzvaQbseVwpcmiVepFLaQCPe53I9mNGq5NxQ2bJBsrJ8hVRxVT4MMRzjfEz28ZPFy0bY3KRUV29yILqYUKMRkTTPmYZcnIsMV8qVvzCCRdFCQ+tcbsxnhSMOVgCyhmfdxUs0z1EOW42lWLbhY9mDASVzUQePCp5KUZuDQ7KPBVfyjXMOPGsxYTnzGDSE3cmYKSknWLocga6vXnUeV0qV7HlWnDhqeM8JDNOpe1K8kjNDkjgqeenAmlDHFM05EAcrABkog1mHdAccQqWULjkvyns5OVfN9jPKzHSEO8je8OOY+aX1Yww1NjvlPdyM+ctv2d+jDmRMzP60dC/nOG5+boaUF2TT9yoUPXT4HbKBRVcSoXQxphdVXAQMWxlnVgDiYAUgO2YaznEyBDwBv5RSyyuU1ex1aTI/evS9EV5uvTHCYEhbCUh4dlLyG3WXnUxl5yS50LxyDWuGA6mcGtIDEsbwNot9UcQBYMo1VDOTE0Y4PqRDL3VO2VosF3ND3h9ZmIcq1Dq8eU0PVtXgRHcjzoI4WAGIgxWAbOEy0fWIyzQS9iFu3/4M7VaWKTO1NyLmT0xQNvnwmIz00xfPjfDkx2eknVUhRzbJJ+GU5a9P8PrS4aieM2MSig/6tHCguPKoOIYq/KbTaIeC//l5itKHvJIOp+gtV68sGGHpCn17t2M5NLwX++7dB/5k+sA4swIQBysAGfI+MMljrtPf2JDK4offkYBULJJWq2trRrh165YRUGb14AEduHn+nBi6s7MLnb0eu/68EwNh0KFb3Q7Fz91u12lTWbWNSUqdQeHFCgf8K4urRri6St7PBU6E9bFzbKnCNm2vR+m50ZAS7nFmBSAOVgAyJGum+QDN+JBouPVqE+0O93aNAIp1mBf//PprI3Q9WoE7S0tLRsjzH6ETaazhkExkxldqjl1hm0bcAcTkCJ6VUkfHtIZ8yiVKO2wWYaw7XVI+9SkRM0mQ/hYavt+mF/X7ZD3n5siUx5kVgDhYAfg/pQ4eZ65sAxcAAAAASUVORK5CYII=\n",
  1108. "text/plain": [
  1109. "<PIL.Image.Image image mode=RGB size=100x100 at 0x7F1EC53B6588>"
  1110. ]
  1111. },
  1112. "execution_count": 5,
  1113. "metadata": {},
  1114. "output_type": "execute_result"
  1115. }
  1116. ],
  1117. "source": [
  1118. "(data, label) = trainset[100]\n",
  1119. "print(classes[label])\n",
  1120. "\n",
  1121. "# (data + 1) / 2是为了还原被归一化的数据\n",
  1122. "show((data + 1) / 2).resize((100, 100))"
  1123. ]
  1124. },
  1125. {
  1126. "cell_type": "markdown",
  1127. "metadata": {},
  1128. "source": [
  1129. "Dataloader是一个可迭代的对象,它将dataset返回的每一条数据拼接成一个batch,并提供多线程加速优化和数据打乱等操作。当程序对dataset的所有数据遍历完一遍之后,相应的对Dataloader也完成了一次迭代。"
  1130. ]
  1131. },
  1132. {
  1133. "cell_type": "code",
  1134. "execution_count": 6,
  1135. "metadata": {},
  1136. "outputs": [
  1137. {
  1138. "name": "stdout",
  1139. "output_type": "stream",
  1140. "text": [
  1141. " cat deer horse plane\n"
  1142. ]
  1143. },
  1144. {
  1145. "data": {
  1146. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAABkCAIAAAAnqfEgAAA09UlEQVR4nO19SZAk53Xel0tl1l7V1dvs0wBmAAx2AiAEimCQIilrs6wIy3JYctjhi0+KUIRPvvrkcPhqRVg32WFH2AeFFZYoOyTKpCRCJkWCBEgAxDL7TE/3dE8vtW+5+fC/72VNVTWla9P/u/Trqsw///wzK/N9b/keYMWKFStWrFixYsWKFStWrFixYsWKFStWrFixYsWKlf+/xFn86Lf/2T+W75zMKNFkDGDc78i/45FR0mQq+6Ty13VdjitKPI2NksHnMV0eKpvdOQgLRikWQ6P4BdnF90UJAtmmVCoZpcCvCp4MW3AdAFkmw2ZpwoOlPL5sWQzkQOVi0SiVUpHnLisTp7II//zf/AfMyMVXNo1SXz1rlO7BUHaJZBcvkANVm3KgUi2QyfgyfhRnAApcyEd7B0bZuHDeKGkk3/WPxrJv7HK2Mlp/0DOK68n61Os1AFEki58kMScuc5tOp3Nn6vuebOLJJ54vB/ID+aRYCgCMB3KmDu+feCrDfvD2D/G4/M6//FdG+fJrbxmlfSyzvX94DOCv3v66+ff2w4dGefPNrxjlaedYxu+Jcu7zv26Ug0jO/c7tD41yeHDDKN3jXQD1alXWpCpX6v1P7hilc/xAvsoio3zx818wyo3rd41y4dIlo2zfet8ov/qLbwHYurhl/r11Z1+mtLpqlJsc/79+7S+M8upXvmyU15590ijR9n2j/Nvf//eYkd/+F79jFL3VPU8uhxvImTr8KsuyOSXgnex4sk2SJLODeK4oSSw/h6AgP4d4MpEDpforczm+bDOZyP1jfkxONuHEZQM3rMmBQln2zJNpT/UmTKc8Ndk5iuSWTibR7OnEsVyX3/29f4fHxYUVK1asnBLxl3wWy9ssLPKJjgEAL+ubfx1XHpkFvpML+oB39QEvg43G8liNJzRweJwgDDBjb4Uh7Z1yyA3EXnBdtQJ8blyYG81xOL6TAfByO27+lRUEYkYVaa8FtCmcTGbrOnw1LTNCAaRTOdz+jthEDt/5NFmQuLJN4vBt78uhx7EYKYVKCKBcF4OxnkRcBNkygyhRJGfUO5R9hzR11WZUu2k4HIGvWQApzUzP05dnxtk6/ETPja9NDuupFea4AMo0b8OCLGC708cJctyVrz65c8soMV+5e8dHAFpnG/J5IMe9+f4PjHKxJed+jldzjdbZ+lf/vlE2X31OTqSzY5S/+J9/AKDIpdgdyXJ1Y4EIo1he7Jc3zxjllRc+I1O6LVbed97+c1kMyBX5zrsfA3ju+c/KuTuPjPLBt75plMBvGWV15aJR9g9l38++9kWjTNY/kXX5fcxKgSupii6+S5zh+PP3fERLJDeTaUl5jgcgpYnk8YeZ8cZOE7GS9GrGEUdz5VYo0GD3qGSpAwCp/ELVcnf5401dudNcXw5d8NXUAkeT2QZ8gMTeFECaqkF3oiFlLSwrVqycGrEPLCtWrJwaWQIJ6011bItSqzoAJiX67cZi29H6QyEoGyXjE3A8IRih0z1OxnMHCn0fQKFApBaKoUhXMqpVOt2J5tT0dR0ZP8l96iKu5wFwGTFIErEzpxOZ9ngkSMqt1eXQBKGuR2OYs/LS5c90n+7GCaFHRsjmevKVz+jBeBRxMrKxH8qy+BUPgEuvdq0l4DGhP3hIS9ohCi5VK0bpHQsqqVYqevbmT7/fBxDRyC8QVqiikDCHHlSCgLibG8eED+PhCECrIeumoY8pgf+iXL99xyiNspzaG6+/bpRPv3ULwL37AhVrZTmLF56/YpTi4Z5RDnZ3jbLaEVd3CXLE5oogyj//+h8ZZWf7AYDLq/L5rR+/Z5RWTQDmL/7SV43y5OUto9y9e90od7YFsg0GgvTf+rmfN8pzz70MoN2RmycZDoxybiKfxEdyOYpDXjLePA8PZNqF6QDLRH0deofnCiGg3oger1QKdWWAG4sYL0rG+E+qPhqNhjEM4/De0/s2pmNef3cakJlOYgCOw9splN9qyhlkKc89476hukqIRhnIcghgUSgASDgnxcWLYi0sK1asnBqxDywrVqycGlkCCTvdtlEGPTERy2EAoMCYWsEXIKAJQUFBsE9EY9IlUitWxCB0mdyhIa1yuQSgWhM4WQg0nMc8KaZHKb5ziPw0bqUBRM9T1FMAgFRtVDmLlHPTUJpLWDEYCFyNmPRUDJmZlaOtx+TlF58wykFHhr13QxBBjZh2TMw8ZKCqALGB11uSt1IsVwAEFYYCaQyPOKXpVAaJafa7BVmfUlmWztXV4O4m4BLHmoelIVReIIIHDc2oOD9ByYCZ6HBCyKlJT4tS1qw6Xt82b7B+rwPgledeMP+eXZEo29OXn5adCbI0D2sqgyG+I+lXLhfh7gcSW5xM2gDe/sF75t8Wces/+YVfNcoLr71mlGuvvmKUvY6M75bkjP7HH/yhTOaZa0b5/OfeAuARAP7NLUn72rwuKPJi7YJR1mJZsG4iMc3f+0+/a5RnN+tYKry4izlWWaIXiBdR8+MYPNWL6Co2dBzMYDoN52kcX38vE8L5IJCNM41up3Rc8CfpSf4gcx75w9cDezz3mEeMuGKe5pEl+tvk3o4DIE71Ll0emoe1sKxYsXKKxD6wrFixcmpkCSTUGhePj7MszgBMIwkVqcHm+VpownBSLNtMUwFBil+qRWYbEm2Vy2UAtTqT+mm+akGAT6tVUxlT2pmaCanpjppKWqvXMRNJUYs6Yb6ixs6iCZNa+ZXm14UEMmFYwjJ57bWrRrmz3ZUtCfeePHfOKPcZG7qx25ZTpjF/9axsY1Lw2jyv9khAhKafJoTDKVGcmvdTAtgC1zbT+ItjJi9nMWH5xXSqRRXgtEVKCsA19sSvKiWBXWaRc5RNyOA4J4Z11taaRtnbk8TOo0MBehurGwDOr0tpy4U1Ua5eFUjoeM8aZTiVZfnff/onRgnee9soFy8JNj/bkEW4cPYqgKeeF4B2sSwobIPKB++8Z5TWplTtXHh2yyi/9Gu/bJRvf//HRrlzS4p1PnttD0CLAeWtp+S4hSPJNT18IOfVLAu2vX5fynp2+nIn7N1hCPxx0QxPF/orYBhXf27O/HXRH4i6YnLJzC5yv2nQOckYIs9kJhEjiQoJC76WyLCujuE842+JCeXiTBMGOKlMA5e8b1mRA24TEkhquHASjwGkC3B4UayFZcWKlVMjSyysp69KFsx0SL/vcALkZo4aL+omd7W0kqk0SWNFNuZLQA0rL3cEZgCCAr3yno7G8TWDA+qApBeZVpLPhCn16ZpH/JiedbXBtAp6PJHzcjL17osFUVY39kJ20pyUCzLa5qoct/WqrJvHhKyMGS4x1ydgOkyDFa2mlGHKtQ19pm5FNDNpu3hMfklY5JSFjDPQwp3wkrlIAUz53o743vNSrRti+hg96C6vS5V5XuD7s8BainIpBICMeTq+XtMTLazzGy1OUk7t1vU7MoVRAqDsyOdFKh6jOgnN2wlfxccDyWOaHom91uu1jaLBnEmcAXj9rc+bf+OurMlfv3fTKJvnxfjqj2Tdxn1ZqJXqulG+8BUp/dn96AM50MN7AGqbcjqdqWCIqCEJX/sPxIzqemLFNBixiVzZ6+jhPpZJHjtayMNSh7pDsJJnI/KUk3T+JpfymnxLOVCs151GjAZdtHSmzPvK02xHpTlABoCAJI8IpTS6tKbNDzTHivZaHhAghuNt4yYuZlHRQiBIxVpYVqxYOTViH1hWrFg5NbIEEl48R4Kno7ZRBt4AAOjXU8tNszOU8ibjeG6gSf2CfRIsFNNkGYBuR/ypmtNRqQouU5+uFotosYvrBJyDAkwxI0fjCYBBn9wSPK+AhSYBvcsFzUnJa9mVBouwa9GdaY7Ls2hWiETo9e/t0xPJ9XninACNpkcfN1FVr9OfPS6JMBCPGAcgmovpL89YKKPrr1k8RZ7j+uoKgDt7h+bfPmuDWEKPDVa0ZDnTgxJFcJGhaJHu0sDHDLWZOk0znOglnXZlDm984UtG0WS07bu7AA6PJAfq/v1t2aAo3vGXPvumTJuXu5pwShW5SWLClg8/IRFCWAZQKEswZ0jYcv+RHKgliA3PvS4reQUyWi0UV0ahQIINrn8hiwFUSrI4RyQF+6Pvfs8o62QBubb1lFE+/8JLRvnu9z82SufBMZaJGg6aH6cgUf3xM84N3pz5AKw/Y2TMoEV1v+TYkIM5iTpGZJsolhs4JvFJyF/xJGJBUpICcDOlOZGJRwpF9TGhhHTMuko53+FYf9d0KTguAF/DSpmFhFasWDn9Yh9YVqxYOTWyBBJutsQqrjOLZ1gdAQiVpY9W64Qg5ZjgUaMGoNmvgYAiI0qaQjUZTQAMmBg0HIjF2Okz4csXJSxLtKhcEaVGIgeHkE0Vkz526axQDIdFmYCvBQ2F+eBjzuhAbDshPoqTEyAhUSSYgtTtMhoyFmXcE0N6tSX462yzKQciJOz2RwAO9yWXxyNcbTE6eNBlNGogkZppVyJlIZc0cpQxQuz8V178WQDXrslR/uzP/8ooTcLtr35JCIvffV9yhXZ5EZXRQa94HMv43V4XyyK/IxJgLEp1Iti8vyO1LI/274lyuA9gvSHpVy1Glm/cFwrju4/+l1HWGGlqEk0cfyQAcJ+h2CtPXjbKOC0AuPup5E997itfMsozz79olHu3BHse9+WUD1mak2Yy/kZVbpLzrwmsazRCAB3SkHz5F/+BUb7+7W8bpcKg9sULQuC32VozyirvWxKgoD2XD6d8ivzAXQjwAcpznW9l/ii7tSZkmWRGdb8U+FtWMmWP/paE0buUSFAT/RLNBeMVTxOjsMCLBA8e102nnSmDZjYf+5shj5RNTA7gdKKELjYPy4oVK6df7APLihUrp0aWQMLGStMoTSomUKI1LlofEzE5U/vcDAiCtM1GwoyyiEUnEakLsmkEIIsFpAxisai7EzFwh33BPqORlL8orNvckCDOuTVJydtsSWhppekBWFlhNUlRwamS/ykdIPgVo5BEi2pa/62QUDnX2QcHGZNaPe66UiVToJZK5PTVCYAC3xxPnBWWcZcTOFORfQ/asm43yCE3JnauNORkJ7TM723fA/Brv/Qr5t/L6wxTrgjsOmoLCDo8FKa6ep1xQ8bF+gMCPYIFzw8xA6Vz3osT0msBDNvCpH58IJUrF87IZFbX1wEM23Jj/Pjjj4xy8drLRgm5Stu7km/5Aq/y+gtStTM+FG6/ClM6G/VNAPfvyS4txhM3tqSY5skLW0aJSSTwta8J+V/GOpJ2h0x7vPf+8v0fALjy7PPm31/49d80ys+Qj/DqmQ2ZCQQA3iSjQ28oi1Cpyt2Cx0nwc6597UDlqAdDy2Lm+fg11qxMeCVl1EvS2S2VfV+pH/w8XEhXA+vqYtJYxmQiVEY9zykASBhhVEioPyUFgooN9YhT1oeVeEs3ecsdHR4AmE5l2FTZBRfEWlhWrFg5NWIfWFasWDk1sgQSBsRHga9+/hTAhFTuWn0U0Fht1gSghUwTHXPjEWHFhJGFlEXbncEIwEFfzL9+KltGzHabJGTL5r6ahloo+3OKdspaWasACAKdpMu5Efcpk3Ruxyoduxask8RaK6weF4e4KSyIkR/4smWzLhVkBbLgN0ge7zCaNuqLsr7SANBsyuQ3G7KSJVrUXeZ8VmmW7+9ISHFCtgYFmCBQ/fCj6wCe2RKQ9dqLAmQ0IfYP//hrMj7DtAGp5fpMiewTxWsqabNRx0x3Jq1cUwKMRVk7tyXbMEHx6JEwtT/zwmcArL8mKHhIlNLcbBrl6lXBfUe3BVupYwEs+mutCuLbPCMVgoNHfQCIBNPppdRwVUFf05z//o7EDT/4kbAA3r4tHU/diYyzUgkxUzR35ZrEHL/0pS8Z5RxzcUMSW97/xjeMsncol2xtQ6Z94yHTXAHMdNZKcyYMzlEDfwTmmqY7Q/KnsTm9tzOzhflvRAZ6OJojrSx9DAWy2DN25iczJZ1k4DqYCTVqxF+5GWLmCSRsWJdxtJjg8Ec/lG67AaPMzVoNQIMI0bMEflasWPkpkCUWVveR1FIUlbPYzTDrDsyrrlnfz66f2vx9HMrzNSQtwTGrUvYGsteN4zGAg2N59mfkBigVWQwB2bdFTiUlzyoVlZ2ZBlQow5arRQAVdoEN6EHkqycnCdLnOK2l3EGoTlD3BKanUN3zXJZqhb587YBCn3VA0uRqIAZUzFfF8xdXATw6Eg9xoFSztOw89nbd3tnj3MQmrbXElJuQv0EpKiJkAO7tyS4/87q4sTsHbaPceyi+9pjXZcAUJ59xhipjBaOxuIh39/cAnKMLX2uDFnmWVe48POQpy+Sikfj7G80VAFvPihn11pelmU2Bfc+P9o+Mcv26dNaZHstoEdOIGrRZBgOxkh492AOQ0L1drMuaH+2Ie/5od4erIZ/84Pv/1yiffCo2qVaYFJj9FDoJgJ0dyRF793t/Y5Q3f+ErcoJchAJtijL79PTpud9Y28Qy0Uw3ZTLIHetsvaNtcd0FA0T/z+iGN/VzZabdqVc+oZmpju2Ye2esRtKKnIIrN2FF+c3jDEDGQfo03LS5vNKaJ3ThJ+yjo7/QJxn9UBe9AXPa2iezpTlWrFj5KRD7wLJixcqpkSWQ8OE9Jraw4U29XsEMd4JaoB7tTFf9mjn9mJL8ySCHXUlF2TkQn+4kdgFkNLmVXnk0lEFYz4BaILZinYqWzrgZU0Los0/HAYBxJIOMdNp0EIbMVamw14vOX7vmKEeadricE/XcT2kMZxBlQst2TKu7Q9q51qrglwsMUxS9NoBGRY4yUpIJ+kRbdZntE5fFOd0niPtgW/DRtM9UmqkACdNsdUKs21yT4/aYujVgrxS3JPhrQndph/wZDqvntao+jscAuj0SbHgnOkdV3v1E/OUXL8ocilPBR+P+IQA/kMt8/oKcYIEgaLIr94w/0YagbL1z2DZKEpP3cVUmY9agS6dvfyizPbwpxTo3PpRqpEe7UiT0YPcux+d9lTunZTKDKAZQZMnU994RkobLL0lA41VGNvbv3OTZs6sokf5owKDB4zLT1xZU9KfEHxeXJVuAhOrYVn+N2SLn2mTExuX8yxXxcPdJ0nA8ZksnJmAWMrkbpwPJoZtMepipxNLSH522/pzdPIyQdwQyf6ukyUTOYx4BmHLxNb9sUayFZcWKlVMj9oFlxYqVUyNL8M5wKHZgmpCeAT6APgtl1GwOGUasVNkRh00r4bIhKLP7t9lT5O49ApnYx0w0pErGtXpJAEhYEPNyheOvr0v5Ra0lympd7NiiJwca9DuYyb5xaSS7uR0tn0TEbhrjUv4JNdE15WROYkKqjKFMJVHrs+i8Tf6G47aE5FbKZzltOaNPbj4AkKRy3GaDsa0SQ7S0qMuMh5YZsskmpOXjqa0QxSfRBMDgSC7lkPHKPRbipAPBZR6nXaoKa0KayCC72+IcWGnIpVlZqQPod2XYFjOPPP/EN19Kn0BQk5jmxaKUB7368lUAXk4OIUihvy1RvP2PpHVNlTU0Zd4bet88uCf46+HtT2WbcgXA5nPS1ijal7M4vCmjjXYl5lgjkirk/HOcNjPyXIa9xnEEoEs+jyIjvy5z36rsLXS71+EocoGe3LpklEYoi4wf4zFR3OTMZzYpJMwWOtvqXZryk1SDjCkw08HX4z3pE4WlU1nSfibeiaQoJBNeWcD7MdPTpt0dzsoDWfMB+FzAIG94Qy5PbX7MWyPlJ0oCkUQKhF3M9HBwbWmOFStWfgrEPrCsWLFyamQJJFxdbxqlxMR5x/EBDIdiXo6HgobG2nNpLAZnpaaEYWVurNx4rDmg+Voo+gB6x5IcyARJBJkAkEpLBvGVfqws5muT4bZGWY5Y8hjgy0IAYcCET1r7mimX94MipcSI+G48Fmim/b6Uo25Oxiw9B+uTQmZR+oyd1aoy7SQRADhiPbpWSoxHCYAS8UWZLVFdBpjiiIEYZuudYRDncy+1eEayF0OL2Hv4CMBKUzbY2xM83u4IWllbkaRQXhYMGMYtsW9VifzlKRNSjx4dANjYEFyjeDnDyTZ8UU7tY1IX/PyvCb3BM09fBrDNrNRGQ67pp9els1bnjqRxrvDmCIkmzgSykmNfoM0x+24V2n0Aq4cyyZU+Q43rcl+V+3KCh8RuVfLxt+klGKrfgBjQdKXLpszaXZdbUdNre0cyWkz81Scv4OsvkwWwJIwOf/zNP55ZJBQYdFYSdk3K1XCbo9wM2gEMWqyz0McUKQCf9wxIrK5ga8LSpQ6rnZIVuVsmKUGoNh9jr4EbN+4B+PimZM8+T9LEl648weMwYJ0uJKbm1WyEfgzBGzBYrzNRmY+aRbEWlhUrVk6NLGtVTxdusajPeAeAS3unHDIzpU/vMsmV2m15WyZa6co2lnsP5f0zINVvIZwAQMYu6rR3jpkMNY3kQTukk2/EYT2+1jbpfW+xbNjHFDPe04h+9DEf2wMahuOpvkW1YnPB05ktf6Z7bDZZJhfYZExCZLLrliraKV4+GfJt36CX+vlnrwHotcUC8hjQ8D368nmJGnWx9Wor8tXZqUwy9cV26LFItRIWATSaq/xc3qJHtLBSfRXz3a6l1HFXNl5pyiIXWREVTX3M0IepBEGAE+RzbwkXc78v5/jaF6UEp9hcBbAK8cEXQ5nShz+WCuSDB2KU1WM5rybtWY0VXGjKsly5JMbL6HAA4AKr4je0UQz9watk3O522kZpVeUED5iMdsyrGUdaWRIAWNuUJf3q3/uyUc6fER7k40dSX6U5ekX1x9PMLCz5wQHIfe1qAemt6LustqHTWkv01dQKNKGJ47lZBlbSYLbih9dOMUSFRWvdqfwutCYPzNV6+1ti6t69swPgeCA/80PS0o0yiSo4tIG8sCljVGTFFPRo/yo1JzuDIwD7exIeGVoLy4oVKz8FYh9YVqxYOTWyLA+LVE36nfH/RYRsKZ1/PtmHayTeTRJ5Ao7HtMO72leDdiyrakajPoACwVFIh3qpqA5vgVRFuqKLbIFTKimbAksB+PA1BmfO0KTVQvTweXQVV7WhjnJaaWcgxZQn1I0XmRDEyvac+6FAtDgZCpCZkHOqEgqAvXRRmvrEgzaAY3p2o5FsqcsC1lJoN1kFAjFd+IMRIxsjjYf0AJw7K2lfIZ3KUawxEIYgCoovmHlERDMecfyuzGpzvQWgypImbWIUs7frovzWP/0tDivw5PKWnLvpxXKmIJf71qfSbfSDG8IVdbQtFTMXuRq1TWE7SOl9L5NOY60mN2FcrAGo1mXYR7vCbPXJR5KoNSTFmMPLvcbOPTu8ZC4LxTKWXpkb7CyJmL/yxZ81isKx7i69/hz2+avXjNLYENh4i3Vvc6K9ncplpra5giJz9iulqUoUPyoReTz3leFWiXir5B1VSUhy3GOOHityEJA3hX6JoChIOSzLWZ895wC4XJGUukuXBAm2zoj3vV6Se0OxZ0TqBeVGPz4SEpFety0nkk0xQ0euNFuLYi0sK1asnBqxDywrVqycGlkCCaeMLEwZp5hGJu5Gur5qg4pYjIk++BjbirnxOqOE/cJ1o7Sj20YZPRoCiMasHqDpWGQNxyYN6SAv1pHxNRDj5HSxGukDZiChdnZR/jNtFAqCIKa8zNQ+6Pje8md6ljebFCVgxUycchhGmlihgc2LckYrqxKA++DuJwAc2s8xSSYU4RbyfBzwE05b6yEYCa2xjmRSLgKoMkx5/rxgww8/YR1GprhSJjemHc7OpFhZkdk6KYno+gMA00gOd+GCgLufUJrzMlOQVBSemHKUkP8ePJJY832SC2rRfodT6uXRWy7CiJlTd6R8pLXaALC+Lud+f1foie9QmTAoHDJPLSSbhcfqkFIsaDfq0wHi+QCeufK0+XeFfL5dcgpOWD80Q6cnylki2fF0OdjJi2wSjQmyIQ3ZLHymBCqDiLLl5UuaQ0IHM6zEE6L7eCJh4hFJLHoD+SR02LSGkLDWkBtg66nnjLK5UgbQWJHlqnEBU57X8EhKvro9iZl2hqLE7PLravkOc7WMOyK/1W0jVStWrPwUiH1gWbFi5dTIEkhYaYiZp3mPpnC835MgQo+FLKknNl6ZXWEKRY2yMWanXWccBpIcEkVXygDqbH8CQsIeCyaG7LhTqgj2PJ9IZ5SnrjA/TUNmkQbIRpjpdqMxNeUwU0M6TeYTR7W5iEOSvDhdbp0mLM1xq5oTyCjeWJQaeeWnicIKmczDPSluMB1KHH6eKjbUYn0m1irPtaN5gwqQtXyHgaRi4AGokGu/wRTQ7QeCm7S0KCI2TPz5dESfk9GoTbFUAqB0jT0y+fmFE5n8ZhgQFbyzVsMMS/Q9ILlgn6nFLinrHvEE7/PemxKDTAl2AtKwnxuPAFRYYHTILMQ9Vswocz+YRhtqDQoD0zXSjYwnMmytXAPwmZdfNf+WQwHX/Uxqy3bJGZ8yAqvhVPVLrDKePicjlrsNSJWn6+b6RIKFcO6rvL8pu0npNsVyDUC5JouwvnnOKDFTcDUMff5YLuLNmwKZXaLR1TXZ68ozguvrwRTAwf4+j0soSq/HjRtCjuizWssrsH+VEmDwvsp46Q0dvkYJs5wrdF6shWXFipVTI/aBZcWKlVMjSyBha0MiGgGLCk0VXrE2nyLps5pfa80URGh4osoenGsrEvtrkMPg3nEXQJDI59WKZPoVmER3/UOpYAqKrBPUNmJklY6mDKuR8KBWbwDItCElZ5Jno/Er7QPqERFUtS0SIaEyEc5JwvaWU6KVKGbCbSZnVCTl3gpP+fBQAmG9EakRyg4AXztWMripHGlFXW1meCYsLnOYSlrOiD0JPQyTX5GZvT96/0dUPpRzpzVeIe1crMHBFiOJiYw2YcrfpQvnMVMC1meAyTmhGRoe61u1PAdXIfeEfIQZ+dR9ZiEeD9pG2SbSr69K5eDxkSCy8ZHEFgemQnBbQPeAV7nHINR4QmKQVPBRyHju+aclCHh4Q+jetWOu63gAfN4PDotMA8IxveV6LLUrsl5vRNxaZnb0nGgIPmAisubZehnD6MwTBmtgx5pByls744VIHR/AyrrQ5DdfFEz37o/lN9UhOYdDbgY3oxuEwxUIkDukbPw/3/oagDNrgjQvXdkySo3Ek00y62uPu4x8kBPS9UWMoyqT3zQ2bhxmTRdOvJ2shWXFipVTI0ssLJ/v7YSpRqaopappIMxj8vMHIX2W9PkFZPhVuqXzZ8VwOzqSF8XBcQqgTkaBc2eY3lWUfQ535SXplyQOUG+JzzII+P7RPClNLIoed5PzP+2fqvOPeYLq7VObAvxqkZmAW8poQ+YBuXQ8O+zfE9CKKfNNtXOk1Q8yzsVKFYBHXoqiGGfodpmrQusppYUy1nwWR17gzUbTKJ22mBuGQ0K7zH7ne9Lipc+Xf4vkR/lFpIWlRnG/Kzk1SowlPv085Y1lK9GJiTN5t1ptV6udXRwHM+TAA1rN2sF8dUXskfvbsj7PvC5cWr/8G79plDY5p771Z39qlFsffQDgJkejqQqfpxzFskrKdfUiG95ceU6U778vdULK1mAS1pT6LclLdkS2traMcuO2UDCnGmeg8RWcQNeQ0WmtjWM0OVDtWbAlsNaN+SzbKigxVvaY4pE3/PZtNoiNZLn2yRP9YJt9betijqW1FgeRr/a2JYmy23kE4OJ5CZR98y++ZZRPb0rZ0xfefNMoRdqMMR8Bak9hoZOxk4wxE7ZSo3VRrIVlxYqVUyP2gWXFipVTI0sMVM9TjEAn5WgCYMrKA/Vea7HLiMzCpYrY8DWCFC2mqZEwoMgsj+7+DoCDHSmmP96TXKFzZwQklpgRVmeNSIMl+IhkMt0DMXojlm8Y5OEzzaRE7rQiqy6mdO4qZV2SKsCUSarBDxAkPi5JTN47JjRVGxUOIrUUiYI4El1o2VM2JmXtxAXgaZEQL0hQ0katrJhhMYq6+VdYGtWoNY3S6YkNX2utAUg47IiTvHhJMmsmpIXQNKkmmZd9+n11ETZWZf3rlSqATlcgVZX5cScB57+LKNTtD7v8hPk4BBFrLG351X/0D43y87/8K0bRhLKXX37FKDc++QSzLUs5Pr3Y+O//+b8Y5cMPhYv5lVdfM8rVa8KvsPonAjAPD/Y5TISZihZVUl2ldQFKD/aFjWAGDosSx8tvp8XPtW1PzIuo4Qv9ATo5piISZIjJFEtd/5HwXnz3h+8Z5a03XjRK90gSvrRipn0keVjOnixLWuEV6chX9+7fw0xIbUhyizLb8e4+pFNC64c0NMSomrY+Ukq/glfHDHDWB8uiWAvLihUrp0bsA8uKFSunRpZY8iOChYjG9mgwwkzqk/Ii+AuJLRMqI9qKdW3BUhe0ePZJ6dd4aecQQLEkSHBtVSKATbI1BDtijddr3Jdoca1JnjCSPYxGZPIreAAyzS5h+xm1ltXgnEw1SqgpQnJGWqwznSy34UfsBpQXEaQCOUOaxwViq/5EJqnceJoL5qMEoFRhIlUkUTxXg1u8Rg4JwmtVpgIx5DRhbKnclDVcPX8ZwHc/EESg+V+tVtMoY22xqaEZzlZLf0psi6vT3tnZBeCQgHxtQ0BQTpf4E2Q+SCiirV/G2tqzTyTCRjjPPifpUS8xn6jKQJIi5Zeek6+effp5zOB9h3ApYbLVN77+TaPcuCWRsvNnBSmvMgx99oyc2kcfCz7yvQxApyOQp0dc7CXKPSAHWueyzESfZRtd9jnRBYwXeoimjEdnZAGJCEJTHjHWPD5l8osdAH0WMF08L11tblwXTsTDQ0F5CZFmSAK/2lh+s1caTxnl+IEsnan9ijkBpVpZWyMzJXlB1AWkzVZd8o142sxV6UbSDLMI92SxFpYVK1ZOjdgHlhUrVk6NLIGEHZawZzRoo8kUQMwsu5wXQbEhYYVyJIz67NDVF4xz1BGrcpv9vuLxBMCjtuC+aCgINF4V3DcmFfpGS2zUFiNxTaYURgPS5hG7mRiKBo+0KaOam1pi7mtDR9ZbpAwXDtkKbMDGWXOiTVgdFjBpeYrLlDmXxrYW3Kh5XGLIcnN9HUDsCmTrH7MZGgOvASRps8E6D4+xlTEJ2h02Xitkss3b338PwDvvCyTM6foIGVoM/IWsdtIFGjIcXCYkVDR98OgRgFJVPi+0BbtVSPK9KEvq7vPYWQZygQB5y9sJJ6mVRltbV41yYUPoOlzG/pgbq61CUQoKAAKukuZksogLHjNjq+wm22LtlM+uoitapMVZu14AYEpc2dkX3gufrUO7rL/RWHOFhTjxWIZtd9hB93HRYhRtmKYAKl8uLSnLZBuFhFONr2k4chwDOLorBUYTpRJkSdPmatMoh20paXqwL8rHJJ4/c5G/Tf7233j5eQAhS6b2HgkuBoPyRToWfA1gav9XXjJHyRd5juZJk+Sc7id0UrAWlhUrVk6RLM3DIj/RQGyTbqcLoMvXqVYnqPu0yNIT/UTrJ3Obhf4/ZVNdqxUB7IzknbPP7JXjjmSInFkXD2id7+/QE7tj/768OpQmOPdnuz6ASCegLx0+8kO1sGhYaWdKbc9ZY4qT75ewTHxlj/LnncmZmqJa7pvKS7hZkjM5uymnVi07AHZph47YrEVZooZkcRrQc18iD3WfVt4KU6iOyDL8V995B8DukQyrPFytGluudnlxM1l/bRSuSu5bpfVaa9TwWJoM3cDREkNKziMn/9VaFu1INAUQs+TV4QaafhWSuezqk+L9BQ3nw35bRmNNeN5D1HVnJ6l+7knu2JYlrdVYbcYq6AKXv0Ur0uNL3bBmu3QMj5nyFrIWvUvryWGPmUZNzLQJk8J6h4+wTPRHl3umlfWM47MwKk/IAk2VYu6t5ldxCiDJ2EmXPGLrLfmkHsq+jYZMcr8jk/zwUKzI47/8vlFefEYM29VGFUDKwu8tpvVlnLY+UALaXA5o6ualOWo0yl/jfU9zE9JSJFuxYuX0i31gWbFi5dTIEkioTLUHB4IBb16/iRnvrKIhNek6tHi1EalyAAyHYl5qr0e1YyvlCoC1NTIE0XNfrYqyUWGy0lTGf3R/wG3FqiwRNVSqgrbKUh5EWKGkXdwyoK9dG4X26GKPSDlUoc+10SximWifGIWEE0Kbiie7aCJbysILNyaJkiuwq9PbA7B/LG7OAfPXnKlAEoeYutsXDMIoAiL6leNMwM4hYXtvNAYwJJSbEt17ZL2aMnIyg+ZkDk9ckUS5IcGUo61cXB/AiOGR7W1xPNcbNZwg47Yg/THpp5UJw/TR7LYlznBM2miHkK1MX3gpkAnc+OgdGY0FJdrRNucUdlzMZDYleqZ6FlNByg2WbSWE2yMeeqXG24bNkEajAYBeX27FB7usaNlnjYsr98zWNQFKBV/u7bsPhcygTz7iOUkXilE0KSkn5FJqNq3RySnGFC3K/27qADhHF3vKlL1Of/z4Hqhq32LGWPq8vg+74nS/PGKRXHMdgM/DuEqRTEY5ZSjxmTaYOepZZ47Y3wb9LCS0YsXKT4PYB5YVK1ZOjSyBhNFUSw3kk+PjNoCAbGGanlNmUxYNBUYsKHFon9eaYpSC3RPHTK7pDaYAWi0BRyXmGdVrsm8rZJ19qlPS9jDgHAQJhgShyWQMwGfYokDb1+W+MZFOxjNFXrXDE2FCWfEExrUkD3URt0ZkUygxdsaErNSfhy39RJDyoNcDMCJw9hh/ySBmuWaNOZybl+fasBCKCWsZj2h6uBbZSKZODmuPAHDKuKT2wKmtSLJbnxlzfRLjalzYxLA0qKr9XzV7aFGuv/vXRkmI0Vwm5kzGQwCTMct6xuQLZFyswlFr5HQcdgWE6rLHE+bQcVYGekyYyKZdOZWTo8rksIwJWZ5OiTHZBo/dZLhwHE8AdFg25LHl7bvvvCuzbcit3jy7ZZSDA8lsuvERmamnyyFhlvMu8COFhGRrcB0t0qI7gmeUsfxIT9ZxUzxGC671N3LKEUPMHj0YdbJH9pXogkBVHQiOHwLwGK/UnsT5/JWaXHOpPA2zEsny55bp/CVKyF/oQn2SirWwrFixcmrEPrCsWLFyamRZI1WS5Gkk7pXPvIgZ3oJUe3Qo7UGslQGa0iZ/tSC7wCBdic1Ws8IYQFhg8C5gx1OPliFHc8lrrqGHcok5nzQ8p0RVJssxb2vKKBKI8rRqR0kaHKLdSs5Ip71elkPCkK19qkwOTBMx3ccMF3oMQq1tbHC2emoCSyvlMmZab2ptReqzoEFT8rjIgXb91MVmbGuDDW+ubm0CeLB/aP6tV5syPKk4XJfx0IbsEoTkHSTBYeqx/IWXpl6vAoi4bpW67FIqLV8lAMODB5wlV1vbZGYJgIA34dkVWfzVCkt/tJssM28TBvg8oonRQCnrGJ/yPACOEkhwECX1X+Et4RG2qIchY4ZtsymTWV+XTMuHhx0AD/eF5ODlt94wCsuicMAI4MGekEDssLVqiaHGp68xA/Ybb2NGlP1dZ+uoK8NR3KTrNs9qsEiEkbkOgIxwMqf+YDuoIe80dTVcZLsslwdSnKd0KVkSAcjyDrvzWdk6t9SZd8Vozq3iVs2szpDh5L5Ks2ItLCtWrJwasQ8sK1asnBpZYskrqnMJCTfPbQJot7UwTdBQoy5xpZ1dyZ3rkaRBbdAG8zkd8hNoALHZdAD4OafdfNGZlu9rhDEhmhgw/22qQTRoDMIBEBJSpUQTXkHbZNHi5ZELxD4BdC+GlpYtEYA6EwtLFe2syVxZwi7NxDt3ecsovbZgtM6xzL/klQAEFZl8maukJBNqJXtcU8IL1IjmylWpTHSZqfjsM78BYO9IcjKnrN8cMvA3JB1jSHykSLlA3D3oS6rnmFHCMAwBHLPozPVlSrW6DPIH/+3reFzikRwx0ppBhuRMBNbl7bRSlUEukU4v1jCrImaGoRXOK0O/r4Wb6RRAwg0KmspIIOMxAzMM5AL5gRw64d2ysiokfBvrAufbgymAARewUSPx5LUrcoIsvVSqxatba5y2fFUqaNL1Y6LxMjUhNHlSs0O1u5cy+TkLRXnOYrTx8f983lBBoDFH+WqzJqe8slqbO2KFHhjfhCM5pZh5yDPJ5HpK88fWuWmZpPbZmyd1xDzmVbEWlhUrVk6NLDEf1ljPPWHByt7DfQADZu7oO63Dogpt7F5mLlVeWEBFa2jU+26eqpowpalPE3aUSflujPle1eSaESdTpAvf44FMmb7yE+hxS2W+Tlma4+hbekR7iqU5+Ymc4AgMQy4drb8Se757ZO91QvlEE5fUQoy0z1ChBODSplTDVMqy7/GhpPBoX1uPZ+TrW4YvKG04OuYbb+3sWQDPPCsvfyWDznKfLp2jmhSTc0vQn61FQBSz+1iJN5jQ5NHU+tf4j3O7KGmvu5A9NJA4iRxFAyk/88bnjPLUM6/I6azJSg5GbaNo59aU5FDawF1KWHg/jHjzuMycGkznm+Devrctn0zkvuonTOvjbyQMygAmzH2bcnGqVZlAhaTAai/UAjFVcnPphH4wGdSe4mwfpw8G/eiYdXDr1fzbyIXVhe8xx62gRGI0sYplTbTkiXC2ahQbA01ZrZW2QSW/yksm4cxt4/EHPo2NUTzPArIo1sKyYsXKqRH7wLJixcqpkSWQsH0otKeDriC+9tExAJcZFgVmLc0Yq4RU2s9D+9DQTk66ZHQgPDEpSGO6qDUfX/cNyQsYEmSVWFleq4ubWV25WrUTGuhEj3pKi3fI7iNePjXln+A60Fg96Iw57PI2J0WlfCChcFBjp9hQgEB7KGd08/odOUetzCBO8ZwAQMBkNGWbW2mIT1eRoDLVVZmmNGoLf4NHE1orS/b39gCUalJkr+BaO+Jo5lfAkiYdRJdUs5OUwNe4S0tlWa6EHYm0M82i+FxkZdbWxCK3EABIIjncgC4ILZCaktn3xg2BbN/+tjj1xwyPKBmA8mf4foCZOo8h4zPquZ+SpUNb43zjm38p585hP70v4ZHDHrPqPB9AkYTLMUG9x8vhaOegHNaBc5NFSE7IV8tyAEiEzltR0eLfBQAu8ByoZ0bxsnxRoJt8xAZRY/4wA3ZFcjkZ7/FQQH4MDqtoTieg9Mf5qS0oM/N+7LyyhUQzFWthWbFi5dSIfWBZsWLl1MiyrjmM/WlOfau5AqBUJkkb8cWI6Tl37kkNvVZoF0NWP3BjBXoTbcgaRZjJeMoWaL2SvBfm/INVTdARu5j4LBxv1JoAiuU6R2P0cDrP6DYiLMpYxqHYJ+XKuCc80stsHKLkFvu7Eter1gUtDsdyRsdMaAqZC6YVRQdHHQAeU94aZBlXuBoSgCsRQ0QMpTxqEcN25Yog5cy0eGEsLVsICU0nMoMeoXqRjAsaYJ2yoH8SycaGBNElQkjIxZhmy4NfmCl/yQj04pwOJAWQMWvpnXd/aJT9fQnV1Vj6UyEvyOYq+ydtSCxbqSLVpWBSzLTPwJAAcMAbr9tuG2XrkuRJXb78ihyxKbfNjbtSVfOjDz+SM8kyAHXGmissySoSEiakIdGUQM/TXCdNblwOdtT7MVHSRy0byuGYIqb5Op5sIQAnx104mtIBelx2TVjrHsuyh3SahLxtQrojzC3nsnguvxMWIGF+RCpLsKGWH7kuZhwLP6FGx1pYVqxYOTViH1hWrFg5NbIEEmpfLLXLgrAAYGdPCJ6Vpj1j4qKnJRS0Y7URaUJsmBN6KVkaEgBBUcsj2FiJikYhnbz3kTI6EOjR4lQO7/ZRB0DQI0sfT2dJBDMPMLH6nLAoL1w4ISKjjTw1xTQmy3iv05YjssFRmXA14V5TZnjGaQRgwFBpleUpYGA0dQUsjIZisWv5SImRSo+NvxI2JTNWd85LRxSmDPQjksdrC6+UxnuR5UFaGVNg9Uzq+AASrb/3tKxKcybnpXskoUxtKxvzSvV7PQAxP3/9jTdl8rzu2p+1we5kF84J5i2WeO6MdmU5i2SEmb6/PpnvC7y4ffKyq5RYRKUg6oWrTxjl2ScvzW7p85RzzkJ2ZkvIrK/gtEhuPJeWgeadzom2jE0XCPBmwDYh4dz/M6KxcpMY7CykceqvO81RpLZllVsi1kgiKQPTPOZuRluMYM5PSSuK3Lwp2QJa5FeOm83ukthGqlasWLFixYoVK1asWLFixYoVK1asWLFixYoVK1asWLFixYqVUyf/D3PcGe48X+nJAAAAAElFTkSuQmCC\n",
  1147. "text/plain": [
  1148. "<PIL.Image.Image image mode=RGB size=400x100 at 0x7F1EC53EAB38>"
  1149. ]
  1150. },
  1151. "execution_count": 6,
  1152. "metadata": {},
  1153. "output_type": "execute_result"
  1154. }
  1155. ],
  1156. "source": [
  1157. "dataiter = iter(trainloader)\n",
  1158. "images, labels = dataiter.next() # 返回4张图片及标签\n",
  1159. "print(' '.join('%11s'%classes[labels[j]] for j in range(4)))\n",
  1160. "show(tv.utils.make_grid((images+1)/2)).resize((400,100))"
  1161. ]
  1162. },
  1163. {
  1164. "cell_type": "markdown",
  1165. "metadata": {},
  1166. "source": [
  1167. "### 4.2 定义网络\n",
  1168. "\n",
  1169. "拷贝上面的LeNet网络,修改self.conv1第一个参数为3通道,因CIFAR-10是3通道彩图。"
  1170. ]
  1171. },
  1172. {
  1173. "cell_type": "code",
  1174. "execution_count": 7,
  1175. "metadata": {},
  1176. "outputs": [
  1177. {
  1178. "name": "stdout",
  1179. "output_type": "stream",
  1180. "text": [
  1181. "Net(\n",
  1182. " (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n",
  1183. " (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
  1184. " (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
  1185. " (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
  1186. " (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
  1187. ")\n"
  1188. ]
  1189. }
  1190. ],
  1191. "source": [
  1192. "import torch.nn as nn\n",
  1193. "import torch.nn.functional as F\n",
  1194. "\n",
  1195. "class Net(nn.Module):\n",
  1196. " def __init__(self):\n",
  1197. " super(Net, self).__init__()\n",
  1198. " self.conv1 = nn.Conv2d(3, 6, 5) \n",
  1199. " self.conv2 = nn.Conv2d(6, 16, 5) \n",
  1200. " self.fc1 = nn.Linear(16*5*5, 120) \n",
  1201. " self.fc2 = nn.Linear(120, 84)\n",
  1202. " self.fc3 = nn.Linear(84, 10)\n",
  1203. "\n",
  1204. " def forward(self, x): \n",
  1205. " x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) \n",
  1206. " x = F.max_pool2d(F.relu(self.conv2(x)), 2) \n",
  1207. " x = x.view(x.size()[0], -1) \n",
  1208. " x = F.relu(self.fc1(x))\n",
  1209. " x = F.relu(self.fc2(x))\n",
  1210. " x = self.fc3(x) \n",
  1211. " return x\n",
  1212. "\n",
  1213. "\n",
  1214. "net = Net()\n",
  1215. "print(net)"
  1216. ]
  1217. },
  1218. {
  1219. "cell_type": "markdown",
  1220. "metadata": {},
  1221. "source": [
  1222. "### 4.3 定义损失函数和优化器(loss和optimizer)"
  1223. ]
  1224. },
  1225. {
  1226. "cell_type": "code",
  1227. "execution_count": 8,
  1228. "metadata": {},
  1229. "outputs": [],
  1230. "source": [
  1231. "from torch import optim\n",
  1232. "criterion = nn.CrossEntropyLoss() # 交叉熵损失函数\n",
  1233. "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)"
  1234. ]
  1235. },
  1236. {
  1237. "cell_type": "markdown",
  1238. "metadata": {},
  1239. "source": [
  1240. "### 4.4 训练网络\n",
  1241. "\n",
  1242. "所有网络的训练流程都是类似的,不断地执行如下流程:\n",
  1243. "\n",
  1244. "- 输入数据\n",
  1245. "- 前向传播+反向传播\n",
  1246. "- 更新参数\n"
  1247. ]
  1248. },
  1249. {
  1250. "cell_type": "code",
  1251. "execution_count": 10,
  1252. "metadata": {},
  1253. "outputs": [
  1254. {
  1255. "name": "stderr",
  1256. "output_type": "stream",
  1257. "text": [
  1258. "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:25: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n"
  1259. ]
  1260. },
  1261. {
  1262. "name": "stdout",
  1263. "output_type": "stream",
  1264. "text": [
  1265. "[1, 2000] loss: 2.210\n",
  1266. "[1, 4000] loss: 1.958\n",
  1267. "[1, 6000] loss: 1.723\n",
  1268. "[1, 8000] loss: 1.590\n",
  1269. "[1, 10000] loss: 1.532\n",
  1270. "[1, 12000] loss: 1.467\n",
  1271. "[2, 2000] loss: 1.408\n",
  1272. "[2, 4000] loss: 1.374\n",
  1273. "[2, 6000] loss: 1.345\n",
  1274. "[2, 8000] loss: 1.331\n",
  1275. "[2, 10000] loss: 1.338\n",
  1276. "[2, 12000] loss: 1.286\n",
  1277. "Finished Training\n"
  1278. ]
  1279. }
  1280. ],
  1281. "source": [
  1282. "from torch.autograd import Variable\n",
  1283. "\n",
  1284. "t.set_num_threads(8)\n",
  1285. "for epoch in range(2): \n",
  1286. " \n",
  1287. " running_loss = 0.0\n",
  1288. " for i, data in enumerate(trainloader, 0):\n",
  1289. " \n",
  1290. " # 输入数据\n",
  1291. " inputs, labels = data\n",
  1292. " inputs, labels = Variable(inputs), Variable(labels)\n",
  1293. " \n",
  1294. " # 梯度清零\n",
  1295. " optimizer.zero_grad()\n",
  1296. " \n",
  1297. " # forward + backward \n",
  1298. " outputs = net(inputs)\n",
  1299. " loss = criterion(outputs, labels)\n",
  1300. " loss.backward() \n",
  1301. " \n",
  1302. " # 更新参数 \n",
  1303. " optimizer.step()\n",
  1304. " \n",
  1305. " # 打印log信息\n",
  1306. " running_loss += loss.data[0]\n",
  1307. " if i % 2000 == 1999: # 每2000个batch打印一下训练状态\n",
  1308. " print('[%d, %5d] loss: %.3f' \\\n",
  1309. " % (epoch+1, i+1, running_loss / 2000))\n",
  1310. " running_loss = 0.0\n",
  1311. "print('Finished Training')"
  1312. ]
  1313. },
  1314. {
  1315. "cell_type": "markdown",
  1316. "metadata": {},
  1317. "source": [
  1318. "此处仅训练了2个epoch(遍历完一遍数据集称为一个epoch),来看看网络有没有效果。将测试图片输入到网络中,计算它的label,然后与实际的label进行比较。"
  1319. ]
  1320. },
  1321. {
  1322. "cell_type": "code",
  1323. "execution_count": null,
  1324. "metadata": {
  1325. "lines_to_next_cell": 2
  1326. },
  1327. "outputs": [],
  1328. "source": [
  1329. "dataiter = iter(testloader)\n",
  1330. "images, labels = dataiter.next() # 一个batch返回4张图片\n",
  1331. "print('实际的label: ', ' '.join(\\\n",
  1332. " '%08s'%classes[labels[j]] for j in range(4)))\n",
  1333. "show(tv.utils.make_grid(images / 2 - 0.5)).resize((400,100))"
  1334. ]
  1335. },
  1336. {
  1337. "cell_type": "markdown",
  1338. "metadata": {},
  1339. "source": [
  1340. "接着计算网络预测的label:"
  1341. ]
  1342. },
  1343. {
  1344. "cell_type": "code",
  1345. "execution_count": 12,
  1346. "metadata": {},
  1347. "outputs": [
  1348. {
  1349. "name": "stdout",
  1350. "output_type": "stream",
  1351. "text": [
  1352. "预测结果: cat ship ship ship\n"
  1353. ]
  1354. }
  1355. ],
  1356. "source": [
  1357. "# 计算图片在每个类别上的分数\n",
  1358. "outputs = net(Variable(images))\n",
  1359. "# 得分最高的那个类\n",
  1360. "_, predicted = t.max(outputs.data, 1)\n",
  1361. "\n",
  1362. "print('预测结果: ', ' '.join('%5s'\\\n",
  1363. " % classes[predicted[j]] for j in range(4)))"
  1364. ]
  1365. },
  1366. {
  1367. "cell_type": "markdown",
  1368. "metadata": {},
  1369. "source": [
  1370. "已经可以看出效果,准确率50%,但这只是一部分的图片,再来看看在整个测试集上的效果。"
  1371. ]
  1372. },
  1373. {
  1374. "cell_type": "code",
  1375. "execution_count": 13,
  1376. "metadata": {},
  1377. "outputs": [
  1378. {
  1379. "name": "stdout",
  1380. "output_type": "stream",
  1381. "text": [
  1382. "10000张测试集中的准确率为: 54 %\n"
  1383. ]
  1384. }
  1385. ],
  1386. "source": [
  1387. "correct = 0 # 预测正确的图片数\n",
  1388. "total = 0 # 总共的图片数\n",
  1389. "for data in testloader:\n",
  1390. " images, labels = data\n",
  1391. " outputs = net(Variable(images))\n",
  1392. " _, predicted = t.max(outputs.data, 1)\n",
  1393. " total += labels.size(0)\n",
  1394. " correct += (predicted == labels).sum()\n",
  1395. "\n",
  1396. "print('10000张测试集中的准确率为: %d %%' % (100 * correct / total))"
  1397. ]
  1398. },
  1399. {
  1400. "cell_type": "markdown",
  1401. "metadata": {},
  1402. "source": [
  1403. "训练的准确率远比随机猜测(准确率10%)好,证明网络确实学到了东西。"
  1404. ]
  1405. },
  1406. {
  1407. "cell_type": "markdown",
  1408. "metadata": {},
  1409. "source": [
  1410. "### 4.5 在GPU训练\n",
  1411. "就像之前把Tensor从CPU转到GPU一样,模型也可以类似地从CPU转到GPU。"
  1412. ]
  1413. },
  1414. {
  1415. "cell_type": "code",
  1416. "execution_count": 44,
  1417. "metadata": {},
  1418. "outputs": [],
  1419. "source": [
  1420. "if t.cuda.is_available():\n",
  1421. " net.cuda()\n",
  1422. " images = images.cuda()\n",
  1423. " labels = labels.cuda()\n",
  1424. " output = net(Variable(images))\n",
  1425. " loss= criterion(output,Variable(labels))"
  1426. ]
  1427. },
  1428. {
  1429. "cell_type": "markdown",
  1430. "metadata": {},
  1431. "source": [
  1432. "如果发现在GPU上并没有比CPU提速很多,实际上是因为网络比较小,GPU没有完全发挥自己的真正实力。"
  1433. ]
  1434. },
  1435. {
  1436. "cell_type": "markdown",
  1437. "metadata": {},
  1438. "source": [
  1439. "对PyTorch的基础介绍至此结束。总结一下,本节主要包含以下内容。\n",
  1440. "\n",
  1441. "1. Tensor: 类似Numpy数组的数据结构,与Numpy接口类似,可方便地互相转换。\n",
  1442. "2. autograd/Variable: Variable封装了Tensor,并提供自动求导功能。\n",
  1443. "3. nn: 专门为神经网络设计的接口,提供了很多有用的功能(神经网络层,损失函数,优化器等)。\n",
  1444. "4. 神经网络训练: 以CIFAR-10分类为例演示了神经网络的训练流程,包括数据加载、网络搭建、训练及测试。\n",
  1445. "\n",
  1446. "通过本节的学习,相信读者可以体会出PyTorch具有接口简单、使用灵活等特点。从下一章开始,本书将深入系统地讲解PyTorch的各部分知识。"
  1447. ]
  1448. }
  1449. ],
  1450. "metadata": {
  1451. "kernelspec": {
  1452. "display_name": "Python 3",
  1453. "language": "python",
  1454. "name": "python3"
  1455. },
  1456. "language_info": {
  1457. "codemirror_mode": {
  1458. "name": "ipython",
  1459. "version": 3
  1460. },
  1461. "file_extension": ".py",
  1462. "mimetype": "text/x-python",
  1463. "name": "python",
  1464. "nbconvert_exporter": "python",
  1465. "pygments_lexer": "ipython3",
  1466. "version": "3.6.9"
  1467. }
  1468. },
  1469. "nbformat": 4,
  1470. "nbformat_minor": 2
  1471. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。