You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

4-deep-nn.ipynb 92 kB


  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# 深层神经网络\n",
  8. "前面一章我们简要介绍了神经网络的一些基本知识,同时也是示范了如何用神经网络构建一个复杂的非线性二分类器,更多的情况神经网络适合使用在更加复杂的情况,比如图像分类的问题,下面我们用深度学习的入门级数据集 MNIST 手写体分类来说明一下更深层神经网络的优良表现。\n",
  9. "\n",
  10. "## MNIST 数据集\n",
  11. "mnist 数据集是一个非常出名的数据集,基本上很多网络都将其作为一个测试的标准,其来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员,一共有 60000 张图片。 测试集(test set) 也是同样比例的手写数字数据,一共有 10000 张图片。\n",
  12. "\n",
  13. "每张图片大小是 28 x 28 的灰度图,如下\n",
  14. "\n",
  15. "![](https://ws3.sinaimg.cn/large/006tKfTcly1fmlx2wl5tqj30ge0au745.jpg)\n",
  16. "\n",
  17. "所以我们的任务就是给出一张图片,我们希望区别出其到底属于 0 到 9 这 10 个数字中的哪一个。\n",
  18. "\n",
  19. "## 多分类问题\n",
  20. "前面我们讲过二分类问题,现在处理的问题更加复杂,是一个 10 分类问题,统称为多分类问题,对于多分类问题而言,我们的 loss 函数使用一个更加复杂的函数,叫交叉熵。\n",
  21. "\n",
  22. "### softmax\n",
  23. "提到交叉熵,我们先讲一下 softmax 函数,前面我们见过了 sigmoid 函数,如下\n",
  24. "\n",
  25. "$$s(x) = \\frac{1}{1 + e^{-x}}$$\n",
  26. "\n",
  27. "可以将任何一个值转换到 0 ~ 1 之间,当然对于一个二分类问题,这样就足够了,因为对于二分类问题,如果不属于第一类,那么必定属于第二类,所以只需要用一个值来表示其属于其中一类概率,但是对于多分类问题,这样并不行,需要知道其属于每一类的概率,这个时候就需要 softmax 函数了。\n",
  28. "\n",
  29. "softmax 函数示例如下\n",
  30. "\n",
  31. "![](https://ws4.sinaimg.cn/large/006tKfTcly1fmlxtnfm4fj30ll0bnq3c.jpg)\n"
  32. ]
  33. },
  34. {
  35. "cell_type": "markdown",
  36. "metadata": {},
  37. "source": [
  38. "对于网络的输出 $z_1, z_2, \\cdots z_k$,我们首先对他们每个都取指数变成 $e^{z_1}, e^{z_2}, \\cdots, e^{z_k}$,那么每一项都除以他们的求和,也就是\n",
  39. "\n",
  40. "$$\n",
  41. "z_i \\rightarrow \\frac{e^{z_i}}{\\sum_{j=1}^{k} e^{z_j}}\n",
  42. "$$\n",
  43. "\n",
  44. "如果对经过 softmax 函数的所有项求和就等于 1,所以他们每一项都分别表示属于其中某一类的概率。\n",
  45. "\n",
  46. "## 交叉熵\n",
  47. "交叉熵衡量两个分布相似性的一种度量方式,前面讲的二分类问题的 loss 函数就是交叉熵的一种特殊情况,交叉熵的一般公式为\n",
  48. "\n",
  49. "$$\n",
  50. "cross\\_entropy(p, q) = E_{p}[-\\log q] = - \\frac{1}{m} \\sum_{x} p(x) \\log q(x)\n",
  51. "$$\n",
  52. "\n",
  53. "对于二分类问题我们可以写成\n",
  54. "\n",
  55. "$$\n",
  56. "-\\frac{1}{m} \\sum_{i=1}^m (y^{i} \\log sigmoid(x^{i}) + (1 - y^{i}) \\log (1 - sigmoid(x^{i}))\n",
  57. "$$\n",
  58. "\n",
  59. "这就是我们之前讲的二分类问题的 loss,当时我们并没有解释原因,只是给出了公式,然后解释了其合理性,现在我们给出了公式去证明这样取 loss 函数是合理的\n",
  60. "\n",
  61. "交叉熵是信息理论里面的内容,这里不再具体展开,更多的内容,可以看到下面的[链接](http://blog.csdn.net/rtygbwwwerr/article/details/50778098)\n",
  62. "\n",
  63. "下面我们直接用 mnist 举例,讲一讲深度神经网络"
  64. ]
  65. },
  66. {
  67. "cell_type": "code",
  68. "execution_count": 1,
  69. "metadata": {},
  70. "outputs": [],
  71. "source": [
  72. "import numpy as np\n",
  73. "import torch\n",
  74. "from torchvision.datasets import mnist # 导入 pytorch 内置的 mnist 数据\n",
  75. "\n",
  76. "from torch import nn\n",
  77. "from torch.autograd import Variable"
  78. ]
  79. },
  80. {
  81. "cell_type": "code",
  82. "execution_count": 2,
  83. "metadata": {},
  84. "outputs": [],
  85. "source": [
  86. "# 使用内置函数下载 mnist 数据集\n",
  87. "train_set = mnist.MNIST('../../data/mnist', train=True, download=True)\n",
  88. "test_set = mnist.MNIST('../../data/mnist', train=False, download=True)"
  89. ]
  90. },
  91. {
  92. "cell_type": "markdown",
  93. "metadata": {},
  94. "source": [
  95. "我们可以看看其中的一个数据是什么样子的"
  96. ]
  97. },
  98. {
  99. "cell_type": "code",
  100. "execution_count": 3,
  101. "metadata": {},
  102. "outputs": [],
  103. "source": [
  104. "a_data, a_label = train_set[0]"
  105. ]
  106. },
  107. {
  108. "cell_type": "code",
  109. "execution_count": 4,
  110. "metadata": {},
  111. "outputs": [
  112. {
  113. "data": {
  114. "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABAElEQVR4nGNgGMyAWUhIqK5jvdSy/9/rGRgYGFhgEnJsVjYCwQwMDAxPJgV+vniQgYGBgREqZ7iXH8r6l/SV4dn7m8gmCt3++/fv37/Htn3/iMW+gDnZf/+e5WbQnoXNNXyMs/5GoQoxwVmf/n9kSGFiwAW49/11wynJoPzx4YIcRlyygR/+/i2XxCWru+vv32nSuGQFYv/83Y3b4p9/fzpAmSyoMnohpiwM1w5h06Q+5enfv39/bcMiJVF09+/fv39P+mFKiTtd/fv3799jgZiBJLT69t+/f/8eDuDEkDJf8+jv379/v7Ryo4qzMDAwMAQGMjBc3/y35wM2V1IfAABFF16Aa0wAOwAAAABJRU5ErkJggg==\n",
  115. "text/plain": [
  116. "<PIL.Image.Image image mode=L size=28x28 at 0x7FF1658A5278>"
  117. ]
  118. },
  119. "execution_count": 4,
  120. "metadata": {},
  121. "output_type": "execute_result"
  122. }
  123. ],
  124. "source": [
  125. "a_data"
  126. ]
  127. },
  128. {
  129. "cell_type": "code",
  130. "execution_count": 5,
  131. "metadata": {},
  132. "outputs": [
  133. {
  134. "data": {
  135. "text/plain": [
  136. "tensor(5)"
  137. ]
  138. },
  139. "execution_count": 5,
  140. "metadata": {},
  141. "output_type": "execute_result"
  142. }
  143. ],
  144. "source": [
  145. "a_label"
  146. ]
  147. },
  148. {
  149. "cell_type": "markdown",
  150. "metadata": {},
  151. "source": [
  152. "这里的读入的数据是 PIL 库中的格式,我们可以非常方便地将其转换为 numpy array"
  153. ]
  154. },
  155. {
  156. "cell_type": "code",
  157. "execution_count": 6,
  158. "metadata": {},
  159. "outputs": [
  160. {
  161. "name": "stdout",
  162. "output_type": "stream",
  163. "text": [
  164. "(28, 28)\n"
  165. ]
  166. }
  167. ],
  168. "source": [
  169. "a_data = np.array(a_data, dtype='float32')\n",
  170. "print(a_data.shape)"
  171. ]
  172. },
  173. {
  174. "cell_type": "markdown",
  175. "metadata": {},
  176. "source": [
  177. "这里我们可以看到这种图片的大小是 28 x 28"
  178. ]
  179. },
  180. {
  181. "cell_type": "code",
  182. "execution_count": 7,
  183. "metadata": {},
  184. "outputs": [
  185. {
  186. "name": "stdout",
  187. "output_type": "stream",
  188. "text": [
  189. "[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
  190. " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  191. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
  192. " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  193. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
  194. " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  195. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
  196. " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  197. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
  198. " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  199. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 3. 18.\n",
  200. " 18. 18. 126. 136. 175. 26. 166. 255. 247. 127. 0. 0. 0. 0.]\n",
  201. " [ 0. 0. 0. 0. 0. 0. 0. 0. 30. 36. 94. 154. 170. 253.\n",
  202. " 253. 253. 253. 253. 225. 172. 253. 242. 195. 64. 0. 0. 0. 0.]\n",
  203. " [ 0. 0. 0. 0. 0. 0. 0. 49. 238. 253. 253. 253. 253. 253.\n",
  204. " 253. 253. 253. 251. 93. 82. 82. 56. 39. 0. 0. 0. 0. 0.]\n",
  205. " [ 0. 0. 0. 0. 0. 0. 0. 18. 219. 253. 253. 253. 253. 253.\n",
  206. " 198. 182. 247. 241. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  207. " [ 0. 0. 0. 0. 0. 0. 0. 0. 80. 156. 107. 253. 253. 205.\n",
  208. " 11. 0. 43. 154. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  209. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 14. 1. 154. 253. 90.\n",
  210. " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  211. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 139. 253. 190.\n",
  212. " 2. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  213. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 11. 190. 253.\n",
  214. " 70. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  215. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 35. 241.\n",
  216. " 225. 160. 108. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  217. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 81.\n",
  218. " 240. 253. 253. 119. 25. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  219. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
  220. " 45. 186. 253. 253. 150. 27. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  221. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
  222. " 0. 16. 93. 252. 253. 187. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  223. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
  224. " 0. 0. 0. 249. 253. 249. 64. 0. 0. 0. 0. 0. 0. 0.]\n",
  225. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
  226. " 46. 130. 183. 253. 253. 207. 2. 0. 0. 0. 0. 0. 0. 0.]\n",
  227. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 39. 148.\n",
  228. " 229. 253. 253. 253. 250. 182. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  229. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 24. 114. 221. 253.\n",
  230. " 253. 253. 253. 201. 78. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  231. " [ 0. 0. 0. 0. 0. 0. 0. 0. 23. 66. 213. 253. 253. 253.\n",
  232. " 253. 198. 81. 2. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  233. " [ 0. 0. 0. 0. 0. 0. 18. 171. 219. 253. 253. 253. 253. 195.\n",
  234. " 80. 9. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  235. " [ 0. 0. 0. 0. 55. 172. 226. 253. 253. 253. 253. 244. 133. 11.\n",
  236. " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  237. " [ 0. 0. 0. 0. 136. 253. 253. 253. 212. 135. 132. 16. 0. 0.\n",
  238. " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  239. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
  240. " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  241. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
  242. " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
  243. " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n",
  244. " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]\n"
  245. ]
  246. }
  247. ],
  248. "source": [
  249. "print(a_data)"
  250. ]
  251. },
  252. {
  253. "cell_type": "markdown",
  254. "metadata": {},
  255. "source": [
  256. "我们可以将数组展示出来,里面的 0 就表示黑色,255 表示白色\n",
  257. "\n",
  258. "对于神经网络,我们第一层的输入就是 28 x 28 = 784,所以必须将得到的数据我们做一个变换,使用 reshape 将他们拉平成一个一维向量"
  259. ]
  260. },
  261. {
  262. "cell_type": "code",
  263. "execution_count": 23,
  264. "metadata": {},
  265. "outputs": [],
  266. "source": [
  267. "def data_tf(x):\n",
  268. " x = np.array(x, dtype='float32') / 255\n",
  269. " x = (x - 0.5) / 0.5 # 标准化,这个技巧之后会讲到\n",
  270. " x = x.reshape((-1,)) # 拉平\n",
  271. " x = torch.from_numpy(x)\n",
  272. " return x\n",
  273. "\n",
  274. "train_set = mnist.MNIST('../../data/mnist', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换\n",
  275. "test_set = mnist.MNIST('../../data/mnist', train=False, transform=data_tf, download=True)"
  276. ]
  277. },
  278. {
  279. "cell_type": "code",
  280. "execution_count": 24,
  281. "metadata": {},
  282. "outputs": [
  283. {
  284. "name": "stdout",
  285. "output_type": "stream",
  286. "text": [
  287. "torch.Size([784])\n",
  288. "tensor(5)\n"
  289. ]
  290. }
  291. ],
  292. "source": [
  293. "a, a_label = train_set[0]\n",
  294. "print(a.shape)\n",
  295. "print(a_label)"
  296. ]
  297. },
  298. {
  299. "cell_type": "code",
  300. "execution_count": 25,
  301. "metadata": {},
  302. "outputs": [],
  303. "source": [
  304. "from torch.utils.data import DataLoader\n",
  305. "# 使用 pytorch 自带的 DataLoader 定义一个数据迭代器\n",
  306. "train_data = DataLoader(train_set, batch_size=64, shuffle=True)\n",
  307. "test_data = DataLoader(test_set, batch_size=128, shuffle=False)"
  308. ]
  309. },
  310. {
  311. "cell_type": "markdown",
  312. "metadata": {},
  313. "source": [
  314. "使用这样的数据迭代器是非常有必要的,如果数据量太大,就无法一次将他们全部读入内存,所以需要使用 python 迭代器,每次生成一个批次的数据"
  315. ]
  316. },
  317. {
  318. "cell_type": "code",
  319. "execution_count": 28,
  320. "metadata": {},
  321. "outputs": [],
  322. "source": [
  323. "a, a_label = next(iter(train_data))"
  324. ]
  325. },
  326. {
  327. "cell_type": "code",
  328. "execution_count": 27,
  329. "metadata": {},
  330. "outputs": [
  331. {
  332. "name": "stdout",
  333. "output_type": "stream",
  334. "text": [
  335. "torch.Size([64, 784])\n",
  336. "torch.Size([64])\n"
  337. ]
  338. }
  339. ],
  340. "source": [
  341. "# 打印出一个批次的数据大小\n",
  342. "print(a.shape)\n",
  343. "print(a_label.shape)"
  344. ]
  345. },
  346. {
  347. "cell_type": "code",
  348. "execution_count": 38,
  349. "metadata": {},
  350. "outputs": [],
  351. "source": [
  352. "# 使用 Sequential 定义 4 层神经网络\n",
  353. "net = nn.Sequential(\n",
  354. " nn.Linear(784, 400),\n",
  355. " nn.ReLU(),\n",
  356. " nn.Linear(400, 200),\n",
  357. " nn.ReLU(),\n",
  358. " nn.Linear(200, 100),\n",
  359. " nn.ReLU(),\n",
  360. " nn.Linear(100, 10)\n",
  361. ")"
  362. ]
  363. },
  364. {
  365. "cell_type": "code",
  366. "execution_count": 39,
  367. "metadata": {},
  368. "outputs": [
  369. {
  370. "data": {
  371. "text/plain": [
  372. "Sequential(\n",
  373. " (0): Linear(in_features=784, out_features=400, bias=True)\n",
  374. " (1): ReLU()\n",
  375. " (2): Linear(in_features=400, out_features=200, bias=True)\n",
  376. " (3): ReLU()\n",
  377. " (4): Linear(in_features=200, out_features=100, bias=True)\n",
  378. " (5): ReLU()\n",
  379. " (6): Linear(in_features=100, out_features=10, bias=True)\n",
  380. ")"
  381. ]
  382. },
  383. "execution_count": 39,
  384. "metadata": {},
  385. "output_type": "execute_result"
  386. }
  387. ],
  388. "source": [
  389. "net"
  390. ]
  391. },
  392. {
  393. "cell_type": "markdown",
  394. "metadata": {},
  395. "source": [
  396. "交叉熵在 pytorch 中已经内置了,交叉熵的数值稳定性更差,所以内置的函数已经帮我们解决了这个问题"
  397. ]
  398. },
  399. {
  400. "cell_type": "code",
  401. "execution_count": 40,
  402. "metadata": {},
  403. "outputs": [],
  404. "source": [
  405. "# 定义 loss 函数\n",
  406. "criterion = nn.CrossEntropyLoss()\n",
  407. "optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1"
  408. ]
  409. },
  410. {
  411. "cell_type": "code",
  412. "execution_count": 42,
  413. "metadata": {
  414. "scrolled": true
  415. },
  416. "outputs": [
  417. {
  418. "name": "stderr",
  419. "output_type": "stream",
  420. "text": [
  421. "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:22: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n",
  422. "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:25: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n",
  423. "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:41: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n",
  424. "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:44: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n"
  425. ]
  426. },
  427. {
  428. "name": "stdout",
  429. "output_type": "stream",
  430. "text": [
  431. "epoch: 0, Train Loss: 0.166705, Train Acc: 0.947978, Eval Loss: 0.129106, Eval Acc: 0.959157\n",
  432. "epoch: 1, Train Loss: 0.117714, Train Acc: 0.962836, Eval Loss: 0.097123, Eval Acc: 0.969838\n",
  433. "epoch: 2, Train Loss: 0.092098, Train Acc: 0.970532, Eval Loss: 0.098194, Eval Acc: 0.969541\n",
  434. "epoch: 3, Train Loss: 0.074442, Train Acc: 0.975880, Eval Loss: 0.077213, Eval Acc: 0.975574\n",
  435. "epoch: 4, Train Loss: 0.062742, Train Acc: 0.979594, Eval Loss: 0.149892, Eval Acc: 0.955301\n",
  436. "epoch: 5, Train Loss: 0.052319, Train Acc: 0.983276, Eval Loss: 0.124755, Eval Acc: 0.961531\n",
  437. "epoch: 6, Train Loss: 0.045134, Train Acc: 0.985091, Eval Loss: 0.085263, Eval Acc: 0.975178\n",
  438. "epoch: 7, Train Loss: 0.038610, Train Acc: 0.987423, Eval Loss: 0.063986, Eval Acc: 0.980123\n",
  439. "epoch: 8, Train Loss: 0.033068, Train Acc: 0.988906, Eval Loss: 0.074201, Eval Acc: 0.977453\n",
  440. "epoch: 9, Train Loss: 0.029478, Train Acc: 0.990155, Eval Loss: 0.066254, Eval Acc: 0.980123\n",
  441. "epoch: 10, Train Loss: 0.024885, Train Acc: 0.992237, Eval Loss: 0.067818, Eval Acc: 0.979727\n",
  442. "epoch: 11, Train Loss: 0.020706, Train Acc: 0.993237, Eval Loss: 0.174131, Eval Acc: 0.958070\n",
  443. "epoch: 12, Train Loss: 0.019527, Train Acc: 0.993553, Eval Loss: 0.066838, Eval Acc: 0.982199\n",
  444. "epoch: 13, Train Loss: 0.016248, Train Acc: 0.994620, Eval Loss: 0.080457, Eval Acc: 0.978738\n",
  445. "epoch: 14, Train Loss: 0.017617, Train Acc: 0.994603, Eval Loss: 0.064320, Eval Acc: 0.982496\n",
  446. "epoch: 15, Train Loss: 0.012970, Train Acc: 0.995985, Eval Loss: 0.079791, Eval Acc: 0.979925\n",
  447. "epoch: 16, Train Loss: 0.012162, Train Acc: 0.995736, Eval Loss: 0.083829, Eval Acc: 0.979727\n",
  448. "epoch: 17, Train Loss: 0.011916, Train Acc: 0.996185, Eval Loss: 0.079493, Eval Acc: 0.981507\n",
  449. "epoch: 18, Train Loss: 0.008972, Train Acc: 0.997385, Eval Loss: 0.074135, Eval Acc: 0.981507\n",
  450. "epoch: 19, Train Loss: 0.008857, Train Acc: 0.997018, Eval Loss: 0.074056, Eval Acc: 0.983188\n"
  451. ]
  452. }
  453. ],
  454. "source": [
  455. "# 开始训练\n",
  456. "losses = []\n",
  457. "acces = []\n",
  458. "eval_losses = []\n",
  459. "eval_acces = []\n",
  460. "\n",
  461. "for e in range(20):\n",
  462. " train_loss = 0\n",
  463. " train_acc = 0\n",
  464. " net.train()\n",
  465. " for im, label in train_data:\n",
  466. " im = Variable(im)\n",
  467. " label = Variable(label)\n",
  468. " # 前向传播\n",
  469. " out = net(im)\n",
  470. " loss = criterion(out, label)\n",
  471. " # 反向传播\n",
  472. " optimizer.zero_grad()\n",
  473. " loss.backward()\n",
  474. " optimizer.step()\n",
  475. " # 记录误差\n",
  476. " train_loss += loss.data[0]\n",
  477. " # 计算分类的准确率\n",
  478. " _, pred = out.max(1)\n",
  479. " num_correct = float((pred == label).sum().data[0])\n",
  480. " acc = num_correct / im.shape[0]\n",
  481. " train_acc += acc\n",
  482. " \n",
  483. " losses.append(train_loss / len(train_data))\n",
  484. " acces.append(train_acc / len(train_data))\n",
  485. " # 在测试集上检验效果\n",
  486. " eval_loss = 0\n",
  487. " eval_acc = 0\n",
  488. " net.eval() # 将模型改为预测模式\n",
  489. " for im, label in test_data:\n",
  490. " im = Variable(im)\n",
  491. " label = Variable(label)\n",
  492. " out = net(im)\n",
  493. " loss = criterion(out, label)\n",
  494. " # 记录误差\n",
  495. " eval_loss += loss.data[0]\n",
  496. " # 记录准确率\n",
  497. " _, pred = out.max(1)\n",
  498. " num_correct = float((pred == label).sum().data[0])\n",
  499. " acc = num_correct / im.shape[0]\n",
  500. " eval_acc += acc\n",
  501. " \n",
  502. " eval_losses.append(eval_loss / len(test_data))\n",
  503. " eval_acces.append(eval_acc / len(test_data))\n",
  504. " print('epoch: {}, Train Loss: {:.6f}, Train Acc: {:.6f}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'\n",
  505. " .format(e, train_loss / len(train_data), train_acc / len(train_data), \n",
  506. " eval_loss / len(test_data), eval_acc / len(test_data)))"
  507. ]
  508. },
  509. {
  510. "cell_type": "markdown",
  511. "metadata": {},
  512. "source": [
  513. "画出 loss 曲线和 准确率曲线"
  514. ]
  515. },
  516. {
  517. "cell_type": "code",
  518. "execution_count": 43,
  519. "metadata": {},
  520. "outputs": [],
  521. "source": [
  522. "import matplotlib.pyplot as plt\n",
  523. "%matplotlib inline"
  524. ]
  525. },
  526. {
  527. "cell_type": "code",
  528. "execution_count": 44,
  529. "metadata": {},
  530. "outputs": [
  531. {
  532. "data": {
  533. "text/plain": [
  534. "[<matplotlib.lines.Line2D at 0x7fa44b7aa588>]"
  535. ]
  536. },
  537. "execution_count": 44,
  538. "metadata": {},
  539. "output_type": "execute_result"
  540. },
  541. {
  542. "data": {
  543. "image/png": "\n",
  544. "text/plain": [
  545. "<Figure size 432x288 with 1 Axes>"
  546. ]
  547. },
  548. "metadata": {
  549. "needs_background": "light"
  550. },
  551. "output_type": "display_data"
  552. }
  553. ],
  554. "source": [
  555. "plt.title('train loss')\n",
  556. "plt.plot(np.arange(len(losses)), losses)"
  557. ]
  558. },
  559. {
  560. "cell_type": "code",
  561. "execution_count": 45,
  562. "metadata": {},
  563. "outputs": [
  564. {
  565. "data": {
  566. "text/plain": [
  567. "Text(0.5, 1.0, 'train acc')"
  568. ]
  569. },
  570. "execution_count": 45,
  571. "metadata": {},
  572. "output_type": "execute_result"
  573. },
  574. {
  575. "data": {
  576. "image/png": "\n",
  577. "text/plain": [
  578. "<Figure size 432x288 with 1 Axes>"
  579. ]
  580. },
  581. "metadata": {
  582. "needs_background": "light"
  583. },
  584. "output_type": "display_data"
  585. }
  586. ],
  587. "source": [
  588. "plt.plot(np.arange(len(acces)), acces)\n",
  589. "plt.title('train acc')"
  590. ]
  591. },
  592. {
  593. "cell_type": "code",
  594. "execution_count": 46,
  595. "metadata": {},
  596. "outputs": [
  597. {
  598. "data": {
  599. "text/plain": [
  600. "Text(0.5, 1.0, 'test loss')"
  601. ]
  602. },
  603. "execution_count": 46,
  604. "metadata": {},
  605. "output_type": "execute_result"
  606. },
  607. {
  608. "data": {
  609. "image/png": "\n",
  610. "text/plain": [
  611. "<Figure size 432x288 with 1 Axes>"
  612. ]
  613. },
  614. "metadata": {
  615. "needs_background": "light"
  616. },
  617. "output_type": "display_data"
  618. }
  619. ],
  620. "source": [
  621. "plt.plot(np.arange(len(eval_losses)), eval_losses)\n",
  622. "plt.title('test loss')"
  623. ]
  624. },
  625. {
  626. "cell_type": "code",
  627. "execution_count": 47,
  628. "metadata": {},
  629. "outputs": [
  630. {
  631. "data": {
  632. "text/plain": [
  633. "Text(0.5, 1.0, 'test acc')"
  634. ]
  635. },
  636. "execution_count": 47,
  637. "metadata": {},
  638. "output_type": "execute_result"
  639. },
  640. {
  641. "data": {
  642. "image/png": "\n",
  643. "text/plain": [
  644. "<Figure size 432x288 with 1 Axes>"
  645. ]
  646. },
  647. "metadata": {
  648. "needs_background": "light"
  649. },
  650. "output_type": "display_data"
  651. }
  652. ],
  653. "source": [
  654. "plt.plot(np.arange(len(eval_acces)), eval_acces)\n",
  655. "plt.title('test acc')"
  656. ]
  657. },
  658. {
  659. "cell_type": "markdown",
  660. "metadata": {},
  661. "source": [
  662. "可以看到我们的三层网络在训练集上能够达到 99.9% 的准确率,测试集上能够达到 98.20% 的准确率"
  663. ]
  664. },
  665. {
  666. "cell_type": "markdown",
  667. "metadata": {},
  668. "source": [
  669. "**小练习:看一看上面的训练过程,看一下准确率是怎么计算出来的,特别注意 max 这个函数**\n",
  670. "\n",
  671. "**自己重新实现一个新的网络,试试改变隐藏层的数目和激活函数,看看有什么新的结果**"
  672. ]
  673. }
  674. ],
  675. "metadata": {
  676. "kernelspec": {
  677. "display_name": "Python 3",
  678. "language": "python",
  679. "name": "python3"
  680. },
  681. "language_info": {
  682. "codemirror_mode": {
  683. "name": "ipython",
  684. "version": 3
  685. },
  686. "file_extension": ".py",
  687. "mimetype": "text/x-python",
  688. "name": "python",
  689. "nbconvert_exporter": "python",
  690. "pygments_lexer": "ipython3",
  691. "version": "3.6.9"
  692. }
  693. },
  694. "nbformat": 4,
  695. "nbformat_minor": 2
  696. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。