You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

PyTorch快速入门.ipynb 85 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago

  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "## 2.2 PyTorch第一步\n",
  8. "\n",
  9. "PyTorch的简洁设计使得它入门很简单,在深入介绍PyTorch之前,本节将先介绍一些PyTorch的基础知识,使得读者能够对PyTorch有一个大致的了解,并能够用PyTorch搭建一个简单的神经网络。部分内容读者可能暂时不太理解,可先不予以深究,本书的第3章和第4章将会对此进行深入讲解。\n",
  10. "\n",
  11. "本节内容参考了PyTorch官方教程[^1]并做了相应的增删修改,使得内容更贴合新版本的PyTorch接口,同时也更适合新手快速入门。另外本书需要读者先掌握基础的Numpy使用,其他相关知识推荐读者参考CS231n的教程[^2]。\n",
  12. "\n",
  13. "[^1]: http://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html\n",
  14. "[^2]: http://cs231n.github.io/python-numpy-tutorial/"
  15. ]
  16. },
  17. {
  18. "cell_type": "markdown",
  19. "metadata": {},
  20. "source": [
  21. "### Tensor\n",
  22. "\n",
  23. "Tensor是PyTorch中重要的数据结构,可认为是一个高维数组。它可以是一个数(标量)、一维数组(向量)、二维数组(矩阵)以及更高维的数组。Tensor和Numpy的ndarrays类似,但Tensor可以使用GPU进行加速。Tensor的使用和Numpy及Matlab的接口十分相似,下面通过几个例子来看看Tensor的基本使用。"
  24. ]
  25. },
  26. {
  27. "cell_type": "code",
  28. "execution_count": null,
  29. "metadata": {},
  30. "outputs": [
  31. {
  32. "ename": "ImportError",
  33. "evalue": "No module named 'torch'",
  34. "traceback": [
  35. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  36. "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
  37. "\u001b[0;32m<ipython-input-1-93bdd78e6769>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0m__future__\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mprint_function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
  38. "\u001b[0;31mImportError\u001b[0m: No module named 'torch'"
  39. ],
  40. "output_type": "error"
  41. }
  42. ],
  43. "source": [
  44. "from __future__ import print_function\n",
  45. "import torch as t"
  46. ]
  47. },
  48. {
  49. "cell_type": "code",
  50. "execution_count": 2,
  51. "metadata": {},
  52. "outputs": [
  53. {
  54. "data": {
  55. "text/plain": [
  56. "\n",
  57. "1.00000e-07 *\n",
  58. " 0.0000 0.0000 5.3571\n",
  59. " 0.0000 0.0000 0.0000\n",
  60. " 0.0000 0.0000 0.0000\n",
  61. " 0.0000 5.4822 0.0000\n",
  62. " 5.4823 0.0000 5.4823\n",
  63. "[torch.FloatTensor of size 5x3]"
  64. ]
  65. },
  66. "execution_count": 2,
  67. "metadata": {},
  68. "output_type": "execute_result"
  69. }
  70. ],
  71. "source": [
  72. "# 构建 5x3 矩阵,只是分配了空间,未初始化\n",
  73. "x = t.Tensor(5, 3) \n",
  74. "x"
  75. ]
  76. },
  77. {
  78. "cell_type": "code",
  79. "execution_count": 3,
  80. "metadata": {},
  81. "outputs": [
  82. {
  83. "data": {
  84. "text/plain": [
  85. "\n",
  86. " 0.3673 0.2522 0.3553\n",
  87. " 0.0070 0.7138 0.0463\n",
  88. " 0.6198 0.6019 0.3752\n",
  89. " 0.4755 0.3675 0.3032\n",
  90. " 0.5824 0.5104 0.5759\n",
  91. "[torch.FloatTensor of size 5x3]"
  92. ]
  93. },
  94. "execution_count": 3,
  95. "metadata": {},
  96. "output_type": "execute_result"
  97. }
  98. ],
  99. "source": [
  100. "# 使用[0,1]均匀分布随机初始化二维数组\n",
  101. "x = t.rand(5, 3) \n",
  102. "x"
  103. ]
  104. },
  105. {
  106. "cell_type": "code",
  107. "execution_count": 4,
  108. "metadata": {},
  109. "outputs": [
  110. {
  111. "name": "stdout",
  112. "output_type": "stream",
  113. "text": [
  114. "torch.Size([5, 3])\n"
  115. ]
  116. },
  117. {
  118. "data": {
  119. "text/plain": [
  120. "(3, 3)"
  121. ]
  122. },
  123. "execution_count": 4,
  124. "metadata": {},
  125. "output_type": "execute_result"
  126. }
  127. ],
  128. "source": [
  129. "print(x.size()) # 查看x的形状\n",
  130. "x.size()[1], x.size(1) # 查看列的个数, 两种写法等价"
  131. ]
  132. },
  133. {
  134. "cell_type": "markdown",
  135. "metadata": {},
  136. "source": [
  137. "`torch.Size` 是tuple对象的子类,因此它支持tuple的所有操作,如x.size()[0]"
  138. ]
  139. },
  140. {
  141. "cell_type": "code",
  142. "execution_count": 5,
  143. "metadata": {},
  144. "outputs": [
  145. {
  146. "data": {
  147. "text/plain": [
  148. "\n",
  149. " 0.4063 0.7378 1.2411\n",
  150. " 0.0687 0.7725 0.0634\n",
  151. " 1.1016 1.4291 0.7324\n",
  152. " 0.7604 1.2880 0.4597\n",
  153. " 0.6020 1.0124 1.0185\n",
  154. "[torch.FloatTensor of size 5x3]"
  155. ]
  156. },
  157. "execution_count": 5,
  158. "metadata": {},
  159. "output_type": "execute_result"
  160. }
  161. ],
  162. "source": [
  163. "y = t.rand(5, 3)\n",
  164. "# 加法的第一种写法\n",
  165. "x + y"
  166. ]
  167. },
  168. {
  169. "cell_type": "code",
  170. "execution_count": 6,
  171. "metadata": {},
  172. "outputs": [
  173. {
  174. "data": {
  175. "text/plain": [
  176. "\n",
  177. " 0.4063 0.7378 1.2411\n",
  178. " 0.0687 0.7725 0.0634\n",
  179. " 1.1016 1.4291 0.7324\n",
  180. " 0.7604 1.2880 0.4597\n",
  181. " 0.6020 1.0124 1.0185\n",
  182. "[torch.FloatTensor of size 5x3]"
  183. ]
  184. },
  185. "execution_count": 6,
  186. "metadata": {},
  187. "output_type": "execute_result"
  188. }
  189. ],
  190. "source": [
  191. "# 加法的第二种写法\n",
  192. "t.add(x, y)"
  193. ]
  194. },
  195. {
  196. "cell_type": "code",
  197. "execution_count": 7,
  198. "metadata": {},
  199. "outputs": [
  200. {
  201. "data": {
  202. "text/plain": [
  203. "\n",
  204. " 0.4063 0.7378 1.2411\n",
  205. " 0.0687 0.7725 0.0634\n",
  206. " 1.1016 1.4291 0.7324\n",
  207. " 0.7604 1.2880 0.4597\n",
  208. " 0.6020 1.0124 1.0185\n",
  209. "[torch.FloatTensor of size 5x3]"
  210. ]
  211. },
  212. "execution_count": 7,
  213. "metadata": {},
  214. "output_type": "execute_result"
  215. }
  216. ],
  217. "source": [
  218. "# 加法的第三种写法:指定加法结果的输出目标为result\n",
  219. "result = t.Tensor(5, 3) # 预先分配空间\n",
  220. "t.add(x, y, out=result) # 输入到result\n",
  221. "result"
  222. ]
  223. },
  224. {
  225. "cell_type": "code",
  226. "execution_count": 8,
  227. "metadata": {},
  228. "outputs": [
  229. {
  230. "name": "stdout",
  231. "output_type": "stream",
  232. "text": [
  233. "最初y\n",
  234. "\n",
  235. " 0.0390 0.4856 0.8858\n",
  236. " 0.0617 0.0587 0.0171\n",
  237. " 0.4818 0.8272 0.3572\n",
  238. " 0.2849 0.9205 0.1565\n",
  239. " 0.0196 0.5020 0.4426\n",
  240. "[torch.FloatTensor of size 5x3]\n",
  241. "\n",
  242. "第一种加法,y的结果\n",
  243. "\n",
  244. " 0.0390 0.4856 0.8858\n",
  245. " 0.0617 0.0587 0.0171\n",
  246. " 0.4818 0.8272 0.3572\n",
  247. " 0.2849 0.9205 0.1565\n",
  248. " 0.0196 0.5020 0.4426\n",
  249. "[torch.FloatTensor of size 5x3]\n",
  250. "\n",
  251. "第二种加法,y的结果\n",
  252. "\n",
  253. " 0.4063 0.7378 1.2411\n",
  254. " 0.0687 0.7725 0.0634\n",
  255. " 1.1016 1.4291 0.7324\n",
  256. " 0.7604 1.2880 0.4597\n",
  257. " 0.6020 1.0124 1.0185\n",
  258. "[torch.FloatTensor of size 5x3]\n",
  259. "\n"
  260. ]
  261. }
  262. ],
  263. "source": [
  264. "print('最初y')\n",
  265. "print(y)\n",
  266. "\n",
  267. "print('第一种加法,y的结果')\n",
  268. "y.add(x) # 普通加法,不改变y的内容\n",
  269. "print(y)\n",
  270. "\n",
  271. "print('第二种加法,y的结果')\n",
  272. "y.add_(x) # inplace 加法,y变了\n",
  273. "print(y)"
  274. ]
  275. },
  276. {
  277. "cell_type": "markdown",
  278. "metadata": {},
  279. "source": [
  280. "注意,函数名后面带下划线**`_`** 的函数会修改Tensor本身。例如,`x.add_(y)`和`x.t_()`会改变 `x`,但`x.add(y)`和`x.t()`返回一个新的Tensor, 而`x`不变。"
  281. ]
  282. },
  283. {
  284. "cell_type": "code",
  285. "execution_count": 9,
  286. "metadata": {},
  287. "outputs": [
  288. {
  289. "data": {
  290. "text/plain": [
  291. "\n",
  292. " 0.2522\n",
  293. " 0.7138\n",
  294. " 0.6019\n",
  295. " 0.3675\n",
  296. " 0.5104\n",
  297. "[torch.FloatTensor of size 5]"
  298. ]
  299. },
  300. "execution_count": 9,
  301. "metadata": {},
  302. "output_type": "execute_result"
  303. }
  304. ],
  305. "source": [
  306. "# Tensor的选取操作与Numpy类似\n",
  307. "x[:, 1]"
  308. ]
  309. },
  310. {
  311. "cell_type": "markdown",
  312. "metadata": {},
  313. "source": [
  314. "Tensor还支持很多操作,包括数学运算、线性代数、选择、切片等等,其接口设计与Numpy极为相似。更详细的使用方法,会在第三章系统讲解。\n",
  315. "\n",
  316. "Tensor和Numpy的数组之间的互操作非常容易且快速。对于Tensor不支持的操作,可以先转为Numpy数组处理,之后再转回Tensor。"
  317. ]
  318. },
  319. {
  320. "cell_type": "code",
  321. "execution_count": 10,
  322. "metadata": {},
  323. "outputs": [
  324. {
  325. "data": {
  326. "text/plain": [
  327. "\n",
  328. " 1\n",
  329. " 1\n",
  330. " 1\n",
  331. " 1\n",
  332. " 1\n",
  333. "[torch.FloatTensor of size 5]"
  334. ]
  335. },
  336. "execution_count": 10,
  337. "metadata": {},
  338. "output_type": "execute_result"
  339. }
  340. ],
  341. "source": [
  342. "a = t.ones(5) # 新建一个全1的Tensor\n",
  343. "a"
  344. ]
  345. },
  346. {
  347. "cell_type": "code",
  348. "execution_count": 11,
  349. "metadata": {},
  350. "outputs": [
  351. {
  352. "data": {
  353. "text/plain": [
  354. "array([1., 1., 1., 1., 1.], dtype=float32)"
  355. ]
  356. },
  357. "execution_count": 11,
  358. "metadata": {},
  359. "output_type": "execute_result"
  360. }
  361. ],
  362. "source": [
  363. "b = a.numpy() # Tensor -> Numpy\n",
  364. "b"
  365. ]
  366. },
  367. {
  368. "cell_type": "code",
  369. "execution_count": 12,
  370. "metadata": {},
  371. "outputs": [
  372. {
  373. "name": "stdout",
  374. "output_type": "stream",
  375. "text": [
  376. "[1. 1. 1. 1. 1.]\n",
  377. "\n",
  378. " 1\n",
  379. " 1\n",
  380. " 1\n",
  381. " 1\n",
  382. " 1\n",
  383. "[torch.DoubleTensor of size 5]\n",
  384. "\n"
  385. ]
  386. }
  387. ],
  388. "source": [
  389. "import numpy as np\n",
  390. "a = np.ones(5)\n",
  391. "b = t.from_numpy(a) # Numpy->Tensor\n",
  392. "print(a)\n",
  393. "print(b) "
  394. ]
  395. },
  396. {
  397. "cell_type": "markdown",
  398. "metadata": {},
  399. "source": [
  400. "Tensor和numpy对象共享内存,所以他们之间的转换很快,而且几乎不会消耗什么资源。但这也意味着,如果其中一个变了,另外一个也会随之改变。"
  401. ]
  402. },
  403. {
  404. "cell_type": "code",
  405. "execution_count": 13,
  406. "metadata": {},
  407. "outputs": [
  408. {
  409. "name": "stdout",
  410. "output_type": "stream",
  411. "text": [
  412. "[2. 2. 2. 2. 2.]\n",
  413. "\n",
  414. " 2\n",
  415. " 2\n",
  416. " 2\n",
  417. " 2\n",
  418. " 2\n",
  419. "[torch.DoubleTensor of size 5]\n",
  420. "\n"
  421. ]
  422. }
  423. ],
  424. "source": [
  425. "b.add_(1) # 以`_`结尾的函数会修改自身\n",
  426. "print(a)\n",
  427. "print(b) # Tensor和Numpy共享内存"
  428. ]
  429. },
  430. {
  431. "cell_type": "markdown",
  432. "metadata": {},
  433. "source": [
  434. "Tensor可通过`.cuda` 方法转为GPU的Tensor,从而享受GPU带来的加速运算。"
  435. ]
  436. },
  437. {
  438. "cell_type": "code",
  439. "execution_count": 14,
  440. "metadata": {},
  441. "outputs": [],
  442. "source": [
  443. "# 在不支持CUDA的机器下,下一步不会运行\n",
  444. "if t.cuda.is_available():\n",
  445. " x = x.cuda()\n",
  446. " y = y.cuda()\n",
  447. " x + y"
  448. ]
  449. },
  450. {
  451. "cell_type": "markdown",
  452. "metadata": {},
  453. "source": [
  454. "此处可能发现GPU运算的速度并未提升太多,这是因为x和y太小且运算也较为简单,而且将数据从内存转移到显存还需要花费额外的开销。GPU的优势需在大规模数据和复杂运算下才能体现出来。\n",
  455. "\n",
  456. "### Autograd: 自动微分\n",
  457. "\n",
  458. "深度学习的算法本质上是通过反向传播求导数,而PyTorch的**`Autograd`**模块则实现了此功能。在Tensor上的所有操作,Autograd都能为它们自动提供微分,避免了手动计算导数的复杂过程。\n",
  459. " \n",
  460. "`autograd.Variable`是Autograd中的核心类,它简单封装了Tensor,并支持几乎所有Tensor有的操作。Tensor在被封装为Variable之后,可以调用它的`.backward`实现反向传播,自动计算所有梯度。Variable的数据结构如图2-6所示。\n",
  461. "\n",
  462. "\n",
  463. "![图2-6:Variable的数据结构](imgs/autograd_Variable.svg)\n",
  464. "\n",
  465. "\n",
  466. "Variable主要包含三个属性。\n",
  467. "- `data`:保存Variable所包含的Tensor\n",
  468. "- `grad`:保存`data`对应的梯度,`grad`也是个Variable,而不是Tensor,它和`data`的形状一样。\n",
  469. "- `grad_fn`:指向一个`Function`对象,这个`Function`用来反向传播计算输入的梯度,具体细节会在下一章讲解。"
  470. ]
  471. },
  472. {
  473. "cell_type": "code",
  474. "execution_count": 15,
  475. "metadata": {},
  476. "outputs": [],
  477. "source": [
  478. "from torch.autograd import Variable"
  479. ]
  480. },
  481. {
  482. "cell_type": "code",
  483. "execution_count": 16,
  484. "metadata": {
  485. "scrolled": true
  486. },
  487. "outputs": [
  488. {
  489. "data": {
  490. "text/plain": [
  491. "Variable containing:\n",
  492. " 1 1\n",
  493. " 1 1\n",
  494. "[torch.FloatTensor of size 2x2]"
  495. ]
  496. },
  497. "execution_count": 16,
  498. "metadata": {},
  499. "output_type": "execute_result"
  500. }
  501. ],
  502. "source": [
  503. "# 使用Tensor新建一个Variable\n",
  504. "x = Variable(t.ones(2, 2), requires_grad = True)\n",
  505. "x"
  506. ]
  507. },
  508. {
  509. "cell_type": "code",
  510. "execution_count": 17,
  511. "metadata": {
  512. "scrolled": true
  513. },
  514. "outputs": [
  515. {
  516. "data": {
  517. "text/plain": [
  518. "Variable containing:\n",
  519. " 4\n",
  520. "[torch.FloatTensor of size 1]"
  521. ]
  522. },
  523. "execution_count": 17,
  524. "metadata": {},
  525. "output_type": "execute_result"
  526. }
  527. ],
  528. "source": [
  529. "y = x.sum()\n",
  530. "y"
  531. ]
  532. },
  533. {
  534. "cell_type": "code",
  535. "execution_count": 18,
  536. "metadata": {},
  537. "outputs": [
  538. {
  539. "data": {
  540. "text/plain": [
  541. "<SumBackward0 at 0x7fc14824b860>"
  542. ]
  543. },
  544. "execution_count": 18,
  545. "metadata": {},
  546. "output_type": "execute_result"
  547. }
  548. ],
  549. "source": [
  550. "y.grad_fn"
  551. ]
  552. },
  553. {
  554. "cell_type": "code",
  555. "execution_count": 19,
  556. "metadata": {},
  557. "outputs": [],
  558. "source": [
  559. "y.backward() # 反向传播,计算梯度"
  560. ]
  561. },
  562. {
  563. "cell_type": "code",
  564. "execution_count": 20,
  565. "metadata": {},
  566. "outputs": [
  567. {
  568. "data": {
  569. "text/plain": [
  570. "Variable containing:\n",
  571. " 1 1\n",
  572. " 1 1\n",
  573. "[torch.FloatTensor of size 2x2]"
  574. ]
  575. },
  576. "execution_count": 20,
  577. "metadata": {},
  578. "output_type": "execute_result"
  579. }
  580. ],
  581. "source": [
  582. "# y = x.sum() = (x[0][0] + x[0][1] + x[1][0] + x[1][1])\n",
  583. "# 每个值的梯度都为1\n",
  584. "x.grad "
  585. ]
  586. },
  587. {
  588. "cell_type": "markdown",
  589. "metadata": {},
  590. "source": [
  591. "注意:`grad`在反向传播过程中是累加的(accumulated),这意味着每一次运行反向传播,梯度都会累加之前的梯度,所以反向传播之前需把梯度清零。"
  592. ]
  593. },
  594. {
  595. "cell_type": "code",
  596. "execution_count": 21,
  597. "metadata": {},
  598. "outputs": [
  599. {
  600. "data": {
  601. "text/plain": [
  602. "Variable containing:\n",
  603. " 2 2\n",
  604. " 2 2\n",
  605. "[torch.FloatTensor of size 2x2]"
  606. ]
  607. },
  608. "execution_count": 21,
  609. "metadata": {},
  610. "output_type": "execute_result"
  611. }
  612. ],
  613. "source": [
  614. "y.backward()\n",
  615. "x.grad"
  616. ]
  617. },
  618. {
  619. "cell_type": "code",
  620. "execution_count": 22,
  621. "metadata": {
  622. "scrolled": true
  623. },
  624. "outputs": [
  625. {
  626. "data": {
  627. "text/plain": [
  628. "Variable containing:\n",
  629. " 3 3\n",
  630. " 3 3\n",
  631. "[torch.FloatTensor of size 2x2]"
  632. ]
  633. },
  634. "execution_count": 22,
  635. "metadata": {},
  636. "output_type": "execute_result"
  637. }
  638. ],
  639. "source": [
  640. "y.backward()\n",
  641. "x.grad"
  642. ]
  643. },
  644. {
  645. "cell_type": "code",
  646. "execution_count": 23,
  647. "metadata": {},
  648. "outputs": [
  649. {
  650. "data": {
  651. "text/plain": [
  652. "\n",
  653. " 0 0\n",
  654. " 0 0\n",
  655. "[torch.FloatTensor of size 2x2]"
  656. ]
  657. },
  658. "execution_count": 23,
  659. "metadata": {},
  660. "output_type": "execute_result"
  661. }
  662. ],
  663. "source": [
  664. "# 以下划线结束的函数是inplace操作,就像add_\n",
  665. "x.grad.data.zero_()"
  666. ]
  667. },
  668. {
  669. "cell_type": "code",
  670. "execution_count": 24,
  671. "metadata": {},
  672. "outputs": [
  673. {
  674. "data": {
  675. "text/plain": [
  676. "Variable containing:\n",
  677. " 1 1\n",
  678. " 1 1\n",
  679. "[torch.FloatTensor of size 2x2]"
  680. ]
  681. },
  682. "execution_count": 24,
  683. "metadata": {},
  684. "output_type": "execute_result"
  685. }
  686. ],
  687. "source": [
  688. "y.backward()\n",
  689. "x.grad"
  690. ]
  691. },
  692. {
  693. "cell_type": "markdown",
  694. "metadata": {},
  695. "source": [
  696. "Variable和Tensor具有近乎一致的接口,在实际使用中可以无缝切换。"
  697. ]
  698. },
  699. {
  700. "cell_type": "code",
  701. "execution_count": 25,
  702. "metadata": {},
  703. "outputs": [
  704. {
  705. "name": "stdout",
  706. "output_type": "stream",
  707. "text": [
  708. "Variable containing:\n",
  709. " 0.5403 0.5403 0.5403 0.5403 0.5403\n",
  710. " 0.5403 0.5403 0.5403 0.5403 0.5403\n",
  711. " 0.5403 0.5403 0.5403 0.5403 0.5403\n",
  712. " 0.5403 0.5403 0.5403 0.5403 0.5403\n",
  713. "[torch.FloatTensor of size 4x5]\n",
  714. "\n"
  715. ]
  716. },
  717. {
  718. "data": {
  719. "text/plain": [
  720. "\n",
  721. " 0.5403 0.5403 0.5403 0.5403 0.5403\n",
  722. " 0.5403 0.5403 0.5403 0.5403 0.5403\n",
  723. " 0.5403 0.5403 0.5403 0.5403 0.5403\n",
  724. " 0.5403 0.5403 0.5403 0.5403 0.5403\n",
  725. "[torch.FloatTensor of size 4x5]"
  726. ]
  727. },
  728. "execution_count": 25,
  729. "metadata": {},
  730. "output_type": "execute_result"
  731. }
  732. ],
  733. "source": [
  734. "x = Variable(t.ones(4,5))\n",
  735. "y = t.cos(x)\n",
  736. "x_tensor_cos = t.cos(x.data)\n",
  737. "print(y)\n",
  738. "x_tensor_cos"
  739. ]
  740. },
  741. {
  742. "cell_type": "markdown",
  743. "metadata": {},
  744. "source": [
  745. "### 神经网络\n",
  746. "\n",
  747. "Autograd实现了反向传播功能,但是直接用来写深度学习的代码在很多情况下还是稍显复杂,torch.nn是专门为神经网络设计的模块化接口。nn构建于 Autograd之上,可用来定义和运行神经网络。nn.Module是nn中最重要的类,可把它看成是一个网络的封装,包含网络各层定义以及forward方法,调用forward(input)方法,可返回前向传播的结果。下面就以最早的卷积神经网络:LeNet为例,来看看如何用`nn.Module`实现。LeNet的网络结构如图2-7所示。\n",
  748. "\n",
  749. "![图2-7:LeNet网络结构](imgs/nn_lenet.png)\n",
  750. "\n",
  751. "这是一个基础的前向传播(feed-forward)网络: 接收输入,经过层层传递运算,得到输出。\n",
  752. "\n",
  753. "#### 定义网络\n",
  754. "\n",
  755. "定义网络时,需要继承`nn.Module`,并实现它的forward方法,把网络中具有可学习参数的层放在构造函数`__init__`中。如果某一层(如ReLU)不具有可学习的参数,则既可以放在构造函数中,也可以不放,但建议不放在其中,而在forward中使用`nn.functional`代替。"
  756. ]
  757. },
  758. {
  759. "cell_type": "code",
  760. "execution_count": 26,
  761. "metadata": {},
  762. "outputs": [
  763. {
  764. "name": "stdout",
  765. "output_type": "stream",
  766. "text": [
  767. "Net(\n",
  768. " (conv1): Conv2d (1, 6, kernel_size=(5, 5), stride=(1, 1))\n",
  769. " (conv2): Conv2d (6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
  770. " (fc1): Linear(in_features=400, out_features=120)\n",
  771. " (fc2): Linear(in_features=120, out_features=84)\n",
  772. " (fc3): Linear(in_features=84, out_features=10)\n",
  773. ")\n"
  774. ]
  775. }
  776. ],
  777. "source": [
  778. "import torch.nn as nn\n",
  779. "import torch.nn.functional as F\n",
  780. "\n",
  781. "class Net(nn.Module):\n",
  782. " def __init__(self):\n",
  783. " # nn.Module子类的函数必须在构造函数中执行父类的构造函数\n",
  784. " # 下式等价于nn.Module.__init__(self)\n",
  785. " super(Net, self).__init__()\n",
  786. " \n",
  787. " # 卷积层 '1'表示输入图片为单通道, '6'表示输出通道数,'5'表示卷积核为5*5\n",
  788. " self.conv1 = nn.Conv2d(1, 6, 5) \n",
  789. " # 卷积层\n",
  790. " self.conv2 = nn.Conv2d(6, 16, 5) \n",
  791. " # 仿射层/全连接层,y = Wx + b\n",
  792. " self.fc1 = nn.Linear(16*5*5, 120) \n",
  793. " self.fc2 = nn.Linear(120, 84)\n",
  794. " self.fc3 = nn.Linear(84, 10)\n",
  795. "\n",
  796. " def forward(self, x): \n",
  797. " # 卷积 -> 激活 -> 池化 \n",
  798. " x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
  799. " x = F.max_pool2d(F.relu(self.conv2(x)), 2) \n",
  800. " # reshape,‘-1’表示自适应\n",
  801. " x = x.view(x.size()[0], -1) \n",
  802. " x = F.relu(self.fc1(x))\n",
  803. " x = F.relu(self.fc2(x))\n",
  804. " x = self.fc3(x) \n",
  805. " return x\n",
  806. "\n",
  807. "net = Net()\n",
  808. "print(net)"
  809. ]
  810. },
  811. {
  812. "cell_type": "markdown",
  813. "metadata": {},
  814. "source": [
  815. "只要在nn.Module的子类中定义了forward函数,backward函数就会自动被实现(利用`Autograd`)。在`forward` 函数中可使用任何Variable支持的函数,还可以使用if、for循环、print、log等Python语法,写法和标准的Python写法一致。\n",
  816. "\n",
  817. "网络的可学习参数通过`net.parameters()`返回,`net.named_parameters`可同时返回可学习的参数及名称。"
  818. ]
  819. },
  820. {
  821. "cell_type": "code",
  822. "execution_count": 27,
  823. "metadata": {},
  824. "outputs": [
  825. {
  826. "name": "stdout",
  827. "output_type": "stream",
  828. "text": [
  829. "10\n"
  830. ]
  831. }
  832. ],
  833. "source": [
  834. "params = list(net.parameters())\n",
  835. "print(len(params))"
  836. ]
  837. },
  838. {
  839. "cell_type": "code",
  840. "execution_count": 28,
  841. "metadata": {},
  842. "outputs": [
  843. {
  844. "name": "stdout",
  845. "output_type": "stream",
  846. "text": [
  847. "conv1.weight : torch.Size([6, 1, 5, 5])\n",
  848. "conv1.bias : torch.Size([6])\n",
  849. "conv2.weight : torch.Size([16, 6, 5, 5])\n",
  850. "conv2.bias : torch.Size([16])\n",
  851. "fc1.weight : torch.Size([120, 400])\n",
  852. "fc1.bias : torch.Size([120])\n",
  853. "fc2.weight : torch.Size([84, 120])\n",
  854. "fc2.bias : torch.Size([84])\n",
  855. "fc3.weight : torch.Size([10, 84])\n",
  856. "fc3.bias : torch.Size([10])\n"
  857. ]
  858. }
  859. ],
  860. "source": [
  861. "for name,parameters in net.named_parameters():\n",
  862. " print(name,':',parameters.size())"
  863. ]
  864. },
  865. {
  866. "cell_type": "markdown",
  867. "metadata": {},
  868. "source": [
  869. "forward函数的输入和输出都是Variable,只有Variable才具有自动求导功能,而Tensor是没有的,所以在输入时,需把Tensor封装成Variable。"
  870. ]
  871. },
  872. {
  873. "cell_type": "code",
  874. "execution_count": 29,
  875. "metadata": {
  876. "scrolled": true
  877. },
  878. "outputs": [
  879. {
  880. "data": {
  881. "text/plain": [
  882. "torch.Size([1, 10])"
  883. ]
  884. },
  885. "execution_count": 29,
  886. "metadata": {},
  887. "output_type": "execute_result"
  888. }
  889. ],
  890. "source": [
  891. "input = Variable(t.randn(1, 1, 32, 32))\n",
  892. "out = net(input)\n",
  893. "out.size()"
  894. ]
  895. },
  896. {
  897. "cell_type": "code",
  898. "execution_count": 30,
  899. "metadata": {},
  900. "outputs": [],
  901. "source": [
  902. "net.zero_grad() # 所有参数的梯度清零\n",
  903. "out.backward(Variable(t.ones(1,10))) # 反向传播"
  904. ]
  905. },
  906. {
  907. "cell_type": "markdown",
  908. "metadata": {},
  909. "source": [
  910. "需要注意的是,torch.nn只支持mini-batches,不支持一次只输入一个样本,即一次必须是一个batch。但如果只想输入一个样本,则用 `input.unsqueeze(0)`将batch_size设为1。例如 `nn.Conv2d` 输入必须是4维的,形如$nSamples \\times nChannels \\times Height \\times Width$。可将nSample设为1,即$1 \\times nChannels \\times Height \\times Width$。"
  911. ]
  912. },
  913. {
  914. "cell_type": "markdown",
  915. "metadata": {},
  916. "source": [
  917. "#### 损失函数\n",
  918. "\n",
  919. "nn实现了神经网络中大多数的损失函数,例如nn.MSELoss用来计算均方误差,nn.CrossEntropyLoss用来计算交叉熵损失。"
  920. ]
  921. },
  922. {
  923. "cell_type": "code",
  924. "execution_count": 31,
  925. "metadata": {
  926. "scrolled": true
  927. },
  928. "outputs": [
  929. {
  930. "data": {
  931. "text/plain": [
  932. "Variable containing:\n",
  933. " 28.5536\n",
  934. "[torch.FloatTensor of size 1]"
  935. ]
  936. },
  937. "execution_count": 31,
  938. "metadata": {},
  939. "output_type": "execute_result"
  940. }
  941. ],
  942. "source": [
  943. "output = net(input)\n",
  944. "target = Variable(t.arange(0,10)) \n",
  945. "criterion = nn.MSELoss()\n",
  946. "loss = criterion(output, target)\n",
  947. "loss"
  948. ]
  949. },
  950. {
  951. "cell_type": "markdown",
  952. "metadata": {},
  953. "source": [
  954. "如果对loss进行反向传播溯源(使用`gradfn`属性),可看到它的计算图如下:\n",
  955. "\n",
  956. "```\n",
  957. "input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d \n",
  958. " -> view -> linear -> relu -> linear -> relu -> linear \n",
  959. " -> MSELoss\n",
  960. " -> loss\n",
  961. "```\n",
  962. "\n",
  963. "当调用`loss.backward()`时,该图会动态生成并自动微分,也即会自动计算图中参数(Parameter)的导数。"
  964. ]
  965. },
  966. {
  967. "cell_type": "code",
  968. "execution_count": 32,
  969. "metadata": {},
  970. "outputs": [
  971. {
  972. "name": "stdout",
  973. "output_type": "stream",
  974. "text": [
  975. "反向传播之前 conv1.bias的梯度\n",
  976. "Variable containing:\n",
  977. " 0\n",
  978. " 0\n",
  979. " 0\n",
  980. " 0\n",
  981. " 0\n",
  982. " 0\n",
  983. "[torch.FloatTensor of size 6]\n",
  984. "\n",
  985. "反向传播之后 conv1.bias的梯度\n",
  986. "Variable containing:\n",
  987. "1.00000e-02 *\n",
  988. " -4.2109\n",
  989. " -2.7638\n",
  990. " -5.8431\n",
  991. " 1.3761\n",
  992. " -2.4141\n",
  993. " -1.2015\n",
  994. "[torch.FloatTensor of size 6]\n",
  995. "\n"
  996. ]
  997. }
  998. ],
  999. "source": [
  1000. "# 运行.backward,观察调用之前和调用之后的grad\n",
  1001. "net.zero_grad() # 把net中所有可学习参数的梯度清零\n",
  1002. "print('反向传播之前 conv1.bias的梯度')\n",
  1003. "print(net.conv1.bias.grad)\n",
  1004. "loss.backward()\n",
  1005. "print('反向传播之后 conv1.bias的梯度')\n",
  1006. "print(net.conv1.bias.grad)"
  1007. ]
  1008. },
  1009. {
  1010. "cell_type": "markdown",
  1011. "metadata": {},
  1012. "source": [
  1013. "#### 优化器"
  1014. ]
  1015. },
  1016. {
  1017. "cell_type": "markdown",
  1018. "metadata": {},
  1019. "source": [
  1020. "在反向传播计算完所有参数的梯度后,还需要使用优化方法来更新网络的权重和参数,例如随机梯度下降法(SGD)的更新策略如下:\n",
  1021. "```\n",
  1022. "weight = weight - learning_rate * gradient\n",
  1023. "```\n",
  1024. "\n",
  1025. "手动实现如下:\n",
  1026. "\n",
  1027. "```python\n",
  1028. "learning_rate = 0.01\n",
  1029. "for f in net.parameters():\n",
  1030. " f.data.sub_(f.grad.data * learning_rate)# inplace 减法\n",
  1031. "```\n",
  1032. "\n",
  1033. "`torch.optim`中实现了深度学习中绝大多数的优化方法,例如RMSProp、Adam、SGD等,更便于使用,因此大多数时候并不需要手动写上述代码。"
  1034. ]
  1035. },
  1036. {
  1037. "cell_type": "code",
  1038. "execution_count": 33,
  1039. "metadata": {},
  1040. "outputs": [],
  1041. "source": [
  1042. "import torch.optim as optim\n",
  1043. "#新建一个优化器,指定要调整的参数和学习率\n",
  1044. "optimizer = optim.SGD(net.parameters(), lr = 0.01)\n",
  1045. "\n",
  1046. "# 在训练过程中\n",
  1047. "# 先梯度清零(与net.zero_grad()效果一样)\n",
  1048. "optimizer.zero_grad() \n",
  1049. "\n",
  1050. "# 计算损失\n",
  1051. "output = net(input)\n",
  1052. "loss = criterion(output, target)\n",
  1053. "\n",
  1054. "#反向传播\n",
  1055. "loss.backward()\n",
  1056. "\n",
  1057. "#更新参数\n",
  1058. "optimizer.step()"
  1059. ]
  1060. },
  1061. {
  1062. "cell_type": "markdown",
  1063. "metadata": {},
  1064. "source": [
  1065. "\n",
  1066. "\n",
  1067. "#### 数据加载与预处理\n",
  1068. "\n",
  1069. "在深度学习中数据加载及预处理是非常复杂繁琐的,但PyTorch提供了一些可极大简化和加快数据处理流程的工具。同时,对于常用的数据集,PyTorch也提供了封装好的接口供用户快速调用,这些数据集主要保存在torchvison中。\n",
  1070. "\n",
  1071. "`torchvision`实现了常用的图像数据加载功能,例如Imagenet、CIFAR10、MNIST等,以及常用的数据转换操作,这极大地方便了数据加载,并且代码具有可重用性。\n",
  1072. "\n",
  1073. "\n",
  1074. "### 小试牛刀:CIFAR-10分类\n",
  1075. "\n",
  1076. "下面我们来尝试实现对CIFAR-10数据集的分类,步骤如下: \n",
  1077. "\n",
  1078. "1. 使用torchvision加载并预处理CIFAR-10数据集\n",
  1079. "2. 定义网络\n",
  1080. "3. 定义损失函数和优化器\n",
  1081. "4. 训练网络并更新网络参数\n",
  1082. "5. 测试网络\n",
  1083. "\n",
  1084. "#### CIFAR-10数据加载及预处理\n",
  1085. "\n",
  1086. "CIFAR-10[^3]是一个常用的彩色图片数据集,它有10个类别: 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'。每张图片都是$3\\times32\\times32$,也即3-通道彩色图片,分辨率为$32\\times32$。\n",
  1087. "\n",
  1088. "[^3]: http://www.cs.toronto.edu/~kriz/cifar.html"
  1089. ]
  1090. },
  1091. {
  1092. "cell_type": "code",
  1093. "execution_count": 34,
  1094. "metadata": {},
  1095. "outputs": [],
  1096. "source": [
  1097. "import torchvision as tv\n",
  1098. "import torchvision.transforms as transforms\n",
  1099. "from torchvision.transforms import ToPILImage\n",
  1100. "show = ToPILImage() # 可以把Tensor转成Image,方便可视化"
  1101. ]
  1102. },
  1103. {
  1104. "cell_type": "code",
  1105. "execution_count": 35,
  1106. "metadata": {},
  1107. "outputs": [
  1108. {
  1109. "name": "stdout",
  1110. "output_type": "stream",
  1111. "text": [
  1112. "Files already downloaded and verified\n",
  1113. "Files already downloaded and verified\n"
  1114. ]
  1115. }
  1116. ],
  1117. "source": [
  1118. "# 第一次运行程序torchvision会自动下载CIFAR-10数据集,\n",
  1119. "# 大约100M,需花费一定的时间,\n",
  1120. "# 如果已经下载有CIFAR-10,可通过root参数指定\n",
  1121. "\n",
  1122. "# 定义对数据的预处理\n",
  1123. "transform = transforms.Compose([\n",
  1124. " transforms.ToTensor(), # 转为Tensor\n",
  1125. " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化\n",
  1126. " ])\n",
  1127. "\n",
  1128. "# 训练集\n",
  1129. "trainset = tv.datasets.CIFAR10(\n",
  1130. " root='/home/cy/tmp/data/', \n",
  1131. " train=True, \n",
  1132. " download=True,\n",
  1133. " transform=transform)\n",
  1134. "\n",
  1135. "trainloader = t.utils.data.DataLoader(\n",
  1136. " trainset, \n",
  1137. " batch_size=4,\n",
  1138. " shuffle=True, \n",
  1139. " num_workers=2)\n",
  1140. "\n",
  1141. "# 测试集\n",
  1142. "testset = tv.datasets.CIFAR10(\n",
  1143. " '/home/cy/tmp/data/',\n",
  1144. " train=False, \n",
  1145. " download=True, \n",
  1146. " transform=transform)\n",
  1147. "\n",
  1148. "testloader = t.utils.data.DataLoader(\n",
  1149. " testset,\n",
  1150. " batch_size=4, \n",
  1151. " shuffle=False,\n",
  1152. " num_workers=2)\n",
  1153. "\n",
  1154. "classes = ('plane', 'car', 'bird', 'cat',\n",
  1155. " 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')"
  1156. ]
  1157. },
  1158. {
  1159. "cell_type": "markdown",
  1160. "metadata": {},
  1161. "source": [
  1162. "Dataset对象是一个数据集,可以按下标访问,返回形如(data, label)的数据。"
  1163. ]
  1164. },
  1165. {
  1166. "cell_type": "code",
  1167. "execution_count": 36,
  1168. "metadata": {},
  1169. "outputs": [
  1170. {
  1171. "name": "stdout",
  1172. "output_type": "stream",
  1173. "text": [
  1174. "ship\n"
  1175. ]
  1176. },
  1177. {
  1178. "data": {
  1179. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkCAIAAAD/gAIDAAALVElEQVR4nO1cW3MVxxGe2d1zk46E\nhEASlpCEBaEo43K5UqmUKz8jpCo/MQ/Jj0j5JQllLGIw2NxsK4iLERJH17PXPEz316OZxdLoeb4X\nWruzvbPDfNOX6Tn64cuRUkopVde1OomqEbmsaqcZhIKbFTVJVVV5jRsWRGdRlaRc8d2Gbmtu3/AD\nTdM4Ql4m0tXavYs+NI1mVepjX9pUckUXlXMX7RMVcWbEwQpAppkCEACttMgsJiw1uOK1EYHvJWht\nvQWqUhY0s0FrJqbGYy5V00S6Bwjf5RpdSZKU/vaorRrpldau2oRfFGdWAOJgBSBLZMLy5OS/Ey1D\nCakRXqAZBDZJunEayRVrkguNEpe3iSwO3DkxWCCv9R0ed1J0hvsO9uFtYLTSNg3d5QhsjTMrAHGw\nApBZfHJnvj2zwYtaYTLz5GzQ5qQiS2w8U6tsHmm3WeL1wmK9e+VEJ+C7ijlkN5VvZQluJM5HqTbv\nF0Y6zqwAxMEKQBysAGRwWH0XO7GMqCxVWNr4AawFvivQ1O6CaLvdWpNLXXNEnYgn7boC0r0Ga6ul\nSvrAjgVPg6pkj58vQUOt8S2W68APIhiIHvx5EAcrAJk18wkWc+ygF3EsLnLuSey9qwNudC0+hJU5\nEp/Ddf0tEvqeDFwHiztYHBpXgzg0csv3VERV43EzevDnQRysAEg+S8JgvldbbEgSmBKka+mWzHyh\nFWJyNIJOC4hs4VJLPotttJDPzZbZy0cj7V2zqFJXuZh7lmp7zqAPSMYpV4g4HXGwAiCBdIt32tjW\nENlk7d/FE87f2mOHnSxqGtCQA1rtPqi8KxBOmjD0wPOQvTy45fm2JA1ASY3eeHyMOB1xsAKQiY/n\n08q68BteX4vQss9DQm3lfxEb4lErGhVb62mSXLXVP5ek1h0v2e0pP/HtkuJ2I9w4swIQBysAWcPT\nrvIqAM6I1CcdiMIzuYBHmGR4MEFWl1mQ8pNlUzhv0QolCOzxCotV3fD/OvIwKLNgVbVmd9oLRWtR\nLouD1q6vHGdWAOJgBUBI0VKMcDYknt8olUaNn8axd114ejeunRKC1O5mkrWDeyI6dL7DN82+oYQ1\nTKz8rV8A4Wd7Ik5HHKwAZFq5+/26JegTaEmQIg/jjnjLhEey0dr8hJXBbgjebO3XukkebM3aqkBk\nLQkc3PXKJrxMqd3hunYXkxQ+s/tVER9HHKwAZLLm+5nEVrRUyzGJavdOW+7UripgS4dSPNGIJCZt\nKaZI47CG1OonzC7KpPx9B+isGrwOHRUaViBpTZ5qmqb8FRFnRhysAMTBCoCUSYK24H9b4liWNvHF\nUZXfkrjlhYP5n1mrQ8aBcCUmPOFuUbMclf68q4QOpFYVJ+JorINwgLAdlXiOQtW6THsZ7RhInwdx\nsAJgbbJaRbzm38r2tgE/Nhan2fWMMZNB0IP9D9C0vf3OCEXB2StW1ZuYcl47nBxSr/hQTpL15TO4\n82VJroZfoCDuiFecUdvBAF/WnBqLZ3fOgzhYAcjatm1cwYb2JrNVe8R/s+1Dk4QLfp/98BCq7t69\na4TxeGyEPCc+Fg1Zyi++/NIIn9++bQTQcHK2B1U4QqekNApW3k1tV6UbFdi5A1hPmGZvszXiDIiD\nFYAsscp86N/W3BPguay1xmRmTR5/Gy6xXbh0ERdXlz+hFzEdtt+/N0JeEw0zVvr4+wdGuH79Bt86\n8Qb+CPSKbTrTFoF3guJcvlLZpcbMOkmXt5Q2RpyGOFgBaNndCd/fYQ1ylo6Jyf8X+TGZuV5X3njz\nxroRpqbIBf3mm3tG6A5njXBwdER9YtZfnL3g99M6g4cKRWTZvEIoT0r8PLhStXcsPs6sAMTBCoCY\nlMqLqqQw1vb6pJCBnTdVOQ+CAjjR8fbtKyN8d/9b6Dw+PjbC5i+/GCHNiKTXrpOw9XLLCF999Sfu\nFPWqKqQeIvUOi9f8OR22ffiZCvldB8mQW7UOqPzDOHBqO86sAMTBCkBWeT+XIlV6lt2QX3GQ/U9q\nX1aF00YOjLEvOneZrJvqiDVMFQV3U3Nz1GyOXNa8yo2w9YpoOL+wyMq5JMi22rUwivopd9wtnFq5\nYeOJPSfv5EyTRGsYjjhYAcgQOllzklDVYiPQLFMwgkix8hlLsaL0f3BhetoIPzx5YoT5K8vQeXBw\nYISpGaLh/v6+EV5vEfuevPjJCH/7+z+M8Jc7fzVCryuZUuvnlOhKXoBE2hFg2cUVtew+fNESzWKt\nwzkQBysAcbACkB0XpXNJ9kUsM4/cccXubJmT/52mXW5BQ//zTz8b4e3bX42wf3hohPxEJRScD96w\n6Q2MsLh01QhXr103wmBIy193YpJ7YvWZ/Ymyoe6N+St6aYe/y1udJeQQVVhwk9oNSOLMCkAcrABk\n9+7/10jwtuEldKzcU6/DfnNN/vrkgPzvJCEaNglduXdvwwgbG/eNsLu3Z4SF1TXoXF4mN+Lp06dG\nmGNXfmVlxQjrN24aYW2Nkl9vft02wrgQHoJZ45w2ipBTyziQxg6TtfdLRCtKey1q4SZpcC9EfBxx\nsAKQvf+wa6TBgCxRxkmlzLKGmoPJNSbIzDTlgvsDqkJ49uJ/dGuGMr/r69eMsDMi13x6fhE6//Xv\n/xhhc3PTCCWnqO7c+bMRZmcptH786LER3rwmGua2OWQTdshmt9MhIwinPpX9Hg6k4dNbNMTeKtYl\nv4Y64nTEwQpABpNSHNAEnp2l3FOv30W7hUt0scPcHI12jbC3T/Gw4jNqv7tJlmtpiUi3u0c03DnM\nofOPf/i9Eb74/DNqtks6+/zqmRnyRY8OaJvnYH/EfWeiWdVRiIgrzohhdwe0bbyAv2yj4W9UL0Wc\njjhYAcgSnszb22Rl9njCPzvaQbseVwpcmiVepFLaQCPe53I9mNGq5NxQ2bJBsrJ8hVRxVT4MMRzj\nfEz28ZPFy0bY3KRUV29yILqYUKMRkTTPmYZcnIsMV8qVvzCCRdFCQ+tcbsxnhSMOVgCyhmfdxUs0\nz1EOW42lWLbhY9mDASVzUQePCp5KUZuDQ7KPBVfyjXMOPGsxYTnzGDSE3cmYKSknWLocga6vXnUe\nV0qV7HlWnDhqeM8JDNOpe1K8kjNDkjgqeenAmlDHFM05EAcrABkog1mHdAccQqWULjkvyns5OVfN\n9jPKzHSEO8je8OOY+aX1Yww1NjvlPdyM+ctv2d+jDmRMzP60dC/nOG5+boaUF2TT9yoUPXT4HbKB\nRVcSoXQxphdVXAQMWxlnVgDiYAUgO2YaznEyBDwBv5RSyyuU1ex1aTI/evS9EV5uvTHCYEhbCUh4\ndlLyG3WXnUxl5yS50LxyDWuGA6mcGtIDEsbwNot9UcQBYMo1VDOTE0Y4PqRDL3VO2VosF3ND3h9Z\nmIcq1Dq8eU0PVtXgRHcjzoI4WAGIgxWAbOEy0fWIyzQS9iFu3/4M7VaWKTO1NyLmT0xQNvnwmIz0\n0xfPjfDkx2eknVUhRzbJJ+GU5a9P8PrS4aieM2MSig/6tHCguPKoOIYq/KbTaIeC//l5itKHvJIO\np+gtV68sGGHpCn17t2M5NLwX++7dB/5k+sA4swIQBysAGfI+MMljrtPf2JDK4offkYBULJJWq2tr\nRrh165YRUGb14AEduHn+nBi6s7MLnb0eu/68EwNh0KFb3Q7Fz91u12lTWbWNSUqdQeHFCgf8K4ur\nRri6St7PBU6E9bFzbKnCNm2vR+m50ZAS7nFmBSAOVgAyJGum+QDN+JBouPVqE+0O93aNAIp1mBf/\n/PprI3Q9WoE7S0tLRsjzH6ETaazhkExkxldqjl1hm0bcAcTkCJ6VUkfHtIZ8yiVKO2wWYaw7XVI+\n9SkRM0mQ/hYavt+mF/X7ZD3n5siUx5kVgDhYAfg/pQ4eZ65sAxcAAAAASUVORK5CYII=\n",
  1180. "text/plain": [
  1181. "<PIL.Image.Image image mode=RGB size=100x100 at 0x7FC0FF0ECA20>"
  1182. ]
  1183. },
  1184. "execution_count": 36,
  1185. "metadata": {},
  1186. "output_type": "execute_result"
  1187. }
  1188. ],
  1189. "source": [
  1190. "(data, label) = trainset[100]\n",
  1191. "print(classes[label])\n",
  1192. "\n",
  1193. "# (data + 1) / 2是为了还原被归一化的数据\n",
  1194. "show((data + 1) / 2).resize((100, 100))"
  1195. ]
  1196. },
  1197. {
  1198. "cell_type": "markdown",
  1199. "metadata": {},
  1200. "source": [
  1201. "Dataloader是一个可迭代的对象,它将dataset返回的每一条数据拼接成一个batch,并提供多线程加速优化和数据打乱等操作。当程序对dataset的所有数据遍历完一遍之后,相应的对Dataloader也完成了一次迭代。"
  1202. ]
  1203. },
  1204. {
  1205. "cell_type": "code",
  1206. "execution_count": 37,
  1207. "metadata": {},
  1208. "outputs": [
  1209. {
  1210. "name": "stdout",
  1211. "output_type": "stream",
  1212. "text": [
  1213. " bird cat ship ship\n"
  1214. ]
  1215. },
  1216. {
  1217. "data": {
  1218. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAABkCAIAAAAnqfEgAAA3wElEQVR4nO19SY9kWZrVeYPNg5v5\nGO4xR2RkZEZmVg6VXdXVTdPdBU0xCegNICHYNGqEQAIk/gILVqxYAUIgIbFAiEE0aprqboruGjor\nszIrx8iMeXL38MnMbbY3sXjnfObhbl4Su3a4ZxFx3ewN99137b3v3O/7zgc4ODg4ODg4ODg4ODg4\nODg4ODg4ODg4ODg4ODj8/wXv5Ed//a//xbxx/aVreWNlZRlAFE3yPwfDvvYO8/+XFpfyRrlU4ldp\nlv/v6xRhGPCrjF+Vy1UAhUIx/zNJMm2gXnl+/v94MsobhQLPmKXcZHtrP2/0euzeNIoBZB63SJNY\np03yxnDE/gchj7+7v8Pe+vzk8qVLeeP8ufW88Zt/9x/iCP7Br76ZN0pFXlcX7FtnwOOfa3M0nnf4\nSd36H/Kr3ngCINDYdONK3lhaOZ834njKvmVsYDrM/w+jQd6oVqt5oz/lNvVmA0AS8POf3HnGvvW5\nbxJxkM83eNTlAgcwXFjMG6OMt2bc7eaN1UYJwDjmnfLK9bxx0OVhv/veT/Ei/uTf/nvsUr2cNwqh\npz74AKKIWx4c8ixjXXKr0cwbk0Ne6WTEr6KMdwolNtZXa3kjiFMAva52SXjfJxlnQq3GnjQqbGQx\nO9EscPwvrl/IG9sHe3ljihTAJOKVBj53iSMeX+dBptkZlnRfNV3jKVu/86/+NY7gYTLmvvqkoF3K\n+pWN9JuybTxNV9/jkHo6U9GLAM1IoKCfUqgfl6cBTLNADTuwjq/D2heplwGY6oPIY2Oqn1ukr+KU\n+06jqY7GM9r4+FD//QKAzOcuqRq3fA2g4MPBwcHhjCA8+VGvx3dI56CXN5IkxRFTJUn5Kkn1tLZd\nGlW+5SoytYphgY0iX9dpyufreBwB0MGQ2CulxPfeYMTXzpOnT9ldGSYrK6t5I5MFVyzzWvKHcqIX\n9yRitwuyp6LZ64io1dht6KtCgd0ejUaYh8IazZCgwi6Vh7qQAU2VHRl9iUySScCL9NIXzMnQ5+lK\nHofjcEBzIwz4FVK+qUohrzQEhzSSSVLUKzco1AEkOuzlJdpR6xxaFJu8ZF8vwN4eTYmGTmS2z0j3\nKCyWAYxkoibTTt6ol3XcExj2ubEZVgdjDunB3gjAsM8bFBTZ2/6QQznlTUZL96U10eTx2diTpbO9\nc5A3looNAG/deJ19a7Tzxh99+pO80emx25jw1BOZ8OvXecrplId9+nQzb9QadQCTWH1L2DiUdRlN\nRSY01YOyLBENwmJx/kBFxgP0SeDRuPBlbpT84/aOZ1PY/s/MkgoAhDJqAv1U/ROkyp/tOzuuIFN6\n9pGH2a9ktq/ZSoF2SbXRrA+B6EWWHdsLmY8jtnB03NQ70ttTv3FwcHD4Ywb3wHJwcDgzmEMJA/GI\ncpkLkJVyFUcW3Yuie9V6U7vQfC2KrcyYoBqx1jXNnszJ4Ps//nH+55Mn5H1vfu2tvHH+8sW8sbS6\nrhPx1KVKXf2lLR2GWmWfxACSCXlNoUhG0O+T4Za1TP50k2d8ukWz/9zGRt44OCC/aNTsRC+gMesS\nL3kS8/iHsmjLMU3cQsqhK/kcUjOmh5kPYCwCUgzJJiYpr+twwn0D0eGFFl0czYQkbtBl4/mApy4k\n+wAi8e1MfpKXL57LG+vnWnmjs7ObNx7G8qXoLRZUee9WKvyoVkgB7CYc22nMTlZLRZyCKkS3xyKw\nQ+5Vz1IA51qaRWJS08ZC3vDEhtvq0uvrdATVlrjNH21+lTc25Tk511wD0C4v6HScGBvNtbxRSjkB\nzDMTVLnxXpeD8NEXt/NGf8DFe39/H0ClyLuwKtdEpc7G5hY7MD3kLfv5d9/IG0taKlmt8WL/8L/8\nVxyB0SVbb0m0cpKok8b3j+AEixPbCtIAQKBpFuhXNyOA9jPMjMQZATx+VPskX933jn8PXz0JbKlD\nJDQQSy36x582RkKz1MMROjmSf+wknIXl4OBwZuAeWA4ODmcGcyjhUpsepY3VlbyxsrICYCJuIlsV\n1SrpUr3Ohnk0zFgdjenpO9yj2+uw3+HGXgigVKON/cqrt/LGSzdezhvN9nLeaCcK0NC5BwPyr1is\npCAfU6lcBtDd51mePiPv293dzhsXLzLEqd8/5FfbW+ztcKCDMIIpTeZ7LA63eTmT8VRbih/Jw2he\nTmRkfD5ICS2+ZBjlHliO7cUNcpNuQnfS4YRdqlW1ryLaegk/GaQkUw0xzjQeAxj3OEqDIa90vMwt\n0iE7uSp3YRCRrTw8YLcVIYelBd7f3b0BgLGoLnx28uLCnImU4zvfZFhfXyQ9kv+r5qcAvCk9dJHC\nczo99jYo8rAteSEX9YqtKJbq3TpHrDdSeFriAUhHPNrhIQdhrcrFgWtrnACdPon/OOIl74pcX1oh\n5e8UeKNzDrW+yDm5LM9yqpu7UWdvA5G4b968zN6WOJ12n+9jHopah8nEmwKNRkHcKjSCb162wHxz\n5lLUJz5wJFwxk1faKKfBDuLh+FS3vYxA5kTPfuUFkcJA90Uk3iLPZryvdMJBOXtYZAAQKBzxZ5hR\nzsJycHA4M3APLAcHhzODOZb8sihhuaCY/SgCUPDNkUFzbzKgEZ6JI3iBeZgsMJ8Gs3msIo/UKd/0\n5Vdf0elISUolEpDpJNIZ2Zgq5rC785xHm9KYn4g1pFEGoCciMBizk1U5aELRvYliVd95h3k2a2t0\nJE2mYpqnhETuH5C2mNkdikGV5TLLYo7GZMShq8r/VVO46W7/AEChQH6x1OC1l5UWE3hyICoytqSx\nHQfcpthmHslrNzmY+0/uAth9epdbKmPqyRZ9giNFOS4v0ee4ucexLZd533e6/GR/SDY3HsY4kuRU\nlU/Is6jfE7h14x2eUdsY6wm9FGKvAIx8xzq+RQUHltUxJXeOJryt64pHXdNx/Dz+sMDPR5GlNPGe\nGvVZatMpPIzoHLxwjhsH4laRJnDuDg59xYJOxFsVDGwM2pfnemefo21JbH6xinnw9JuyJBtf1E13\nG0WRq0SLD5GuaE9O7V35fJuVBoB1eb0LmpP2ezSWV/ItN46nNF4ZysWfqXt5MpD96Rs59Y97DuMT\n7s7CCcpp8ai5I9GsJ0viOQlnYTk4OJwZzLGwEr0J7WFcjGIAxQJf07ZKN9SCem/EheFQmQeHPb6y\nLH2nqKCepWUueVbKNQBeotCtQOvBidIylRGapjxaX/mxiXIpAmW9pLHevXGGIymp1u2ylmkzPaZr\nDdqSr996lZ9UuY0FRtWaC5iHck1BaloLr+poo4eM6ipN+RLeWOBBGtYHrbrXa0UAshIw0MpxqcSP\nLjc5bluyNy1bqKzXzTd++Tt542u/8Kfyxqff/58Anj65z04qyzcpt/LG2jnakr6inxQdhaLee1X5\nMZ5bjpEX4sgoLda47yg61cIazyJ/NNmUFhvDB+Ap39je0YFe9YmFCOmrtKAV4pB9sKSWsMjx99ME\nwDjUy3+2szoSa07CrAxZZ0YRbG06M+MiBBDPMoQVw2WxTmYvJLxTmfJ4QsUn4kQ2Lz8WR/G18FyV\ndTPap/X05d0v8sa9ezScb99msNjWFr1GHW3cqjUAvPYGA8Fee5MNua+wqV2ePKZXam+XDofVVeYn\n3XqNfrCrV6/yq7U1AM1m81j/PcuplnFkppCFaPknkquPHgBHHkbhPFGGY4d1cHBw+OMO98BycHA4\nM5hDCfsK3jk8JKMphEUcWSELtLpooU/70kvqbHEt/OkzNkZjWuyL0ldYv8AV4oVGAcD9e5/lf46V\nPlIpk8QtrnGXghhNo8zF6aZScyIt/E9FpnIJA9O6aihe5splSlzt7TKF4uo5Hr9WVPySLOaCp5XI\nyXyys7LKtepIBC3pMMQm6tEsj5Ves6ildKMG2x2S6EF/CCCStexN2e12iZ8slKQ0NCFZCBKOdmGB\nlnlrlWur979insoHP/gegCf7HJObr97IG5U+9x11eYNuXmOs0F6Lg7zd7eSN8wvSyiqwESMBkE7I\nXzojrRAHp8ZhzXJAhMwkk+IYgGcZIVrrnWWl2DKwQs+MkWUzVQAe1vhj5gUAFAI3y0qxNd80sIAm\nNRJOOUsyyySbkWqbXPBgRnm0Vp3pE2+2rzw2ch+ZsJR3YjRyhPI4bG9Suay3x1n0k+//IG88fUhK\naEoksehwUYddX2GYWJ7H80xrAs/3uEwxkrDH1jYngM1wu7SDPWVr3b+XN9bWGZW2fukCgGvXmCD1\nxmuv5Y1Wq6VLOcn7fgYTJPIzW8SWd/oezsJycHA4M3APLAcHhzODOZb8zmGH3+2TTOWyriXxJk9x\nNOZGnEiZt6Ks9IvifffvP8wb2xJj2HpKgzMPcRrIw7h/QKbWbjFWZV0yvmOl13/nO9/OG6+/8lLe\nuHf707yx+ZDW70J7AcDKOr1g4wFN68khnSB7W+RNPkgECtJISPQE9xT9kp3i/nrwyefcckrKUJGD\nqSUnGhQntd8hNWuHvLRITszFehnA7gHpcKxEn/u7HW4gbYmlJY7/UI7RQLFaH/+3f8lLG7C7d58+\nAXD5Jr2frVWKNHx8l0Z+7EnpQbIHO6LUOzFvYrvKvRLpz12tpQC+2mPn7+2TBReNU58O32KpPKN1\ns38BeMaL5XcryMObzpTk5HIykjXjhMqRQgggkgyhJ95UlidrKmGP2CSGp6Lb6kMs56JNicB/oZOB\nbrcJ1M1y1oRY+5p/zDvF//Vv/uW/yhuPHj/KG/vSfliuk48f9phb9kBMzeKhCjNVBovVSnDE6Vms\naiS15cUr0r1YpNqE0dVE6W6hXJZjpU99+OlHAB4/e5L/eX6DVLHVah+70iOu3eN+1Rls6F4UgZi5\naE/AWVgODg5nBu6B5eDgcGYwhxJuy9uFgpJd/CKAsvxWkaIEjS4V5MRRdj2W2yQayy1Ka2cxt95S\ndsidR1sA6sqYWTvHgNKR3IWPH9KRUWurKo+E1pbO0+VXFAldlqcsSlMAvgI7u7Jmv7xD8vj8AS3q\n19/4E+yb/J5mY0dSgj9N090kqwu65kTcLavQebq4RkaGXZLQ4SEvLVL1Gs+LATSqylZRQkmrpOtS\ncOa+/KF79OKitchtqsr1r9W5+8aF6wBWVsiLu106fP2Yl+Or0M5HT0iZu0PuW1GiT3cofqE8p3Gh\nAmA8Oe6c6o8mOAWzzAzLxRcrKVfKACaickYnPMk5GEkwkmJHKyne1Zc7eDp+IekkVPRmYULyWwZP\nlIWKeww4nYpg/xueFO4To4ScSEFQBODJ82vSAia+HutCZq5AbZMqLjRO5rOd7/3u/8wblQpPV6vw\nBiXKf+poTtbP07Ebj7mcsljSuoQyeu4/OwAQiiFeWCFl++inH+eNzad0R9549+e4r4b09ief5I1r\n5zkT3vw55le9d/szACt7/D0Wde1eZgGxxzlvasV+suOuUiPIx92Cp9tRzsJycHA4M3APLAcHhzOD\nOZRwJJfWpKoMsnYGYAyZ3KYWpr8twG8yYbxeXZ4yk3LP5Ky5eIGxbWHRB/D+R2RqfnpOn3Nfc238\n2b/85/PGK6+TYFrOeb1BlcHWLdKfXNphV+J8W8/JK3vSFG+uXs8bhSYDR7sDlbGUwT/o09geTywC\n8QUsXqHevEUwjsWUl9boOrkl7YQ731cKpMQLwzI3HsYRgJpCcJfqZHmxT0YQi9Kef5m6hu++8a28\ncfk6KaevkMWRUix73S6O0BbzpH3Zokvox9//Xe6rml1r4p6xxwlg5VEl3Qi/PwbwUA7FWFS0VT9V\n070QHp9jJgaQeBmOsLxYN9UCkmMxqaGIuSk4wj/uc0xnBDICUFb5r95TTrB79z7IG4uv/nLeWL3O\nkfRjTpLuEzp/y0ukXeESvWleWAKQyWFtevYWWVo4qbkuxQILji0W50fYPrp3h33XLm8oDfDB/S/Z\nN/nTmxIR7HVI5wcpOfvF1RavqFgAEFjtWAVjt6W1mVU40955+628URcb9fVTfekKUwh/9du/kjf2\npz0ASY9j27KkQtG9k071n5FBeJrP9PRUQmdhOTg4nB3Med6//fpbeSOVJRX6RQAHqrde1nLgzIyS\nStG+JIZ7StYZqm57Vavga+u0iS5cWseRrIvbdxncsbNHo+bGO1wOXL1Em2jvgMffndIzUKny3b7Q\nbuWNUpABaDf5580bfFMNnjzmvk8YGralWKdiwWzGF7JGcFQl9kW011VKXq/Ehp7+N2UKXb3BhJjG\nCo3HSCVYnktOazIeAVhp8AXY0JA+V9q9LzGDV37ulzgaJnKkMBZb1xxLHWw8meKIlNJEG2ycp2FY\nW2aXvv87/5lXGisTa6zi7DKsWhrbVjkGUKtJTUH3rlU71cKypJ2JDhfMopM8AJCgsG+aFXpPm7SB\nrcKXZHylCrOaKgbQJKoLhSKOrAf3dpmVcucTWlhXK1xLbrd5pzYfs3RT7ye/nzfqMr42vskbnVeE\nMoM6kTEeKoknPmTkVE8yWBeuMHPFfDh92OL0C1iUuPVEMiEDy/SSGng2ZGNP7pflZZpai9KJLlR4\nI66utAC05P8paTHeN9U2lbxak0erLrmRxWUyj7a4wqIm8OWVDQBL1+QEk7ScDcucxJyfJdLwfw1n\nYTk4OJwZuAeWg4PDmcEcSvjWLdqxdx+QOn119w6O1Jj0FZ+1UFcOzTmapuMxtwnFRmpaRbZk+seb\nzDAYRwmAojIPvvbuL+aNzohm88pV5t9ECS3b9977KG/89APa8K+8ym1+6du/wL2aFQDpiM/iapEd\nuHyZWx5sbqm3yshX0E1fYgbmPWjP0tBfQHuJgS2phNYWFmgnW1Wecxev5I1LN2h+TxSvZFlNeYp8\nWULMFsETzVagpYOs+rWJKtpmWlItS3C5JAXeRqUEoCqVi33pKVZLpB6//GssZtNT2s29P2Bpz5Iy\niiywqCJ5g/zGWLCYLQ6sSffiJCZSYZyIxAW2bh4WAHgigLGEBEySwQuOhzjF8mzM2KPunUnv5b2M\nFBnUXCSvWVghE1wUAY+70r0b8L6XUjkc9rlAsSzdhyT2AWRSl4xNNTjq5I2dL3+UN4Y7XHxoabQz\nyQ1OF9uYh9deITm12lTmMLl8nt6kglbufdXgieRMsYULU0v0CyEAqILvc0mK9yW62ZVaw+/99nfz\nRlGLNo/0A7n3lI3bdykZ6PUOAfzVX/9r7NLMz2DqiOL7s/Crn7Xq/n8LZ2E5ODicGbgHloODw5nB\nHEo4VTmQc4syaKMlAFlCa/npc3pDCiAJ2tlRHoNHg7NWt5KefCaOFOtk1VH6owjAMCVlg1TMNy59\nPW8c9mlIf/YJQ2mePWVWDeQ7g09GOeyTQ3ULAYAoUqKM2MRz+VYebtMLhrEFj0h2IiY/ai6ouolV\nhnwRqZyDfsDAlsVlOu+qDcY6FWUwF1Ll91iFFVOkixMAaUTvYSwj35tVxDRJbIuDUwSQOIKxoUIg\nJ1oWACgoW6hWlm+oyE3L8ha9+tY388aTz8mypwd0q7Wq3H0oOflxtwegnPEgTQkTtlR65yQ2Vcel\nLAfoQlW6gFmGI5IAUAFa86mZY3FW/vOEEp6nOK/EaON0DGXSAEhL5OP7WgE42KVwSBDSuVYqsf+j\nTKVhNfnTPr1143KAI5UHChLwm9VAVXme6aCTN3afc7oGC5wbi/LwHoPp9tlaRN+qIsxU+nS7Fah4\nMOD4PNnmIKeZTZsUwHmFZS0tsNtFlSAYT7VEc/8Br0iLMxXpN1SaJLD3H3NK1Ad9AD/+Icnv5St0\n369euawOWO1W/m9Bmtks9ep4XWR9MlPuwClwFpaDg8OZgXtgOTg4nBnMoYRf3KFz0HhEe2kZwAX5\nRw761B7YOSC3erxJkpjMfCg0jycRo90uXWTs2ZVLjPff25sA2O/JUCyRFu112IGCKoOhSIt3ZYVd\neknOwbKSQnYO6dGoLWwAmEoc3STq9yQQmGTGHch0Hj1jYkRYpI260GbMZ7e7g3nITCBBahP9Hqn0\nvQ6HZbHObcotbqMEhln1s5B1qzRuMulNJb0oNYKCGolCIi1aL5BP1iTJ09QH4KuTxfB44SyLxbxw\ngT7Nlav0Dn+1Tfu/IC+wBa9G4wGAos8BbNZVoPV0LCgZyCJsA/lV02kEIFLkp0niBeKGM5E8uUot\nFnagZJ1YHKSsvdIkwZHAzmcH9ADuiUC9qejKRpvXVa/RE/dM4o7BiNNmssPRSDaWARQU3VrQDSoH\nHISKqgfsqeLcYESHeLmtomrJfBOhVDGmxn139jiLqvpqJohgwpkj5b3pm1hkKtcfNHJtdRjOX77C\nS5ZT+7WXpcu+xl/o5j5H7OLNm3mjKO/jJ9/9HzhSBO+DDxiL+40mFwfqbXnP4+ME0CihNWZFaoMA\nR+ik7wqpOjg4/D+AORbWll6wyy0uwj15dB/AM+lYWQRKxeRs9Iif6pUexTIQyoyC6Y/5DP7sNpNO\nxlEJwPL6lfzPap2venha19QS8rUbzHm+cZPP70Av+bJyNRYXuYCaC+wOJaq1ucWAmsMe+9/WeupA\ncUxBoLKpsiliBbB0DqQ+9SJ8T9aThKUm8ioUrZK4cimwIMfCLF5Gakp+AMCXFnBFYs2hGpYqnNly\n5okyM7axKUnlL9iZYJnZLGZhSQJ4VXVWvvb2N/JGf5cLwKkZaCp03vETAKlMyIW2VmdrWkc/gaYs\nULOwfFPIKhRxROzYUJQ9lVpFVe0yVAxaT7ZDWFTBm8yOPwXw4MGD/M///f3f4+mUSr0tRbHF19jt\nkkR5D8Ycnwc/+qO8cc2n8trylW8AKMLOosQlRaX1dHM3lZcWjrhgX9BSdwrJZ7+Iuw8YumXhVyWl\nzvjyKvhyoeh/TDJeSDmWAvJUBiBCAMWSCgNP+Ht5usVf97delny28nsePeYvxUo6bSoT7tpVDsLB\nZATguWqvmhtoTyn07/78z+eNK0qn80/6SdSwGz1NYhwtkuSfGrrlLCwHB4czA/fAcnBwODOYQwlX\nlZpfU2DUQqkN4MZVJvpDC8OJwvxjT3E6MmjvbtFgfvCMp6jVuKRncqhLtWUAzRatzSgVCdIC4aWL\nXKr8+ps0L9uLCkFKeeqqalUWtRraP+wDuPeARXSePmZjX0Ex3rDDM/Z4EFMj2LioNI4GDfLBoQR8\nX0Sm5e2i2JbSMFDx2cmSmFRiFq4KDh1LVwi0gusbE7RPTLLACsboE6OEszqmJkvkZQBipa2E6qQF\nNAUiWc0G6eoFxdSsX7+l80mHQPyrXC4BWFfEU31xSac99c3nSRpbEWAoif/mlLUoUu/PLucFYgsg\nk3hGWUM6kwcQC66ZdwIhgN/6rffzP+/fpyjwX/wWk7f8Yitv7Aw5b8vy2PzgA268JNq43iG/a5aq\nAHxvor6xS1vPyaAf7JEADpQg9UD8q/0uP7Hop2N4qCSYkhKwAjHBuhSsIg1HoFitVBF5xqEq4p6T\ngy6ORNv96T/zZ/LG1VtkgpUF0nmti+C1N9/KG9+Ui+Cruyzhs77BT955/SUAD7+4nf955wvKh/32\nb/8WL0SaKLdu8WjvvP02O6krMuEzS1ArVso4Iryxr2pMJ+EsLAcHhzMD98BycHA4M5gnkWxCAgow\n2dvdB+AHtI3rVdqoTTnIMlnjpRo/uXlTJXBi+uaePaPlHKqYaKW5DiDxaLWWKnTZNBXTceky/RcL\nbWU/SP8sCOj+uCvZv+l4pONnAO5//hk7cPCAl9Pr5I2yeFlZ/Kss4druLv1fiY621KpjHqIeRyNW\nCZlKSpdKTRktqbycQ9HKgpQCTWMgl68rKVPHuFvgG1Gy0CqlOHjmHGS3PX0yS4NAhiOBLV52PBbG\nuJWnkq6NJt2s9QUO+/5zsulKmUSjsbaGI7LXpqZgdWFPwtKDCuIvFQ3CaBxBoh0AxvLi+Z7VEFL4\nVaqAHeUnlSSbtyDx34rSwp7evQ3gwecMERpHjI/74tGDvPHyDWpLfF10+LM7H+aNu9IwGCoE7CXV\n2onGfQBTDWBJqsS3P/rDvLE95JTwz1Mae6/DIL6mqj3VyvNNhEBKeKlud6hoNV9FajtjXkgyZCOT\nfLZJU5QbPM658xsA/uZv/Aav9N2v8WjijEV5Ic2V3GrxvtcbZPqv3+LigBk2njcGsNri5Vy7zPWc\nd1RW55HqJccjZsIFOsHhLh8Fewox60v0udFuAXj4UEVkDw5wCpyF5eDgcGbgHlgODg5nBnMo4fkL\n8gYqLjTPDt+Uy2N7V2v4q9y9paDNQoE2dlkaYzev0Tk4PLyv3WnHNqYpgJZi2yybf3WNR2tJJK87\npP0/HdLOvP3xT/PGRz+h5Z+Kp5xfCAFgj46MdJvaY5WIl9NVoOIgFLVp8UTlJmmpPFdYWJofEjl6\nyByOZwrJu/EmUxwqZmx7ZCuRfEMjJfR7ctLluviZHK+Z1BpS84IpHtIaRsSyIzJ2/EopFHlihJUs\nTROTPTC5dN6piZQYTM9+Q+Uzh73nx/YKPR8S3gMwVYqGEcyTqJR5okzceaoT5X2amiC9fJG+yG8h\nNuKsUEMxtXomhcUBicaqgn5Ho0MAv/o1Oqf+0wHXIj7+9CfsUon3tDxm6snONsmI1eDZlJDelgZ5\niAhARZVWdx98yMP+5Ad5Y8+UCeW5fu2Vd3lGLYOEJyvr5NdXPu5QTkXzez3ypuXrvC/V4pW8UdPi\nxoryn1aUXvPmN74O4Nf/8l/i4UzyQes89uueJOS2qe5mMCKT9aaqKzySX3I0AlBWkdq1Gn+zVZ+/\noFVdCA75M9/6gCG4I/1SRib+oeM88zIAu91O/mfZivGcgLOwHBwczgzcA8vBweHMYA4lvHqVagoW\nx3Xt2lUAP1dg0S0rMto/VE6WlKrNSxjLnXfhPG3cww5t7F35CBYaizjigVpQ6uKKVMfMdNx7yjM+\nvv1h3vjh7/73vLG/JWNejAmXWwBuiq4W2uzScIe+lVTJX1FROgHS7V5do9dj/RIj5cpNpQG+iL2H\npJzb2wz5Q2vp2LUXlkSCdCEHMnoXFsVKqmUAkbx7nsoxmbbZLMddTKpostkz341CLk32L8hwpBDp\nEe00bjBU0O9gSNM9kKt3ZYV12J7cI32YyGeaOzctdXTm6zxdcW2yzzqgFgbcHfNE0zQFMJCrK1V2\nnpdoLSIR+ZWnLJMDsdonE7zz4/+UN3Yr5BG1WhvAS+cYA1yTwuJYg/PqBU7XSsZ5ZUXhqppFY13a\nwcTXIPQBlEYMjLzz+Xt5Y6YGMenkjUTOr59/61fyhsn+pen8gfIVCJ2Il9mQ/q2/8bfY+E02SvI+\nN0LeoFBMMhGtjvp9AJ0/IgsedTj5R/vs5LDT0SckgB158YaHTAycSPNyqKWYw2gCoC9Jie6YV7rd\n5TLRVHR+JE5tJDeWq9qk6Kea5EkhBDDW4LfWVnAKnIXl4OBwZjDHwrJc6lkx0SwDUFXU1cZ5yrxG\nsiA8BY/YmmWih2VV+QQX1rmU/vQST3ru3CqAuuo4Lq3yDTyL0N/kY3uvw2Cr977Lyi4H91k+J1SZ\nE+vVyxdXAaTgK+Wx6nSOFaJy5doVdkDr6XHAC2kuaRGxzbXDanW+hTXU6wglyUJ8xreZH/GTzg5t\nycYCX/6HUnE66DIUqLWwCKA4kSxBzDetJevM8m/UsPtiS9HpTCrJTKrU/oWVLAXsvW1KSWPleVhW\nfUOFjkxv1wKvynl6vekWqZNmpp3E5u3/kDduaT342joT+gf+eQCdSCbwWKFhiaQ+YtmDeidPhir6\n8gWdHs0GZ8uO3Cw/fO8xgK2QpysoPuuXvsZYoVtLrbzRmkpWQXZNXcFia00aLzcWaaBdLaYA7nzG\nswQVfv7Lf/7X88aPNAG2v2IM4KKcOdEstnF+IdVgQrJSl/via9dJdL79FqXZHv3v388b+1vPjzX6\nZjep4e0cAJhquOKxivHI3gkt5Ut+DDPDZ6lRphSibaJSAKCgP9sl2nqmqjzUUQ+170AL9olNJ83G\nRI1p5gPoqubWRPWTTsJZWA4ODmcG7oHl4OBwZjCPEoaWK6NPfB/AQEVAx7JvLfSpWla9Dcm5pqKE\nFckeNJfJ+M5t8DjN1iKAygJt7ylon3d3uMFQBu1nH3KBs3OP4Vc1j2yxvUZb9O13mA+xvLQGYDhh\nl6qLPG+gFcrEFGbHtMNXN2i6FxRhtPtsdGybY4i1+muLivUJ8wkaChGKuoxcm2S0dWtN8t+P32Pw\nzmQ0BnDjlTc4Jg0LZJNAoATqShL2jWLJ4E4VuKQcHdMUzquTZlrUNA1dE1+MFBQzkIR0o17W8fmV\nrahbI05iHGEKJuyXxqcuujdFMD/47r/PG++8RdmM5WvfAtBaYcGVrMS74HncJVDpoyiU6+ZJJ2/8\nsMOxffvXfi1vvP+D3+bGn30IYNzlHVwvcjrdVEhddUA/yfZn1GbYf8ZB6AzJ9FuayW9fZmPZ6wL4\nXKV2K/p5fPMbZJrPOtzXG3HyX7zIyKlSWSQL89FUxZqSOFtB6nj/7l/887xxYZEk9+Gde9pG5Ysq\nHLGaMm+a59oANha4DmPktNqQt0cVgj3l2RQl8VxTblxF984K/RaDEICnMEVbAhooEWdPa/mdvj6R\nU6KnB4g9Sfp9Nvb2ewBu36P2w1NVoj0JZ2E5ODicGbgHloODw5nBHEpowTtjRQ/1B30cUXFbXKZ5\nOZ6qzs0BXXJLiySA1SaNxsGIZGG/QxOx0aSd3F5axxHXQF/6B909dmD3GWOskj3W0VmpsA/md3vz\nW9/KG+sXr7F77SUADSXZWLcNPZGg5zt0sphQXJhY0gkvLUgjzEPjAiO2Rpv0YKaqs7K/xVyQktLf\nV+VD8ZV+NNwno/nD/34HwIMv6Ve68SbLmt68SYYbFeWrPSHlPo2kjSdKOJ2yD/kmzYUWtLN24dgO\n+hz2eMqGr8nwXKL+mTInlDxD32JdSTCeifqHcyYSB+EcZfM2n5LOf/A+xd6u7n0JYOUic2jqy7yD\nhbJYvIhmIjWL7n1OiaUlMr7lS5yNzU95aefXUgCh5uT5Fjn15TaHq+XzIIWUV3quxkG+qvWKN29y\nlt46r1yT7j0AkwGjlv7X+yzue/8u/Yam5Hf9IudGNeWUHnTJFvvR/MK81Sa9xpMh1zpeef1t9Y0z\n+aUrdBd+R9IOS8rIqYnohZr2aSMEUBXLyxXyAECZcNAKCWCN7IX/gEiLAaka+Vgk2mQw4Y+6d8Bu\nxzpRScInoSmIaBodKn5zT4Wmdva7ADZVoXl6yo8OzsJycHA4Q3APLAcHhzODOZZ8KkoIeZRqpTKA\nmiLoaorAHCVSIxjThBtILt2XrHXFp9tiqU1Jv+oCLf+UkYrK/xiwM6MejcnhNklEecD0DvlS8NIb\nlEa4cJnxdYWApnte7snkEKwQqaUXBOJWq0uLulALkdUlyy0yc7S9iPVV7tvp0Rg+UMZStqtwvpDX\nfl52slXfDKpirE++BPDVe+SVj+/QbxV9m3n2L73JjH9TYhgpsPbOPVWclaR3PSS/ay8uAohjchNT\nwktklg+lNlcq8JO+SOLWJqlNJiJgTqIwCAEUdbQjBbpO9RLeG5GV+Bt/Im88vcMAy1LyOYDhQSf/\nc6FJJ1FVSTbFEnlfxaeCSO9LbjwecuPb36PLL9ojR3vlRgrgrZDTNU4ZUTyYSEWjw+m0tM6b+GvX\n6YD7Bca0IghI54uqDnf/848BTPcYsFpIqFR38IiNN8+18sZbl8lS1yJ2qX6ek/+LPczFpvLGviGf\n42/8/X/MTorlZdCwm66hMXGLAtV9mIQAkMoiGc14npK3tGkxOv4QsCjjI1Gu3D32MgATiY4cdFWt\n9hn7vytNl+FzevqeK3z62RYZXzz7SfHUO4f7APyKNAun84kznIXl4OBwhjDHwmooGWVB9kyerGNJ\nKvY0r6WKtS/RKOspXmY01eKusivKdb4k+z2edDzuAIim+rPPx/ZUAU3lmI/kZY8P8msvU8Dola/z\nVVhc4rsxVORXmGU48nIw1eBmjfaOr1inWaVsvamOJAlzd1spPIb29ZfzxuIqX+Cejl9b5gKqr7zu\nUBZWqhIl2hZb4xGAsV5ZQzkEvvzwe3lj5cpLeWOgsqwf/oiavPdv8wV+440388ZLl1p5oxxMAezs\nag1e5ueizGSLOJtKKenuQ9prkULPGo26drcwIg9AYjlbsOE6tfLlaMprP7/K5Hn0/0L+/+7OdwEk\nIeu59w8lKKzwrsUVjmR1+S12WwFN9TKHfTSSdVnnMnmltQqgUlFVGJ+fP7zDtfZag6O/ss4SMt2M\n8XF1WfdP7tNA6E65ph5mmwB+5SbNg2/eZPhSqEqxpQanYqKfVTBkAtmtFbpQrmxQdPgf4QVYKvJX\nKkjz4ce8uZc2aJ1NC5rA0tG2LDpf5pJNVw8+gETz2ZbJk5nOmm7ZhFc0HUsMWs6WiX4gVh+rNI0A\nJJozA7EWM8/70nHeG8qjpbjOZeXeTcThbn95J2/cefQVgEpVy//+qQa7s7AcHBzODNwDy8HB4cxg\nDiU0Edf1dRrkSTwFsCd9gsNDUrY0koCRTNNEO0dhi58opqY3ob0XK/PeSxMAB/19bcCjTSR1tBDx\nRC+tsSfvfusX80brKonSSJkxRSlJqTNWbOb4qrl9UiioOovCfEwuapa5Pqtj+gIWzjOhJBXlLFdU\nAlMEKtbquBWtgVSWWtq4VisDCLTK2F6nD+EX/9xfyRsbUgd778dcq955StHn1UV6NoqeBUxpSdUH\ngFZNpTdVYyaI6BmYKCDo7kOu9z9VCcy60jsKM8UIEYo0w5E1eJPxGg0OcQqWF5R0olXw+tU/y1Nv\n3gEw2aXIdavMq1hvk+4trJLNBec4Aa5JvzuoKKFHdD5NdafiMYA441WME3VglZ9Mn9HhUD1Hp0Ts\ncZZ6PW68qFnTWCbjW2hGALIBl9j7u6Qzwym5T6SfwzBWUFWf9/3RT/8jr3H9K8xDI5PglNaq/9k/\n/Sd54zf/9t/JG80N9rag8Z/Bs98UP6jPoqsAwFOml6/krcRGq6C1lCIPUpa094IOEsQ842R7G0B3\nnwOYjdgo6NrrWi5P2qTkEwv9U6bdnirOfqzUqMNRF8Ak1Y9CM/AknIXl4OBwZuAeWA4ODmcGcyhh\nIhG4ziE9Js16BcBgxD/zqiQ4Ii0wHNCvNB5JLaB+ThvTmBwoQ2Y85vHjIQB4xjuK5B1VCeC9ukIT\n8eKSSreuLquTquNi2QOmKucXAPh6FpsDK02PN8zZYfLKpVLpxDbz47BSeVKstuvEtBNYDmbmsjEX\nWrEkQ1zlhVKvCCCOO/mfVzbIJl6+Qt9QlDAWaU9M8NwG+UtJwVzJRNoSqkyTBlUABb2PYgkwDCTS\n8HxPOtcqylITSzXWYDFoqXh3/necTbWlBIstQO4kJpw2qc9L3pTYX7D6DoDNByyntNMnOS3VxaBb\n8heP/gvPGHPn4ZBfeRVzxSotrFwCEBY5Z4IK9XaDpuSDh+QthQ16+s63OaRRX5e8xEG4coMlSMOg\nACDu0ae595AKIg8/o9O295xMpx/RL9ZPOSzPP/0wb8TKwTqGb79N9/dYl15QgdhiTLa4LMqWSmXw\n4JC/xGhWl1fqmygAmJo/V9GInpZBAk2OoZKBJqbCKDnDsRQXxhJjmHpjAIncfJabZTGbqX4OfRWp\nivUrOOjyKfHpJ5/wKxOGrIYAIi2hJMmpRZicheXg4HBm4B5YDg4OZwZzKOGClvcD+YaiNAVQVdZ4\nWKDV11Ie/ELMRq/PJ+DuQOn1XZp523IkJQol9YYAMIhJoEZKr79ckgx8WXUiFelXMm+byEhk0nRW\nTBQ+5tWJCZT7P/PiqSdjJZ3HcnuFlugz30mIRAw0MA/jiZA8i7e0M5o7clogU+6nHoCayN3OPgnU\n+x+8nzds2Nsq/2Mk1CSxKwpMHUo/uzdKABREBHoDWvvjITmCJSoVNQXMi2fk2qL3ZhfLMypMUVR6\nPm3ODys6GUvPfl/6ba1r3wCw16fj7O5tZnV8+BEb1x+R912rkU2EyUB94pQrqHJn4PGrkpfiSMlS\nr1DVdXHLis81h517jNL0NQoF8KtITsaDB3/AU4c1AJ6YVKiUpvoiyfuW8lQeHna4r/TyYw1UTy6z\nY3hzkUsBqX509it78Lu/kzfu/QG5Z1LhFX3/8y/yxpe7HKhYv4ulegtAf8gxMeX+WOQxVXmbYlFX\npFm61FYRrCJHzPzp7VYdQFnihVYvOVQcsjeTgTeBEPK+Tz4lE3zyhLlfibhtnsPWll5gaMrzJ+As\nLAcHhzMD98BycHA4M5hDCTOfH5oqQzQdAmgv0UM3FYFK5Q5rS6Cu2SZ/6T+StMCeuZx42HhAM3W4\nuwOgmzCDLJtwl7qixpZVWKyUkldOpKhdatP1k0nyvL3IdLBo0gewv02z0+plGU6Gklrgq0WBphZp\neVpdpsBiTaELlLUvSmipWEWZ1uZVKUuDrVhdAAAlSyaVVt64+4CZfYvt5rFdLELS1BQmSvUaKRUx\ne7oJINTgDOUcTCTKWJKRbwLtNgiZeZzkLgxnUbgZjtAKk5C3ENmTqEjGPk7F5nocqAfPnwHYFbkb\nVRUYqUq3BwG/eh5Q4aMsYftqoZU3xhmv6Nk2k+9CL8ER93Gs+zJQ0KxFJS5UOexFBU9GyqcrNzil\nFyNO6Web9wDsTeX01GFrKpJ6qBWGu12eup+IiEloIfMXMQ/jWek2dTthT0rKrAx1pWNFaS5Jqb0y\nkGNaCb+TaQqgJK33Yrl27GhaKcHmNv2eI2mujHZIMI3oJaqcer7dAuBZqWBNsCDgosTiErsUe9xl\n/4AjticvasHj7iXR3lqzAqAgqfjwtIUYZ2E5ODicIfwsieTEikx7AYCJBHlLqpFTLvLlGSi7OtHb\n8rDfyRubW1KJCrjx8wdUI9q+90MA3ZSm0MY5ZrTXFmnKNWRThLEyP1TqPVhV+U+pBYV6meQGoAVS\nWcKEwfQbLEAq1KPd8pIiXXvm/YwFZeBIsJJn6e8yoyyrx7aRRYJ6TXph7QaAMOQuC6otVFIYjik8\n11Sg3BbDIwkLjWRhNTUayBIAoz4Nrln9m1nhe3MIaA6o+77uZijDMFJ4V141x0zI/ohWQHj6KD3e\nY+ZKuciNRwroe3T3fQCezxf7tVVqJzR9mvBXljkTmgrr8wJVdtFy70gzbaFI4yV3oVQUVpYqrvDB\nQxabGUqi98YaDbeRJsnAH+loNPeKi9RXeLRVBXBvQhdBEqnkrbxGnSnHbTfiVE80r1Jbqs/qmIfi\nZaYcWYJXqcT7UpJDoC5HUCJ14ycf0qi8fom9vf1AwmpPtgDUZHBZNtVYbpmKBEPaC7SJVle58N/W\noru5j7ryBbWrVQC7u7SVQs1J35cZ6POTorrdvMhgt2uXX9Je8kHpGr2SB2BfOstl6Tj/3v/4EC/C\nWVgODg5nBu6B5eDgcGYwb9F9qhXcAU3cUqUMoFynaWdibVMlfCRaFM8UZ1Rp0KiLYy25qarovUc/\nyhuDze8BSAJa7N1AvEOG+oFS2C/UZd7L4u30dGrJNjzZpsbDvbt3AOyp9EsglteTxthjKdX1lYtg\nwVy+VhlDK03qzxkiHJUEnukEe8e+siV2I2Imc1wqvlAidGmJPgRLVLKUoKKyhQx2NNMWLCs9xVKL\n8l4FIqWZjhZZnU6rgaqjFLUMX5I+XKo4r0guiH6vB6AmFT3TSs5OSWDCkUq0kxEXB3yN2KX1iwDC\ngPs2RRDqKe9LbRbIRpKSanF3NBUTD8iCa2Xp50UdAP2+Frx1gRWfIhBBVdPV55Qeail9kPKK9nYl\nSODxq6y0DqDUtNKhHV6XCix5psAB+kmmCeeeLQ6UgirmobC6zi0VFQUtSPsSWihqJqcq4rt2jpdc\nnrBXd+5wbi+3lgFM5QxRHBhKRWlJaoVnIhGFu7fJdq+/xN9UUTNhVfNzoVIBUBUfz6QOEusnZAFf\nVkrZkzqLObImKUn09EC/Yj8DMFTUWBjKP3MCzsJycHA4M3APLAcHhzODOXzHhPoW5CwIS0UABdn/\ntoafykvlyUlUKNMYXhvIFp1SSLurcI/BIZ01qTfGkfqde89J4j7+ijZk7TpdZo9VcHGoXPkPHv5+\n3tjyaLWOPfZq77ALYKowFsuGMXdYbDEv4lYWAFJUY6rApWk0P8LInvQezFA/+ZU+MU10sSFLe1ps\nLAGQJv5s30rFsh9EruW4tIAp0zkbyz9lp8xl101NwQ5i4tyhRV2J4VeV35PpKIMxqVkmKfre4QGA\noi8WphyRrHDqm29liS6zyVjBXOKP1SA+2pOCzltXYduKOaEqnIoH4i++/Kphhf3vyRvcCT0AEwUJ\nbu1wXhV1/OtX6FO7oLDBptjilzuc/HcOuPtC0UKZ6gAmSj4bS6t+nJE8BsptWluiCmNQo+PS05qJ\nXdH7+Lc4ggS60kx5V6JsibjVaMpfQa93POFmd5v+NV/6B+vnNgA8fEQ5wNGA6yEvqyrC3h6v1KRW\nOp2DYxs3Gvw5ryxy/GulAo64Aj1FVAXyEpocu1Z6MDxUyV555y2GcSYeb5Q130UCFSfhLCwHB4cz\nA/fAcnBwODOYQwm3tliZ8tETuuRuvHITwLl1uiQsQDKU5EBZroeCQtoKHtlE/+BB3th9xMqX8YCR\nhJ4f4khGS5bQNP3kMe3YQ8m0RxFNxOdP2aWuIv3G5oDLLEsgBJAZ5Rlzy1qdFnujSd+QOUFiOUaN\nP9a1sZUCO+we4AhMe8xEIFLTMFAIpjkoE12jJbUYWyyGFQDDZ6TJU8ld119nQU1zZVqR14aiQy2r\nxkJ8Y23DkD8Rz5GlHOnMIw2LlTizTs3osPJ4LFtnYWkZQElM0HKPjHKeRCDOXpSDbCoRviCtASiW\nlKMv35PReSv7mo0VkBzxkqdKFhl1SJS2+hr28jKAUahY0DqvfZqSm9wZ8HZ3E5FQRcY+PJB/aonL\nEZmENDq9LoBDKexV6hRpCOqMbt054HV5Y/Kj9QVSwvUmKVWjcNznmyOVm28q8bxI8aiZMnIy5TZ1\npITX7SrLbcAfiFHzna2nACYjbmnzbWdbfnbFYEeKgK3KcVywgOdYlFO0uusDgO+Zm5j33RdJtKBQ\nczrPnhfWCX0QKaA6QoYjcarJKbX14CwsBwcHBwcHBwcHBwcHBwcHBwcHBwcHBwcHBwcHBwcHhz+2\n+D9big2JqKJg6QAAAABJRU5ErkJggg==\n",
  1219. "text/plain": [
  1220. "<PIL.Image.Image image mode=RGB size=400x100 at 0x7FC0FD0C9278>"
  1221. ]
  1222. },
  1223. "execution_count": 37,
  1224. "metadata": {},
  1225. "output_type": "execute_result"
  1226. }
  1227. ],
  1228. "source": [
  1229. "dataiter = iter(trainloader)\n",
  1230. "images, labels = dataiter.next() # 返回4张图片及标签\n",
  1231. "print(' '.join('%11s'%classes[labels[j]] for j in range(4)))\n",
  1232. "show(tv.utils.make_grid((images+1)/2)).resize((400,100))"
  1233. ]
  1234. },
  1235. {
  1236. "cell_type": "markdown",
  1237. "metadata": {},
  1238. "source": [
  1239. "#### 定义网络\n",
  1240. "\n",
  1241. "拷贝上面的LeNet网络,修改self.conv1第一个参数为3通道,因CIFAR-10是3通道彩图。"
  1242. ]
  1243. },
  1244. {
  1245. "cell_type": "code",
  1246. "execution_count": 38,
  1247. "metadata": {},
  1248. "outputs": [
  1249. {
  1250. "name": "stdout",
  1251. "output_type": "stream",
  1252. "text": [
  1253. "Net(\n",
  1254. " (conv1): Conv2d (3, 6, kernel_size=(5, 5), stride=(1, 1))\n",
  1255. " (conv2): Conv2d (6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
  1256. " (fc1): Linear(in_features=400, out_features=120)\n",
  1257. " (fc2): Linear(in_features=120, out_features=84)\n",
  1258. " (fc3): Linear(in_features=84, out_features=10)\n",
  1259. ")\n"
  1260. ]
  1261. }
  1262. ],
  1263. "source": [
  1264. "import torch.nn as nn\n",
  1265. "import torch.nn.functional as F\n",
  1266. "\n",
  1267. "class Net(nn.Module):\n",
  1268. " def __init__(self):\n",
  1269. " super(Net, self).__init__()\n",
  1270. " self.conv1 = nn.Conv2d(3, 6, 5) \n",
  1271. " self.conv2 = nn.Conv2d(6, 16, 5) \n",
  1272. " self.fc1 = nn.Linear(16*5*5, 120) \n",
  1273. " self.fc2 = nn.Linear(120, 84)\n",
  1274. " self.fc3 = nn.Linear(84, 10)\n",
  1275. "\n",
  1276. " def forward(self, x): \n",
  1277. " x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) \n",
  1278. " x = F.max_pool2d(F.relu(self.conv2(x)), 2) \n",
  1279. " x = x.view(x.size()[0], -1) \n",
  1280. " x = F.relu(self.fc1(x))\n",
  1281. " x = F.relu(self.fc2(x))\n",
  1282. " x = self.fc3(x) \n",
  1283. " return x\n",
  1284. "\n",
  1285. "\n",
  1286. "net = Net()\n",
  1287. "print(net)"
  1288. ]
  1289. },
  1290. {
  1291. "cell_type": "markdown",
  1292. "metadata": {},
  1293. "source": [
  1294. "#### 定义损失函数和优化器(loss和optimizer)"
  1295. ]
  1296. },
  1297. {
  1298. "cell_type": "code",
  1299. "execution_count": 39,
  1300. "metadata": {},
  1301. "outputs": [],
  1302. "source": [
  1303. "from torch import optim\n",
  1304. "criterion = nn.CrossEntropyLoss() # 交叉熵损失函数\n",
  1305. "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)"
  1306. ]
  1307. },
  1308. {
  1309. "cell_type": "markdown",
  1310. "metadata": {},
  1311. "source": [
  1312. "### 训练网络\n",
  1313. "\n",
  1314. "所有网络的训练流程都是类似的,不断地执行如下流程:\n",
  1315. "\n",
  1316. "- 输入数据\n",
  1317. "- 前向传播+反向传播\n",
  1318. "- 更新参数\n"
  1319. ]
  1320. },
  1321. {
  1322. "cell_type": "code",
  1323. "execution_count": 40,
  1324. "metadata": {},
  1325. "outputs": [
  1326. {
  1327. "name": "stdout",
  1328. "output_type": "stream",
  1329. "text": [
  1330. "[1, 2000] loss: 2.238\n",
  1331. "[1, 4000] loss: 1.916\n",
  1332. "[1, 6000] loss: 1.703\n",
  1333. "[1, 8000] loss: 1.582\n",
  1334. "[1, 10000] loss: 1.544\n",
  1335. "[1, 12000] loss: 1.465\n",
  1336. "[2, 2000] loss: 1.425\n",
  1337. "[2, 4000] loss: 1.377\n",
  1338. "[2, 6000] loss: 1.364\n",
  1339. "[2, 8000] loss: 1.330\n",
  1340. "[2, 10000] loss: 1.331\n",
  1341. "[2, 12000] loss: 1.298\n",
  1342. "Finished Training\n"
  1343. ]
  1344. }
  1345. ],
  1346. "source": [
  1347. "t.set_num_threads(8)\n",
  1348. "for epoch in range(2): \n",
  1349. " \n",
  1350. " running_loss = 0.0\n",
  1351. " for i, data in enumerate(trainloader, 0):\n",
  1352. " \n",
  1353. " # 输入数据\n",
  1354. " inputs, labels = data\n",
  1355. " inputs, labels = Variable(inputs), Variable(labels)\n",
  1356. " \n",
  1357. " # 梯度清零\n",
  1358. " optimizer.zero_grad()\n",
  1359. " \n",
  1360. " # forward + backward \n",
  1361. " outputs = net(inputs)\n",
  1362. " loss = criterion(outputs, labels)\n",
  1363. " loss.backward() \n",
  1364. " \n",
  1365. " # 更新参数 \n",
  1366. " optimizer.step()\n",
  1367. " \n",
  1368. " # 打印log信息\n",
  1369. " running_loss += loss.data[0]\n",
  1370. " if i % 2000 == 1999: # 每2000个batch打印一下训练状态\n",
  1371. " print('[%d, %5d] loss: %.3f' \\\n",
  1372. " % (epoch+1, i+1, running_loss / 2000))\n",
  1373. " running_loss = 0.0\n",
  1374. "print('Finished Training')"
  1375. ]
  1376. },
  1377. {
  1378. "cell_type": "markdown",
  1379. "metadata": {},
  1380. "source": [
  1381. "此处仅训练了2个epoch(遍历完一遍数据集称为一个epoch),来看看网络有没有效果。将测试图片输入到网络中,计算它的label,然后与实际的label进行比较。"
  1382. ]
  1383. },
  1384. {
  1385. "cell_type": "code",
  1386. "execution_count": 41,
  1387. "metadata": {},
  1388. "outputs": [
  1389. {
  1390. "name": "stdout",
  1391. "output_type": "stream",
  1392. "text": [
  1393. "实际的label: cat ship ship plane\n"
  1394. ]
  1395. },
  1396. {
  1397. "data": {
  1398. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAABkCAIAAAAnqfEgAAA0bklEQVR4nO19WZMc6XXdqcysvbq6\nem/0ABgAg2UwxAyHo5mhRIkSJdohWrbssOWwFXaEIxzhFz/4wb9DP8ARDlNW2H6wZYclh+RNFmmS\nokmKnN2zYkCgATS6G71UV1fXmpWLH/KcW9Xd1SNRCke45e8+ALezsjK//PKrzHvuci7gxIkTJ06c\nOHHixIkTJ06cOHHixIkTJ06cOHHixImT/78kd3rTb/3TV/RZkimFIACQ87zszzAcZkqUjLhDoZAp\nccKvpEnKg3hxpujbSKOqjh8DyBcG2Z8+An0l1dGiTBlFPGyS2IADjYFbhlJy3DPR0XIaNkcbRzqR\nLtADBxnqWx2eGb2QH/3Gv7uPCdnf3+cAIu6ay02ZzJ9UfrKDpCeV8QYv+5MbvNQ7uUdO8yMlhU0g\nd05T2/tPGKTtuby8fOKj3/rWOrWYE7W/u50pw8EAwLXnrmd/NmbrmZL3OYBC3qdiW7SMgpwWSdTP\nlFo1r6/nAAQ+B+l7PMjBQTNTZmZmuGc+r6NxH1stURJmiq1b/pnj371uj98NuJxKpVKmhCG/G+mX\nUi6VdXyeqDFTmjzs13/zn/EqFm/yKz5/U/WZWqYcDbkUu+19jU2/C93XQMMtB0UAJT/QuHUr7dZp\nQ5zEJ7Yk2jI+rK7R83xMWwC5nP3e7acan9qH3yoWi5lS8Io6dRFArsDJ6e1/lCk//8u/duIgHpw4\nceLknEhwelNoL9iE7y4kCYAiaBl54IMwCE5aTzJZkAu4aWhvG71bgoQfZU9/7Yic7DVEQ51IT/rE\n19j4XooDPpvDUB9Fno4TA8jJOisV9E7WdXmBvZx1RnDnFPZa4Nsg8KY/033fn7r9zyl/NjMtp7fZ\n2CLycgASe5+mGm0qM0qvXDMzJ779Z7ewTkutwjvlpVxswy63JGEPQKnAo1XL3CHQ4W0BFLVKyrqb\nnoY9jG0fro1C3gMnAACCQGaa7DUvd/Lai4IIsuTQ7Y10IkoGI1Itfk8nyMv6MHttNBzqQjRs2RQ4\n4/4mKQcf+XM8SJ4/t9inheXlZWH1O5mSxl2dmscZptxn5CUABpo3/VwQjghoPC3gfo8/c1vSdiEG\nSjyPSpqEADwzeDVvUaQVaE+AnD0lOD9zc7y0YnlGh+WNSLwUQK7I88adGs4QZ2E5ceLk3Ih7YDlx\n4uTcyBRImApMAcPJLTnhskQoz68Ihcm0Nh+fudwKBZp5UZLXR/7kPmZM5uSn9/QYzXk0OFOPBnM/\noWm7vUcbtRPyW53OSJcUA6iVBAQ0ttkKHZ/lEi8w8eRYFXTyZfDndSFhMh3sGAj6DDT0Z5A/zdHG\niMx2Hhvi9kl2IQLmI15yYOAh1i3LnT5jcmrLnyCfMewgx1Mbviv4PH7eiwEUPYF32y5/+bBPx7bv\n876XAt7E0VDQxuB8xC1pLgAQC+EW8vyKIUEIN1l4IZY7otfjGfd3dzNlZZFAJvPH+wWuDF/HtwnM\n6+0fCC0OFWewCMBoZD+uY+Kl3B5rbLGCIXGO11Wa4akXnl3htw4PMqXWI0gMB/zNxrUSgGS2kf05\nI9xtJ/IshjYMNQk8Y6nE2R5PmO5vtvZsBdpBIl1XYmtHS7EQcMmVy4o8wCA5pz1BDCAx++lsx4iz\nsJw4cXJuxD2wnDhxcm5kCiQMElqVlrvhJSNMmO4TURyFbE6F0iJDUhZMUSrK6sqtTGm39gDs7dMI\nzwcFnU4RQCVM9VHJlA/XaaijtMB9PH4U1ggbm+0mgI2nrezPWpEHSba55fIqT7RQM6BhKWC8xoIs\n2/hUOgn3lNX650y/+nMhSp05NnyqZLQoSQCMBLc/vc8kspVV5kkZrl+aJ+QpKUaT/ORD+oxJKCgR\nL4nkSRAQyHsJgLz+9GKuhEJekMSPtWcoRXczJziv5RoNFC70qwAGusCKXAG+BQ4Nt+hKuwPCrjff\nfCtTRkKjc/XXeNiiB0DYDjmD4VohniGd1KLbgqsWxk2mQ8IICpyBizMRQB4qwutLqSrmV6/olr31\no0wJd4kNL7x4C0Bulz+KYY4xx5ou4KjPCGNJwy6mPJq3oLikooQWPB1WSgCCkeDwSEercraLh4eZ\nElx6IVN6jVkOUpg91o0oJbzYXJoC8GLFauMzDSlnYTlx4uTciHtgOXHi5NzIFEhoSCMXNKh4OQhl\nAPCEm0IZ+YUCbdR4nD92MgOzoCDKF//SX86UN7/3fQCbB3vZn10BwCiiRfpwYydT7j95kimluQuZ\ncnH1CrcUWM8RClEWassAogFt472dzUypNOYzZaPzNFMGuqKVGo3hiooD4pCI4Kwn+uko4f+l0pzP\nxIwKbuZVGqW80H5nCKB1SLP/6R6rUsozhAYLKk+xWhMLmVmxztTxjc/6p5OCHAipLi1vZSLxEICv\nMF8uJrjLK1Y7MhAhqOvXDUQo6Vc1NElkML4IoNNuZX/VKoRFnmbSKmYCBYNbCg4221TKSrkMheHC\nUQIgKNh9VxQv5kgi/Rysdq0gV0OqlZbE0z0M4x+dRfFSqyRT0qeAWU7YbZBTjVGscrclQv7eUQhg\ndP8Tjk2ek0QVQd1AF6axFSKVFj1WmnGo+i0Fjge1EgB/wD8DXjGGFzik/pbqn3JL/O7sIi9EJxp5\nFlflVSdpAsCX9yDwzlzzzsJy4sTJuZEpFtbQ42P7sMenchwNAczV+EiuqyInkGfd/KnjysrkZLpH\nr8uckW/+/n/KlKcHAwBPO9zh4SZ3WN98zOOXaGpFPs2oWp1P66BS0z58LZT08C15FQB7MpHWLl7O\nlIFsrvv3aWE1W8rlWePRri5TyQd6t5zhJbXKjDT5CQyOdDxBx7aPE1tOWVixJtXKuX29aa1yYne/\nnSntLq+oP4wBdHsqciryVnb7vFO1iswNjaQwHsyfcBU/kS1ZzFliEWcyr/rYLJdqnEiV6HbkVKPj\nncxj8nOqEZE5ZlNppfgxRgA6R5yTR5axJaPJjKNLdU6LZV298+57mfL5O5/LlMSSwuIQQCm1dELO\nZL8nnKE1E42UPhbw+CNVyA+HPUyTWJZXovy41IwJ/cpCy9XSiWaPNBvLzMwqLz/LMaSH2Zi4w+Iq\nR5tXPfM2K6ihipyuwmLpCiNaeVXRDQSYqjNVAOERr2KoyQnK8pdrBQYLtPVyeZmiKU3FGS0fX4Zb\nlMsDyHlKEsSZdW/OwnLixMm5EffAcuLEybmRKZBwt0d7rBk1MuXb/+vbAG7foGPvl+4Ql81ZsbWl\nooiSwVNFjhU9yLeLBw+ZE9TsFQGkFfrCvZqyP+aPMqWkDI5QaTKh1dnMced6jaN6urWVKe1WE0Bd\nxnBJPtdHAoD5Ou3nna2HmVJ7yjNeqKt8x7MIgDE6HJNuz6gsZGPLtDYuMF88AaYYbZBhQy859s6w\nIiEDZh0hGvO+l+XKHageYkuQcOeASsbTMBLe6x0RDu/I+77xhNP1wo1rmfLclYscrdKIxv5+o9PK\nTfw7Ubrhne2I9+U4TwSUPDkQ+odtABBKSkUJ4Iu2oWCEazaBI4YRYsNWsT4ae/dDAN0uE4KePuWe\n1XpNJxI21EyGHe5TUvhot9XKlLfeJ0isFn0A169xugJB0WGPi6csFpBkyLURKw4QG9YZtDFVLKXO\nuKjGmYz6SLAxL5RdvPcpj/rGdzIlev2L+lYRQJoSkxYEHgfglda2eIG+mCSSqiqWUsVwRvzWzEKD\np36yDwAdLqf8Ct1HeEyAGWiSB7ucN1/em+QmM7MG4njwcubvzwEItFzTs6M+zsJy4sTJuRH3wHLi\nxMm5kWmlOY2rmdLb4+NsVFgC0BRU7IUEWfWCUmDGoTRDQ7RFBxFB1q5M+N22Ig6NBQBzy4zidRJa\ny4sqxLEIYKh8j4Gq0vuKAT2riEZPGHAnHAAIlJbVaurEGmRfBq1f4Imethmg3GoTvzy7KGx7hnXa\n6jNKVauI1zAwFGxsENrbwiKGBMdEesffGaeyura3mIM2P0/sXC7x0oYDXlqlyC2rS0TrGd9xt8fL\nqcoIDwdia9OFdcQ2F42LjRRaGqeA2UeTVzNJDoGzpGRsedrJIGExjQHUFGadNW48pY8VhY9KBo+E\nxD1d+5hmNxYXdjsEMFPl9jnN24MNUjPff0zlk3t/mCmtvVamdAYcWy98P1MCKLuq1wZw5yYpjP/G\nX/tapjyjFTgscbSDLscfdnmieqqkpP4RpkneV1mMJsHChYkcL8ZAWTvg8aMN5hjWC/ylHG3yjGFp\nFkAqwsvcFlMaq89wuYZ14S9wkZQ7Sh9rcZADVU1Fe3QgFAYhgKhNuF1sMnw/6guPlwmZWw8Y6y+U\nCQlnLjCC6SsXLFXi1RApgEgLL0zOxITOwnLixMm5EffAcuLEybmRKZDw1ksMNGz8gHn9tdklAK9/\nidsr3nqmhIqPGBrKiV8tRiNTZoT43nmPEY1agyb0M1fuAEhl0hcMYA5YrNNTZYBnTTs04A/fezdT\nZkvcUlGQsVapAdjcpm1sBRueQOK80gVbB7R4D5pU7m/R1l1bZqJdoFGdkKDOq4iF6UaKkEJBHFMs\n9GPVIYaP0uMppOPooRSrIzEOAMO2DdXZjFQ9DyGLSm0GE5Aw5xv/gTqXlHXLrE+M4rjjGM2pwWSQ\nP3/y88/ChI/X1zVIzuRRm8smHg0BPFHd1YHoIrod4v3lBaK5WpUowlfScmiUhAXx8+n+dgc9AAMb\ntLjkH21yXT3YYKi0F4quQ4FjVHh8IxW3aq3tR58A2BTm+qPvfDdTbt98LlOWGsRH/U6LI1F7m9Ft\nMpR0RLl3QorCdKnuIIw0RZDZk9JRlVvn1c9nSj34KV7REed25Gcs6UZEqQhjmSfqxqK7kCtgJI6E\nvFZyXwz6lsfZj2MAvQ7PUtXRBtqzqJ/h/AxZQGI9HDpaclDyankk/r9cbuJCMTp7OTkLy4kTJ+dG\n3APLiRMn50amQMLKLMHOs9cYEMmKpZ69yoaXiyNihtZ9Jl6OBFLiiGjr9V/4m5ly+dqrmXL1xfVM\nefNtorm52iqAzR0a6oHYvEoKaakmHB3l9R0e0MaeU9dMsx2tFnxxcRHAQFX2e8oAtFKyWk1RSGWH\nhgo53X+8kSnLcwQaNy4qNe64fP1f/Rse1hJHZfrOqEfm9auEw6+99ILOyK9bcmkWiUsNv8g+jzSl\nFuQqFIUajABDWY4Lc8pZtQ5shQIm2AKQl+muoraWAqMtMa4dHbYyZWQ5sQrwLShv8Mb1awDyVqFm\n3TknQOMJ+c73fpApRtVvRZG9QQfA+vYT7UCxWZpT5nBVgdGizpNXKmmgvEdPbb56gxBAoLasqeDw\nVlPE5wrfVmoNnVMEJB2r9eOZrAS1XqsD+OlXX8r+7B42tQNx96NHnNJ79+7xIwXPH+5zSvsKIJ6Q\napXrLdKVjmK7C0RzRpeSEwour3B+2urqunvI0eZ8H0CoZmUFC8C1uGck5F9Ujndba7JkHQ2MLlE+\njWFWnaq2DId9zZvwa0V1jjMXL2WKbx6GcWc53eBxPnIKjNdTcnbmqLOwnDhxcm5kioXlF+k2e7L9\nYaZ84dXXAFRnaZj4R8biIONCr9wfb9Ab93NzTOZChQUfM1VVPwQ8frlQwUQxhPmS19bo8P7wHot4\nCnJJtuVTvHqJ1t/N51lV32yqg0g9B2Bzm4knnt4SDfFhtcSUZDZXudLIlL7K0D99pPKgwvRn+kD+\n7LCv8nSZM51DXbq2xLef57dSI/YVL22hjAlTZUx2LFNrdp4pPGMiB+t3YvwNskmtACrhvzzausqh\nnuxwWpr7tFX7fdWRDPW2FKODUQtcvESf9OVLFwFUC7ZsLHRwpoX1zl2euqISDbvRg6gHoDHP3DG7\ny6GMmp2O5laXXCtx7UUyFT1rrarWSl5QBVDoqhvoiC78ZrN5Yth2a0MRRbR1RqsGu7zIZTM/fwET\nFT/7B5zJhQbP++rnuRQ3NmmntwecqI9UuXKaTJwXKD96eYYX2FHKYaBVGltClipaPC2nRMliOV+x\nCM/HhLN8JPKSstomGbwwW9V87bHm1nrwRCqJy5dzABKlvBnJnfE65CM1NrZMQ323FNsql+VpnNXI\nYeJ25M5eTs7CcuLEybkR98By4sTJuZEpkDBfordyMDBoMAKQVyFLpWquUDr/iqJbnQloQ/7WP/8X\nmfKrf/ef8LAqUygUzYaPAFy99kz2506TzteBHJ+rywQLRlw7VCuUa9cZAXjuOrHh4dvsd9I96gBo\ni/Q2iszFSyO/oXyZpEWre3ZO3Axy1fseL2RjcwfT5O/8rV/jkOSirp7qE1kWdDLO4XZbbAoigcgH\nJQCB8llS2ed9ZS2liXLQhCby8u4HZsznrdDnGKK0fJaBaA+MsWCu0ciUWCyAJZ/jb+0T9Ww8Wc+U\n64q3+F6ACdzqC6V+RmnOkSXCmatb2LDilQBcvMQ8pnDIkext8yt7TQZkVpfJBldapCu32drXUblz\nfY64tVScAzAQy0Yv4pyXKrrvEe/7uLerHPbm3Ih6HO3rP3UnU24+uwZgENJr/uDH/Mq9Tz7IlJ95\n7cVMuXSZS/rRe4xKmb88OaPopKBsr4LyChPR3ZUVMInEgHjUVutTEYSUZolbV6qKEaUJjrUsFQOi\nbBRf3oNxZOaUpCoPMkgY+ykmGBA9KQVDnzrsUOSLRtMS6NpjTfu49VQSYKJwzSgqT4uzsJw4cXJu\nxD2wnDhxcm5kijWYU3FAT9Bs0OsDyKu95dG+akRUiJMHQcSFBi3DTz9kKsrmBhX0iPgePl7PlC+s\nfhHAM88yJri2Q6V7jzvMFxqZUm8QG96//4AnWqPV3RLIGsl8fbqzDyAxqgRZvD3F9bxTDAxW1mMB\nrIJ4zsL9bUyTRMloYxtbH9UKrJgplzhjfdG29UacuvX7vMZCoQzg8lUWsj94zPr73/uv38iUSNGc\nko5WMUVAslEn2GnMEhF84QsvAVhaZHnEcxc5XV5OnIKy1C0SZGGj/jLxxdqFBpVn2Kwo45DrKbtn\njILPfvHlRcy/uMwxWOB1b+8xgE5XBAYqzbBksYaYyNeu3ciU+iyvqL5IkLin6HAi7JxVofTUKLSn\ncFsYKrNJJAQFY3kMeMsKYmpfXuWULs1RKeU9AEsCnnWlL+0/JO57+OP1TFlV3LO1/X0edp6jDc/A\nX4F4C3w1iC3pZ9jaYXCz2SFlwu4Wo5BzM0yZvPMC0ai1K874D0aKx1lU2parNSUwV0NuDPC5czwO\nR1o8L/vIvqtqm/F31VBHZ7QlZzvnlRmXt2BgCgCeEG58dlqfs7CcOHFybsQ9sJw4cXJuZJqBarUm\nCg1cWFzABBL55rtEeXMKAN2Yp7FXKljYhfhrd4cgLhm2MuXydVJ8+aUigEqdRv7iClNM91VC0VJw\n0PgBl5eZRRkIn1oJjpXvZ9FAK22xDMOBIoyRWMYXBSvgqQmryMxKinFEYhM8Ib/7e3/AsYn32lPy\nXk3h1BkhtSs3eGlLC8RHCxdYtTO/uAygJDaC1kfEF+9//ChT+sKvBibsvsyIrv765SuZ8qUvvsLj\nV2cAVH3V0MjEDjVdkdpk9awiRw1ByzpsoyG+/G02RtvbawIoq45kZZUTWKko+/eUNATnJwqhRMIH\nD0BT5HnttlIljfNbKO/RBgdQb/O79dkGdxYdXE/5rshFmKwvqfB2lCpWxGMAhzNZ0z6B2pdeWuC1\nG1tDt90CEAlgGp/9VcHVjz76cabcvPW8js/Z3lQqaUmFVifE4Jh1BkiE1I6ULL27S+/EQZNH++S9\nH2bKx+8Se16/ziKwK9dvA5hbFAuFQJaxSxpPv6Ev3+hGtC0Y9yI41mtuoh2sgo/a08LFpzsNm4yD\nj2POkuws9lOd3lsPzsJy4sTJOZJpeVh6WNZretPOlDHRUrQtsqC9Fp93i3UepypPZKQOKOvK5VmZ\nb2TKs3oJZJkyP3zzo+zPJ1v0ns7UaHPllYHywaePNDorPVG6hx7GnS7fvXML8wAi7bCtGp1anQMI\n5HSv6L1qRSEImfiT9DiY1eXpxc8/evt/Z0pZNEzDkJ71vJzKP/3Tr2fKwyeki92n2xR3PscyjkK5\nBKA3pHWQlxn7yiukOhqoGap5iG9cY9nT58SytLbIS6tXaPskgxDA4232B905EAf0Hrd0O/RJt1Qc\nHo7UKV4nsnJrq8EajSIAlQbn5A54FbOz02cJQHC8JhkTL8msXNnCI4FqtmxLocTDLi7R61+r8QJL\nCjgEGmSQ543IctBSFYJY36NZ5aB51u1JnFCB1bgMlZqnMus04rTE8RBAqNKTvi6nMsO0xIfb9I5/\ncJ/Wt9X3jFT2lLbPTHrKxEyVkvjBn5e9dv02oxa9I973D95i7uFbb9DC+s531jPlww/fB3Dr9svZ\nnzdu3c6UxlwjU2w5+f5Jw0qVXZNbtACSGBNZhCZWrBPLmE/GKWBnypgVLudjooouSs7M63MWlhMn\nTs6NuAeWEydOzo1MY2uQg+3C8gXt5AFIlLBz4SIhyY821zOlBTVrCegmbyzSLTdbp6Fu+ThXBAlr\nswsA/uVv/uvsz56O3+6z6qKn8hrzN682RJXV5Km7RTsRvaQffbwJYEe0BObKbXgcZL0haCAWpCCk\nMR/0mAY1XxGOKE03aXcfE6XOy8a+eJEe6Bc+z2qhvGDF++/8MccvO78mkqOdvS0A1TphxUKdO/z1\nr/08B6kcp9lZ7rO4wOybZpMT9eAh6acPW4Sl7cMjAEeKWhyIhqmpfieRQhAFebgLgvNGYlGvc/xW\nxzO3PAOgaFC6LGoBUVaclnmRTSehebh5oiTqAyiIZWF5ZS1Tcqo9KiiryMBpSZUrvgZptBbG/pzl\nBFmiWa+rQhxjgJI/PhU27B1yJp+scyabyhFqqKvrykIDQEl0EeYYTgOi+EClP3tqZnNxjUuupmtv\nD6a7k61kx1oRp55tkWNbmVmNBdYn/dxXuOSuX+dP8rvf+lam3F/fANB7WywUYih58SW6Gi5d4kEC\nRWbiyBi9rZBI12jO9DTFRD9gIxCx5k/GdTXuA2ttay29y+qTxk53D0CSnsSVp8VZWE6cODk34h5Y\nTpw4OTcyBRIa8W59jsZ8FAcAijJ9b4r590dvEFsdFljNn4Dm98pFHvmDDxm/+Nmv/MNM+b44c7vd\nNoCRAnM7WydDgZ1I8SNht4ZH7PZMmdjncJc2fOQ3MmV1pYEJa9YqcgbiQe70xMWsjqrRgGVDywFD\njWuiUR5GVs9xTJ7cZY1+W7GnX/3lf5wpX/vaVzPlD7/JaNGy+CGW1XW1rFSgUi4BsCI+3xkpJSVD\nRbLGDRZFSmPZ/oTDfrTDNKVQ7XOCUhXAzAyzfpYFZEbhyfhOXkjQSuRNmZlhkK5en9FHOQAdEfI+\nfcp7Z3N7WioCSpGnsJqSzhr1ZQDJmAaS96Vc4+lSq+oQbElSbTlFs5uaggRApBsXxRxbe19k3Hbt\ngoSdQwZPN9XCZ3Ve1U5Vpv5lPZwSQdFIh7Fw5DMCWbduMtPw5Reo3L3PMPHb732EaZITEvTEZeyJ\n+CTvW6GMsqIUxfMUGL1xk8TNiX4ym9v/AUBzj+A0UXHY0ycfZ8pzNxg3vP05fnd5RS4g/dKjkfia\nlcwYpzEm7ssUamzh7tMkfGOWx/HF2pdSYIwwxxU/p8RZWE6cODk34h5YTpw4OTcyBRJWa4Qtc4uM\ncWQ97weqXynVZC0rePToEYsGvvw62c4GHSVn1mlsb22wnuDTu3d52CjEuDEHOuJdqM8TilppTkMp\nrLdu8fg/fIeW7Vsfr/PUX/mVTMmIBu/f+/TEQSzXdKDqisurRHMl5VvOzwuMiJIwCqfnsA16jLu9\n+HkWyv/SV38pUxbUKfZnv6hIn6DHjCqK6ppkv1DCRDdQi1sZS7c1CqrLUE9EDHFNs7F8kXHJ5gHn\ncKbRADASWskJLxlvt4WlrOlLR9G0VC1SjFb88RYTXgf9HoCRUHasEo1K9czSHIPktQrn1vDdzu4+\ngLYyVy1f9PotJkYa3bufNzRExXBxqIYtPVHr9Yc9AJHyeD2VHCVD7lkTCjaa/3JBJV+KfzXkE5gV\nyXo4HALoaZBGN+ipoGROcL4iisqNxyy0EqrD556/gWlihP3+WJErwOqIrHQmORZcAxAK6V+8dCVT\nrly5CuANaycszsKdnRYVocWPPmIXq6tXObbnnqOyssJU1RklxyKXBzBQW9ZYv4684LyFAi1x1Cpz\nUuOxHIutzxwmi4Qcp7sTJ07+Aoh7YDlx4uTcyBRImETEULPzREzdfgygJ3xhUaTLl0lCcPd9orzD\nnpIDq4wkXiZhN9Y/Wc+UzU3ii5/50msAusId9TXmDc6vMbbyqEnc11NL1UKVNvzsMo//Sp1j2N1j\nDGh9fQNAR7X7LbWWXF6i2V8HjeErNWK35brI0UEcESrGVD2DS+za8y9nyq//g3/EQcYEGp/cY8wu\nyYnEQpHEkTLimi3Vuyc9ALG6ZipGhATEL0dtFuv7T2n2byondihUkigdsaoo5P1PNwA8eKTAq1Ix\nFxYX9F2l6aqR6p4mEEogHDMdSqlVygAaJZ7FOAX7nemxVEzkozb3OOz7YmqPkiGARoOlo2trpBYI\nVao2Cgknk5RDaguJ9/rG5DHUaIWh8h4mcF9J3BJl5YuaTyBRuK0qBkdDZAVV2Nlqz8KpRi6Y80/G\n7Eai4d/YZ+Vmr9vKFCuoXL1wEdPEF1wyBToRcgrsjtMsT9X66SOrQKzP1IHJuk0pRrGvyHu7yfvy\n9h7x4wfv/ihT5lX/u7rKn9vq2hUApZLynBcYWFxaoRvH0nftlkXyMFjr1nHiqOWdJh4mWBzSM5jv\n4SwsJ06cnCOZYmEdiVKgLA/xcBBCnS0wkZi/NM/X9V15zneafAHu6Z3cmOGj9/aLdEnef8iclIzA\nypziN2/Qc3zjKq2y9c1WpmSl5wD294xfQd1f9G7ceJ/m2NZ+G0BOIQJfFf+rF2m4XdFT+rJ8+cZ+\nNZQplyR5DXJ6LcWv/f2/xwGs8p357vuMKpgHNBy3CVG9hVy25lbM+prE9m6xHp/jVwm3hHo37u3R\ngrNUI7OEGmKJylzRzX01Rpe/dm9PjULFFxzJ6W7FOtY5pqK26SUlH3mRDyC0jjRqf1JWatVpaal+\naGuThm1VlT3Pv/AigAWxklUU+hiI3fjggGl3xiTRE61CRXlqs3Wu0qp61pcLeQCBbKVYTvcsyANg\nJKLqgXV2GXP+iqVXeEKZbQj8AoBULVcHQyr7uzQY9/YZX7JqMGPCMF6QokiNT0guNQuLW8xFnZOp\nYtwGExUxVMzn3e/QHt/e2gKwtUWjqX2oCjkZhjMK+9RklJXEO2IUchvbXNJ319kNdzCIAUQxD7K4\nRFR05w7r7W7eYDLa0hJva32WkZNimU+AFFot+oHQpjfabud0d+LEyV8AcQ8sJ06cnBuZAgnv36P5\nd1nJ+yUvBJAIRARmQ5qHT07lmkiBn3+eqTR/+Af/JVN6LVqnlQX6Vu9t7AC4eJH+vKu3SO9bFCR5\n7ln2kmk1W5ny4Yf07icCIxsHtPPbfdn5cRFAu0WkubxKG/VRk1vmLzYyZV+QB+qV0pK/OVVDoEEy\nxDR55503MuW9997JFE+GrifT2koczOcKGCMCjeqg4GFiJgtjogLx+SpFy0v5Ub1IL7UnAozIt2tX\n+lgKAAUhkUgsgL2WRRV0XVasIxQaynsdqW1SRzioUggALIvALxAuK5xZSoH5Zd7ueWEEYwHOFtKR\nCqSOOup4WsxraOLVkxv+mRVGToq6d771jlUxVnfQBzBQsKIlXGmQbTDgGW/fJjdeXhmFE3zBJ1v4\nDLtHADa26dDY2aWvOhSUNnIRyywryFXS0TV+4xvfwFRRMldiOVaR6mOEFq0PVM5X0pMglS83/Ltv\nvckztnYALCiJ7PEWr72uZLG8VrgF2epKPQvEjlIITnpgOl4HwH6LgZr1ByxQax1wWt56QwtYpJiX\nL9MVsyZa8Atr/EmurXBLtTYHIFcW5YN3Zlqfs7CcOHFybsQ9sJw4cXJuZAokfOdTBqEu3yEleYIu\ngJzFy2S1ttXPo9VioGRh/uVM+ZWv/WKmvPx5Wt2//R9/J1NyKvWenZ0D8Mwao2zGue5HDBLNr3J4\na9eICFrCIG+/806mbHXEvZ2nrTu7ughg8Tr/9AXHYpnUd9UI59620rv03LY6la6uNUpsir6FCfmj\nb/+PTOmJGq2Q52HLqkGx6fVTVfZbG8u8QcIcgFLxJMouiF8hUOpZSW1lDWgodgevZC8ehV3CEMBA\nZTEjBcgs88heVcF4i65UkHy2RujRqHBLreIDKAT8Sl4pQrl4OnDGRGcUS9oKBHuTDOxYDYoyniz1\nrSTcN+hy/P1DLrm+uq8GBZtSEcXFEYBPPiRaebi+zpEI+BuSWrvANKJ5kSP2BetMOZA7otnaB9AL\nLf/L6EC4xXr62s2oCFttqbZpW7UyJ2QkhG4h5lwk2gZDi9o5VQqVhRQ7Cg4O+jzOrZsvAHjl5dey\nP994jy0I/vhHzLE6FKl/rLWxvMqQ35e//OVMCXTL1tUs9vs/+D6Az71ALv+6XEA7uq5tJQlaTHZV\nJBBXr17hGRUT7x4d6opSAHm1sx2c4hQxcRaWEydOzo24B5YTJ07OjUyBhHfbBCN7sagL8gMAXij7\nLRH/lrLs1tRQ88tfYqSvlGfc6uqzLPj+q3/71zPl3//Of86U3e1DAJuHRhvA/qwFWbzNHpV7YoOA\nIjLpIpHm3DJHa2AnSxlNSrZdJGRCsoeRijbGiZG0rbsezfuR4l5pMt06XVmiMbzVZ/wljluZUlez\nzEClOe091moctWmHj2KLfw0xtRZBHGaFMufWMG805njj+6ai1q1VkazHWVauHVb8ATkBqJJwX1mT\nMK8U3EtSLq4xJCcgjsHgCICndrOBMEmjXj45fsndTz7MlBfuEEfYtGej8xSaS1TDYbCiJwb6QY8R\nagsXWqPca6IzX15mgmJW+REoVjsr9kQ7ryXlWvLnRx9/kilGUGHOAcuizBZYR0mhfXEWGiQM1avN\nOOMfPeXasAzS+IwGVuO2o2P2dP5vJHlCzEgEEi2oWVY4+Mtf+ao+8TDB137zZbp3XvwpKgqujuff\negVcu8bM7UAzduUGSf7WLt8CUC7zdlufARu/9Rkw3Le8xNRxo3zwhZQ9eWniZAhgpCtNctNnCc7C\ncuLEyTmSKRbWJ2qP+rvfpaPuC88uAlgtqHm3XiAXVvnsvLDIl9hz11TbqRKKrV0+cb/+b2lYvfU2\nX7lZxc9E6YucpnLXxSUeNjY3s/zl1ic18tSI/PilDELrPmJ9t2kn+LI7UtUMR7LO8lY6Y0lJ4fQq\ngXSkEvEq30JH1jUz5kv4+du0KZILtLl29zgbO6Lr7bRiTLylYyVSJRGPVg34Xnr+JfJQb8q5uyt/\nfz88+drPSn+KKq6q6pY1qpyuJTX7WV3jTbyu2uOVEqeuo8SofdXHZh7uquIAtRm+aRcW5nCGjJT0\nNOhwtJ6sy8yKMHos63h67y7tnSMLaOidnFdQwhrfJ1aqbWW9cQpgcYGDNHuqN7aeqDx69PjEPqZY\np/ieCrAPWy0A3T21yw1s2Lwco+jqKk0pUo1RPO7tPt126PdpQvpKHwtEBh3qpxQp9zDSldphjd3M\nqneiOMJEM5tQ1uva5au6QhWHSfFEmvbgETPX+qGhFrFmz16dPN3BofpOaTaq9Su6UNX5H/LSNp82\nNVqOsqj6uayyKFdTdfrBmU2YnIXlxImTcyPugeXEiZNzI1MgYUd22jfeYh3Mp/fuA/grr7Ig+7k1\ngpQH90lD/POvkau3lKer+EiI7Lf/G/M+3v6Axfq9SAUxQQmAJzfwuJeknO6pZz45fjQUZBtpS07J\nNUMdNjM3g+AkuKtUZH/KtI4NQ2ge7ESR2mQWlB12QvY3Wcgej2i+9mXt96zHqjpfLolAKj8kZCuL\nYKHvpwDS1ICxsIP8jr0+weOXXyfAvHObpMyPHjE7Zu+ATn0rE8kc2oabSjrdkiBVQ8xZsc64vcej\nfSJeJMjnWl8mvKrM1gFUZvjdebFr1eR8PS1l3YhQiMxCHFnQxpMzuSDcak16LDJQk1PZU2ZQRRcS\nKWfn7sek62g39wG0VA2TWLtcHS3QkiiJ5MD4LnrC9bsi7TJI6HsBgDmth1B79pQSFokEIhkDwJO0\nCrncdBPh29/+nxx8RMLiipKSElE/G/nHGIQmlhopwmgLESQRAE9IbSBwl4z5sKz+RlGXBmMstZqu\nUT147Dv00Ht2B5UEZwmGenpY0MMbE4+cuvZxUmAMAFUdRIGs0+IsLCdOnJwbcQ8sJ06cnBuZAgkX\nFmkZNptEJVutAwDfU6OaePSs9qXVtyQSO/i02H/4Bin3fv+b38uUYVLVOcVq4B17XMaWYyVD0Zqh\nGieslddYjCZnBSXGkeD5mMj1mDH227H5OjpxtEQkCmZar64S49TrVN7AMVlV4G/jEbFhNLTsGCoP\nFO06VJ6UXXBX6V3daAQgiQ0SiodaIGI4IOJ467v/PVN+scoruqMr6os+wUJmWR3VwCJcAhE7exzt\nQxEW7/UY9hoIHpUEAOeVXlea5fj9cgEChgCKwpU5f8pC4iVrkAZGPNVmZaMd6AItx6pkeTpSLMAX\nNulYeGw0x8ZZbKHeoICJkqygZEfjkELh/Y4oPSxuaB1hLTZc0vhH/RDASIwCFpC1AJ/5NCxzKlKi\nYhob7J0edC5phURCgoFqwoLCrEaiuKQ5T8b83epVYyCRa83CiOHx7ZOXaI2UbA9ROap3VDhQ6dXx\n32ykJrgGzC030NNoPZzEjyZhx259BGCgz0vBHs4QZ2E5ceLk3Ih7YDlx4uTcyBRLPpAdmy+IQmxQ\nBPBgx+okmPn5C6+Qpa/cYEF2W5zov/HHhFAD2agjGfxFMXtlJrTlTJr4MiZPW8/F00jQdlYtTrlU\nxkQmm5GyH6nhipVHDAVSZhus6li9QKUmHNETI8UJuXyTJGTtLiFVd8PsWHG/Ceg1daKCqmpChQXj\njLHbYLAdIrW4Erfce4/x1sdHnMklT+1XlS4Yy+rueAmA7ZRo5Z6ikxtiBegZAcNl1uivXiGbWkm1\nLDCgJ8q9Wq0GoKIonqfE1PSM4BeAtpg8ekoc3dkUB8MgBBArRdYoJUaCbHZd1kI0L6KIcRRYit3x\nbMKMm2HQURxZRAtHh1QsNludUVKxJjAdKTAtFsMsr/VQ3YYMCcbKyTRi+OTU3TSCitwYsh0TyxPu\ndBjwrcjFYceyTsAWCgwjG5syLT2LG0YAQmPpEPeDJZ2Ow4WG2cdI08alORT+zb5lZXATFWXxCcUa\nqXqnfsf2USAgORpFAHpzXFcXLs7gDHEWlhMnTs6NuAeWEydOzo1MbaRqPT5lKwYlAGFEu3ynQ/vz\nzY8ZsvmVHm28o5QAarNJpaQgXdSzHDYRhFcqAAIZq1Yfn9OoLKxgMcFUHAbGhGfFZZ2Qww6jLgQM\nMVH+bgCwO6ChWxMSnFtmPZ2RiH/8gCHRfGK27jGpNxhKW1phKG1LkNAsYKvMH8pOtp5RsXo3xTgJ\nH04M2w43Egbp7jGtzis2MsUX68CmTvQOhgDuCUB1apy32kUW/S2qbe2irr2o5MwxF5+gTVEM9H7g\nA/CtyaiF87TltGw//FQHO1kBl0XTAjG4Ww/OnHUzzRMWVUSOmDuFXyIVhHYiRRKHEYBElXFjAjz1\n+yoUGYlbfoaT0BVcbR9QsaZn6TgKmcMEgZ+xOKTpyTtl2DBvRAu6y73edA/DxiPSDX66zfNWldQa\nCEVG45XFGYv1UaKgc36chj2CKgoB6NLHLgZrEGtd+8YxR9tH99dmO2OkSOKT8VBPvo6cGErG5PRa\nRafmCR3l9MZzFQDPvMQmEnUlFJwWZ2E5ceLk3Mi09JkxZY96cnh5AIn1mJSZs77D18XXf5utcb76\nlVcz5f4mrYDeONdJvnxVV/iFAoCK3pkF2Up9FVWYvzyVcZQv8dS+3vnmoLUt2aO9b3k6uhzboSHj\naGGVsYLdPdaRt1SV0npEu+D6NVW3H5eSOtYUdTn2covlr5XfHFHu5JSOi/azneztY/vpLZdK6egt\n95Fe8rNqqPPxgKzW74tdulmvAJi/xMGvXSErWWON116ocPzWhHWksQWya3wpgd722Rt1bCLl7AV7\n5pvPT5SmFJu7VzZLdjRL2EntLc3vDsW8HIkbI9GcTvAfUMzpnnUVNesgkKkVaxWViqpYUu3RwR7t\nmq5iLHmtdt+6ew6HmOhhYybweBK0kq3jaUlLriPaiV73ENPEg1aRLYTYiKRlAdkk+5pJGVDWHtWK\nsbI5tilNlftmk5sadDCCCuvBI/7uWJ+NzJTz8wBS61Rk5F1mnVnb1/H8KI6h8EgkMuu6mEIuvngT\nQJDj7WjdfR9niLOwnDhxcm7EPbCcOHFybmQKJJxXU0mrmehGIYCCOi9aKoenRK3v/PC9TFnfpBv+\nsEcbuykPvTJCUBUYySoMijqI4Y6S/OW+zHL7yGzUSEAvZ749mbhxOMJEBkpZSHNxntQC80tEgqHA\nwlB1/H1j7xWgyLpynhYrlO+qWH+mwRMNugQyxv0QyyqOxwa/xm8W9HFJhX1SJUN1lWLzXXFVPxKF\n9H5FuUgrzA5bvbgE4NoiowoLCi94mvyuAKDVQwTywtaMOVo7B0qdK5UrAIqa0rzIOT5DzNU9ZgFW\n+lOa5ACkikSMkaa+ay722Nz8QqnFohwLWiTG+pDy4OZv1u1Q1CJU+lhP6UXd4zUiAHIFHnagPMFs\n/FoyY0xvkNC2GBtEGvLUB/vE7KPwjOVkpJXaYSSsbh9BxTqWg5gIf3maWyPqS9IIkzBcnpmCrt3w\nZZIeQ+iYyMMajeSrNy97mmKiFe6YhcI8C6nc/7mTP9WRqC7nbpKC+eJVlvQNnu4A+LH4NsoipDwt\nzsJy4sTJuRH3wHLixMm5kSmQcCgQpE4rGCYjAHlxM0RmFRv/gUJm65vkALA6+yg0y9a646hZaa+D\nidiKlexUraFLhdjQk8FZMP42oRWrvN9tijEaESZKN+aU1LG60KCyykhZS9it3WI9REfdTRrqfLO3\nM71wPFQxhF+gxTu3xBON1H80UrhwZJE4494WJMyuzDJ3cqeCg1BVRyDeu1FZpS2zHORzsysaNqtq\navUAwIygYlGVRgOr6rB4pdHaiT9vjBY0hrwgeRZpzWtPS8hKz6AqBzAIrfTfIlbH0nw8XaCRu9uS\nmIB7wiCWPWSwKzm5wLJamZHSCX2t55FwX6zDVrUUDQl6RpLRV7HL8T43yal4riVkBQLINi3Np/w5\njIaM3uZOQn+JBQBF5+ApXphXTA3x+IfHnRV5H5M2mIshzQEoCdg26mK4twvRsM25YT+Zgm732Pmj\n72WRRPtK58hi8Tqs7mbbgs6LPPXlmzczZX6eDoqNj9goa//efUxknJUKZ02Ts7CcOHFyfsQ9sJw4\ncXJuZCokpDFcFAbJCvsTxS9yFqSwou2xchIJpslJ0z0dl3onmLD/Dw6I6ZpiSa/XWJAxK4BW92hM\nJurtGSVDXYkgQMkHMBTlmDGIB7KWo566MPW4T6fF7luJyAyM7XsQTH+mB0oTbSwQnNbESB0PxWom\nKGgNoNIxmZlRC3iYQCLWm9aI0AJBzrKyEOs1VZbUSO1WE1NFTbA6s+pDcfJ1NNqeAQHjNRcrQEGI\nzACgAbEx/kpTAKGK7AsFKUo1PC35ovE1KnPYPAmehwmmh3FwcJxmezKwCEUSLQJrlWSRQloZsX1f\nSDDuq5hGUcKqvlKeZeDY+OdGKtvyToE3onWL/I4bnlKrCq522/QwtJUvaojZ7jtUlcLt5mcx8nWV\nSKXim/RVkWOKDXLM22d1NrkUE5yIvaCtAYxBob4rX42mJQjtbnqnvnVMIo3Nzmvp5fVlFYHduqZj\n8UQf//AHmTLc4e/Oj2NMVAslZ7SbhbOwnDhx4sSJEydOnDhx4sSJEydOnDhx4sSJEydOnDhx4sSJ\nEyf/z8r/Ab+8NWulkQjIAAAAAElFTkSuQmCC\n",
  1399. "text/plain": [
  1400. "<PIL.Image.Image image mode=RGB size=400x100 at 0x7FC0FD0B8EF0>"
  1401. ]
  1402. },
  1403. "execution_count": 41,
  1404. "metadata": {},
  1405. "output_type": "execute_result"
  1406. }
  1407. ],
  1408. "source": [
  1409. "dataiter = iter(testloader)\n",
  1410. "images, labels = dataiter.next() # 一个batch返回4张图片\n",
  1411. "print('实际的label: ', ' '.join(\\\n",
  1412. " '%08s'%classes[labels[j]] for j in range(4)))\n",
  1413. "show(tv.utils.make_grid(images / 2 - 0.5)).resize((400,100))\n"
  1414. ]
  1415. },
  1416. {
  1417. "cell_type": "markdown",
  1418. "metadata": {},
  1419. "source": [
  1420. "接着计算网络预测的label:"
  1421. ]
  1422. },
  1423. {
  1424. "cell_type": "code",
  1425. "execution_count": 42,
  1426. "metadata": {},
  1427. "outputs": [
  1428. {
  1429. "name": "stdout",
  1430. "output_type": "stream",
  1431. "text": [
  1432. "预测结果: cat ship plane ship\n"
  1433. ]
  1434. }
  1435. ],
  1436. "source": [
  1437. "# 计算图片在每个类别上的分数\n",
  1438. "outputs = net(Variable(images))\n",
  1439. "# 得分最高的那个类\n",
  1440. "_, predicted = t.max(outputs.data, 1)\n",
  1441. "\n",
  1442. "print('预测结果: ', ' '.join('%5s'\\\n",
  1443. " % classes[predicted[j]] for j in range(4)))"
  1444. ]
  1445. },
  1446. {
  1447. "cell_type": "markdown",
  1448. "metadata": {},
  1449. "source": [
  1450. "已经可以看出效果,准确率50%,但这只是一部分的图片,再来看看在整个测试集上的效果。"
  1451. ]
  1452. },
  1453. {
  1454. "cell_type": "code",
  1455. "execution_count": 43,
  1456. "metadata": {},
  1457. "outputs": [
  1458. {
  1459. "name": "stdout",
  1460. "output_type": "stream",
  1461. "text": [
  1462. "10000张测试集中的准确率为: 54 %\n"
  1463. ]
  1464. }
  1465. ],
  1466. "source": [
  1467. "correct = 0 # 预测正确的图片数\n",
  1468. "total = 0 # 总共的图片数\n",
  1469. "for data in testloader:\n",
  1470. " images, labels = data\n",
  1471. " outputs = net(Variable(images))\n",
  1472. " _, predicted = t.max(outputs.data, 1)\n",
  1473. " total += labels.size(0)\n",
  1474. " correct += (predicted == labels).sum()\n",
  1475. "\n",
  1476. "print('10000张测试集中的准确率为: %d %%' % (100 * correct / total))"
  1477. ]
  1478. },
  1479. {
  1480. "cell_type": "markdown",
  1481. "metadata": {},
  1482. "source": [
  1483. "训练的准确率远比随机猜测(准确率10%)好,证明网络确实学到了东西。"
  1484. ]
  1485. },
  1486. {
  1487. "cell_type": "markdown",
  1488. "metadata": {},
  1489. "source": [
  1490. "#### 在GPU训练\n",
  1491. "就像之前把Tensor从CPU转到GPU一样,模型也可以类似地从CPU转到GPU。"
  1492. ]
  1493. },
  1494. {
  1495. "cell_type": "code",
  1496. "execution_count": 44,
  1497. "metadata": {},
  1498. "outputs": [],
  1499. "source": [
  1500. "if t.cuda.is_available():\n",
  1501. " net.cuda()\n",
  1502. " images = images.cuda()\n",
  1503. " labels = labels.cuda()\n",
  1504. " output = net(Variable(images))\n",
  1505. " loss= criterion(output,Variable(labels))"
  1506. ]
  1507. },
  1508. {
  1509. "cell_type": "markdown",
  1510. "metadata": {},
  1511. "source": [
  1512. "如果发现在GPU上并没有比CPU提速很多,实际上是因为网络比较小,GPU没有完全发挥自己的真正实力。"
  1513. ]
  1514. },
  1515. {
  1516. "cell_type": "markdown",
  1517. "metadata": {},
  1518. "source": [
  1519. "对PyTorch的基础介绍至此结束。总结一下,本节主要包含以下内容。\n",
  1520. "\n",
  1521. "1. Tensor: 类似Numpy数组的数据结构,与Numpy接口类似,可方便地互相转换。\n",
  1522. "2. autograd/Variable: Variable封装了Tensor,并提供自动求导功能。\n",
  1523. "3. nn: 专门为神经网络设计的接口,提供了很多有用的功能(神经网络层,损失函数,优化器等)。\n",
  1524. "4. 神经网络训练: 以CIFAR-10分类为例演示了神经网络的训练流程,包括数据加载、网络搭建、训练及测试。\n",
  1525. "\n",
  1526. "通过本节的学习,相信读者可以体会出PyTorch具有接口简单、使用灵活等特点。从下一章开始,本书将深入系统地讲解PyTorch的各部分知识。"
  1527. ]
  1528. }
  1529. ],
  1530. "metadata": {
  1531. "kernelspec": {
  1532. "display_name": "Python 3",
  1533. "language": "python",
  1534. "name": "python3"
  1535. },
  1536. "language_info": {
  1537. "codemirror_mode": {
  1538. "name": "ipython",
  1539. "version": 3
  1540. },
  1541. "file_extension": ".py",
  1542. "mimetype": "text/x-python",
  1543. "name": "python",
  1544. "nbconvert_exporter": "python",
  1545. "pygments_lexer": "ipython3",
  1546. "version": "3.5.2"
  1547. }
  1548. },
  1549. "nbformat": 4,
  1550. "nbformat_minor": 2
  1551. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。

Contributors (1)