You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

6-batch-normalization.ipynb 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# 批标准化\n",
  8. "在正式进入模型的构建和训练之前,先讲一下数据预处理和批标准化。因为模型训练并不容易,特别是一些非常复杂的模型,并不能非常好的训练得到收敛的结果,所以对数据增加一些预处理,同时使用批标准化能够得到非常好的收敛结果,这也是卷积网络能够训练到非常深的层的一个重要原因。"
  9. ]
  10. },
  11. {
  12. "cell_type": "markdown",
  13. "metadata": {},
  14. "source": [
  15. "## 1. 数据预处理\n",
  16. "目前数据预处理最常见的方法就是 `中心化` 和 `标准化`\n",
  17. "* **中心化** 相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。\n",
  18. "* **标准化** 也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间\n",
  19. "\n",
  20. "下图是处理的的示例:\n",
  21. "\n",
  22. "![](images/data_normalize.png)\n",
  23. "\n"
  24. ]
  25. },
  26. {
  27. "cell_type": "markdown",
  28. "metadata": {},
  29. "source": [
  30. "## 2. Batch Normalization\n",
  31. "前面在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网络结构,网络的非线性层会使得输出的结果变得相关,且不再满足一个标准的 ${N}(0, 1)$ 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。\n",
  32. "\n",
  33. "所以在 2015 年一篇论文提出了这个方法,批标准化(batch normalization),简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。"
  34. ]
  35. },
  36. {
  37. "cell_type": "markdown",
  38. "metadata": {},
  39. "source": [
  40. "批标准化的实现非常简单,对于给定的一个 batch 的数据 $B = \\{x_1, x_2, \\cdots, x_m\\}$算法的公式如下\n",
  41. "\n",
  42. "$$\n",
  43. "\\mu_B = \\frac{1}{m} \\sum_{i=1}^m x_i\n",
  44. "$$\n",
  45. "$$\n",
  46. "\\sigma^2_B = \\frac{1}{m} \\sum_{i=1}^m (x_i - \\mu_B)^2\n",
  47. "$$\n",
  48. "$$\n",
  49. "\\hat{x}_i = \\frac{x_i - \\mu_B}{\\sqrt{\\sigma^2_B + \\epsilon}}\n",
  50. "$$\n",
  51. "$$\n",
  52. "y_i = \\gamma \\hat{x}_i + \\beta\n",
  53. "$$"
  54. ]
  55. },
  56. {
  57. "cell_type": "markdown",
  58. "metadata": {},
  59. "source": [
  60. "* 第一和第二个公式计算出一个 batch 中数据的均值和方差\n",
  61. "* 第三个公式对 batch 中的每个数据点做标准化,$\\epsilon$ 是为了计算稳定引入的一个小的常数,通常取 $10^{-5}$\n",
  62. "* 最后利用权重修正得到最后的输出结果,其中 $\\gamma$ $\\beta$是权值变换参数,也是网络参数在训练过程一起学习\n",
  63. "\n",
  64. "下面演示一维的情况,也就是神经网络中的情况"
  65. ]
  66. },
  67. {
  68. "cell_type": "code",
  69. "execution_count": 1,
  70. "metadata": {
  71. "ExecuteTime": {
  72. "end_time": "2017-12-23T06:50:51.579067Z",
  73. "start_time": "2017-12-23T06:50:51.575693Z"
  74. },
  75. "collapsed": true
  76. },
  77. "outputs": [],
  78. "source": [
  79. "import sys\n",
  80. "sys.path.append('..')\n",
  81. "\n",
  82. "import torch"
  83. ]
  84. },
  85. {
  86. "cell_type": "code",
  87. "execution_count": 2,
  88. "metadata": {
  89. "ExecuteTime": {
  90. "end_time": "2017-12-23T07:14:11.077807Z",
  91. "start_time": "2017-12-23T07:14:11.060849Z"
  92. },
  93. "collapsed": true
  94. },
  95. "outputs": [],
  96. "source": [
  97. "def simple_batch_norm_1d(x, gamma, beta):\n",
  98. " eps = 1e-5\n",
  99. " x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast\n",
  100. " x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)\n",
  101. " x_hat = (x - x_mean) / torch.sqrt(x_var + eps)\n",
  102. " return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)"
  103. ]
  104. },
  105. {
  106. "cell_type": "markdown",
  107. "metadata": {},
  108. "source": [
  109. "我们来验证一下是否对于任意的输入,输出会被标准化"
  110. ]
  111. },
  112. {
  113. "cell_type": "code",
  114. "execution_count": 3,
  115. "metadata": {
  116. "ExecuteTime": {
  117. "end_time": "2017-12-23T07:14:20.610603Z",
  118. "start_time": "2017-12-23T07:14:20.597682Z"
  119. }
  120. },
  121. "outputs": [
  122. {
  123. "name": "stdout",
  124. "output_type": "stream",
  125. "text": [
  126. "before bn: \n",
  127. "\n",
  128. " 0 1 2\n",
  129. " 3 4 5\n",
  130. " 6 7 8\n",
  131. " 9 10 11\n",
  132. " 12 13 14\n",
  133. "[torch.FloatTensor of size 5x3]\n",
  134. "\n",
  135. "after bn: \n",
  136. "\n",
  137. "-1.4142 -1.4142 -1.4142\n",
  138. "-0.7071 -0.7071 -0.7071\n",
  139. " 0.0000 0.0000 0.0000\n",
  140. " 0.7071 0.7071 0.7071\n",
  141. " 1.4142 1.4142 1.4142\n",
  142. "[torch.FloatTensor of size 5x3]\n",
  143. "\n"
  144. ]
  145. }
  146. ],
  147. "source": [
  148. "x = torch.arange(15).view(5, 3)\n",
  149. "gamma = torch.ones(x.shape[1])\n",
  150. "beta = torch.zeros(x.shape[1])\n",
  151. "print('before bn: ')\n",
  152. "print(x)\n",
  153. "y = simple_batch_norm_1d(x, gamma, beta)\n",
  154. "print('after bn: ')\n",
  155. "print(y)"
  156. ]
  157. },
  158. {
  159. "cell_type": "markdown",
  160. "metadata": {},
  161. "source": [
  162. "可以看到这里一共是 5 个数据点,三个特征,每一列表示一个特征的不同数据点,使用批标准化之后,每一列都变成了标准的正态分布\n",
  163. "\n",
  164. "这个时候会出现一个问题,就是测试的时候该使用批标准化吗?\n",
  165. "\n",
  166. "答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以**测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替**\n",
  167. "\n",
  168. "下面我们实现以下能够区分训练状态和测试状态的批标准化方法"
  169. ]
  170. },
  171. {
  172. "cell_type": "code",
  173. "execution_count": 4,
  174. "metadata": {
  175. "ExecuteTime": {
  176. "end_time": "2017-12-23T07:32:48.025709Z",
  177. "start_time": "2017-12-23T07:32:48.005892Z"
  178. },
  179. "collapsed": true
  180. },
  181. "outputs": [],
  182. "source": [
  183. "def batch_norm_1d(x, gamma, beta, \n",
  184. " is_training, \n",
  185. " moving_mean, moving_var, moving_momentum=0.1):\n",
  186. " eps = 1e-5\n",
  187. " x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast\n",
  188. " x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)\n",
  189. " if is_training:\n",
  190. " x_hat = (x - x_mean) / torch.sqrt(x_var + eps)\n",
  191. " moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean\n",
  192. " moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var\n",
  193. " else:\n",
  194. " x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)\n",
  195. " return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)"
  196. ]
  197. },
  198. {
  199. "cell_type": "markdown",
  200. "metadata": {},
  201. "source": [
  202. "下面我们使用上一节课将的深度神经网络分类 MNIST 数据集的例子来试验一下批标准化是否有用"
  203. ]
  204. },
  205. {
  206. "cell_type": "code",
  207. "execution_count": 5,
  208. "metadata": {
  209. "collapsed": true
  210. },
  211. "outputs": [],
  212. "source": [
  213. "import numpy as np\n",
  214. "from torchvision.datasets import mnist # 导入 pytorch 内置的 mnist 数据\n",
  215. "from torch.utils.data import DataLoader\n",
  216. "from torch import nn\n",
  217. "from torch.autograd import Variable"
  218. ]
  219. },
  220. {
  221. "cell_type": "code",
  222. "execution_count": 6,
  223. "metadata": {
  224. "collapsed": true
  225. },
  226. "outputs": [],
  227. "source": [
  228. "# 使用内置函数下载 mnist 数据集\n",
  229. "train_set = mnist.MNIST('../../data/mnist', train=True)\n",
  230. "test_set = mnist.MNIST('../../data/mnist', train=False)\n",
  231. "\n",
  232. "def data_tf(x):\n",
  233. " x = np.array(x, dtype='float32') / 255\n",
  234. " x = (x - 0.5) / 0.5 # 数据预处理,标准化\n",
  235. " x = x.reshape((-1,)) # 拉平\n",
  236. " x = torch.from_numpy(x)\n",
  237. " return x\n",
  238. "\n",
  239. "# 重新载入数据集,申明定义的数据变换\n",
  240. "train_set = mnist.MNIST('../../data/mnist', train=True, \n",
  241. " transform=data_tf, download=True) \n",
  242. "test_set = mnist.MNIST('../../data/mnist', train=False, \n",
  243. " transform=data_tf, download=True)\n",
  244. "train_data = DataLoader(train_set, batch_size=64, shuffle=True)\n",
  245. "test_data = DataLoader(test_set, batch_size=128, shuffle=False)"
  246. ]
  247. },
  248. {
  249. "cell_type": "code",
  250. "execution_count": 7,
  251. "metadata": {
  252. "collapsed": true
  253. },
  254. "outputs": [],
  255. "source": [
  256. "class multi_network(nn.Module):\n",
  257. " def __init__(self):\n",
  258. " super(multi_network, self).__init__()\n",
  259. " self.layer1 = nn.Linear(784, 100)\n",
  260. " self.relu = nn.ReLU(True)\n",
  261. " self.layer2 = nn.Linear(100, 10)\n",
  262. " \n",
  263. " self.gamma = nn.Parameter(torch.randn(100))\n",
  264. " self.beta = nn.Parameter(torch.randn(100))\n",
  265. " \n",
  266. " self.moving_mean = Variable(torch.zeros(100))\n",
  267. " self.moving_var = Variable(torch.zeros(100))\n",
  268. " \n",
  269. " def forward(self, x, is_train=True):\n",
  270. " x = self.layer1(x)\n",
  271. " x = batch_norm_1d(x, self.gamma, self.beta, \n",
  272. " is_train, self.moving_mean, self.moving_var)\n",
  273. " x = self.relu(x)\n",
  274. " x = self.layer2(x)\n",
  275. " return x"
  276. ]
  277. },
  278. {
  279. "cell_type": "code",
  280. "execution_count": 8,
  281. "metadata": {
  282. "collapsed": true
  283. },
  284. "outputs": [],
  285. "source": [
  286. "net = multi_network()"
  287. ]
  288. },
  289. {
  290. "cell_type": "code",
  291. "execution_count": 9,
  292. "metadata": {
  293. "collapsed": true
  294. },
  295. "outputs": [],
  296. "source": [
  297. "# 定义 loss 函数\n",
  298. "criterion = nn.CrossEntropyLoss()\n",
  299. "optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1"
  300. ]
  301. },
  302. {
  303. "cell_type": "markdown",
  304. "metadata": {},
  305. "source": [
  306. "为了方便,训练函数已经定义在外面的 `utils.py` 中,跟前面训练网络的操作是一样的。"
  307. ]
  308. },
  309. {
  310. "cell_type": "code",
  311. "execution_count": 10,
  312. "metadata": {},
  313. "outputs": [
  314. {
  315. "name": "stdout",
  316. "output_type": "stream",
  317. "text": [
  318. "Epoch 0. Train Loss: 0.308139, Train Acc: 0.912797, Valid Loss: 0.181375, Valid Acc: 0.948279, Time 00:00:07\n",
  319. "Epoch 1. Train Loss: 0.174049, Train Acc: 0.949910, Valid Loss: 0.143940, Valid Acc: 0.958267, Time 00:00:09\n",
  320. "Epoch 2. Train Loss: 0.134983, Train Acc: 0.961587, Valid Loss: 0.122489, Valid Acc: 0.963904, Time 00:00:08\n",
  321. "Epoch 3. Train Loss: 0.111758, Train Acc: 0.968317, Valid Loss: 0.106595, Valid Acc: 0.966278, Time 00:00:09\n",
  322. "Epoch 4. Train Loss: 0.096425, Train Acc: 0.971915, Valid Loss: 0.108423, Valid Acc: 0.967563, Time 00:00:10\n",
  323. "Epoch 5. Train Loss: 0.084424, Train Acc: 0.974464, Valid Loss: 0.107135, Valid Acc: 0.969838, Time 00:00:09\n",
  324. "Epoch 6. Train Loss: 0.076206, Train Acc: 0.977645, Valid Loss: 0.092725, Valid Acc: 0.971420, Time 00:00:09\n",
  325. "Epoch 7. Train Loss: 0.069438, Train Acc: 0.979661, Valid Loss: 0.091497, Valid Acc: 0.971519, Time 00:00:09\n",
  326. "Epoch 8. Train Loss: 0.062908, Train Acc: 0.980810, Valid Loss: 0.088797, Valid Acc: 0.972903, Time 00:00:08\n",
  327. "Epoch 9. Train Loss: 0.058186, Train Acc: 0.982309, Valid Loss: 0.090830, Valid Acc: 0.972310, Time 00:00:08\n"
  328. ]
  329. }
  330. ],
  331. "source": [
  332. "from utils import train\n",
  333. "train(net, train_data, test_data, 10, optimizer, criterion)"
  334. ]
  335. },
  336. {
  337. "cell_type": "markdown",
  338. "metadata": {},
  339. "source": [
  340. "这里的 $\\gamma$ 和 $\\beta$ 都作为参数进行训练,初始化为随机的高斯分布,`moving_mean` 和 `moving_var` 都初始化为 0,并不是更新的参数,训练完 10 次之后,我们可以看看移动平均和移动方差被修改为了多少"
  341. ]
  342. },
  343. {
  344. "cell_type": "code",
  345. "execution_count": 11,
  346. "metadata": {
  347. "scrolled": true
  348. },
  349. "outputs": [
  350. {
  351. "name": "stdout",
  352. "output_type": "stream",
  353. "text": [
  354. "Variable containing:\n",
  355. " 0.5505\n",
  356. " 2.0835\n",
  357. " 0.0794\n",
  358. "-0.1991\n",
  359. "-0.9822\n",
  360. "-0.5820\n",
  361. " 0.6991\n",
  362. "-0.1292\n",
  363. " 2.9608\n",
  364. " 1.0826\n",
  365. "[torch.FloatTensor of size 10]\n",
  366. "\n"
  367. ]
  368. }
  369. ],
  370. "source": [
  371. "# 打出 moving_mean 的前 10 项\n",
  372. "print(net.moving_mean[:10])"
  373. ]
  374. },
  375. {
  376. "cell_type": "markdown",
  377. "metadata": {},
  378. "source": [
  379. "可以看到,这些值已经在训练的过程中进行了修改,在测试过程中,我们不需要再计算均值和方差,直接使用移动平均和移动方差即可"
  380. ]
  381. },
  382. {
  383. "cell_type": "markdown",
  384. "metadata": {},
  385. "source": [
  386. "作为对比,我们看看不使用批标准化的结果"
  387. ]
  388. },
  389. {
  390. "cell_type": "code",
  391. "execution_count": 12,
  392. "metadata": {},
  393. "outputs": [
  394. {
  395. "name": "stdout",
  396. "output_type": "stream",
  397. "text": [
  398. "Epoch 0. Train Loss: 0.402263, Train Acc: 0.873817, Valid Loss: 0.220468, Valid Acc: 0.932852, Time 00:00:07\n",
  399. "Epoch 1. Train Loss: 0.181916, Train Acc: 0.945379, Valid Loss: 0.162440, Valid Acc: 0.953817, Time 00:00:08\n",
  400. "Epoch 2. Train Loss: 0.136073, Train Acc: 0.958522, Valid Loss: 0.264888, Valid Acc: 0.918216, Time 00:00:08\n",
  401. "Epoch 3. Train Loss: 0.111658, Train Acc: 0.966551, Valid Loss: 0.149704, Valid Acc: 0.950752, Time 00:00:08\n",
  402. "Epoch 4. Train Loss: 0.096433, Train Acc: 0.970732, Valid Loss: 0.116364, Valid Acc: 0.963311, Time 00:00:07\n",
  403. "Epoch 5. Train Loss: 0.083800, Train Acc: 0.973914, Valid Loss: 0.105775, Valid Acc: 0.968058, Time 00:00:08\n",
  404. "Epoch 6. Train Loss: 0.074534, Train Acc: 0.977129, Valid Loss: 0.094511, Valid Acc: 0.970728, Time 00:00:08\n",
  405. "Epoch 7. Train Loss: 0.067365, Train Acc: 0.979311, Valid Loss: 0.130495, Valid Acc: 0.960146, Time 00:00:09\n",
  406. "Epoch 8. Train Loss: 0.061585, Train Acc: 0.980894, Valid Loss: 0.089632, Valid Acc: 0.974090, Time 00:00:08\n",
  407. "Epoch 9. Train Loss: 0.055352, Train Acc: 0.982892, Valid Loss: 0.091508, Valid Acc: 0.970431, Time 00:00:08\n"
  408. ]
  409. }
  410. ],
  411. "source": [
  412. "no_bn_net = nn.Sequential(\n",
  413. " nn.Linear(784, 100),\n",
  414. " nn.ReLU(True),\n",
  415. " nn.Linear(100, 10)\n",
  416. ")\n",
  417. "\n",
  418. "optimizer = torch.optim.SGD(no_bn_net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1\n",
  419. "train(no_bn_net, train_data, test_data, 10, optimizer, criterion)"
  420. ]
  421. },
  422. {
  423. "cell_type": "markdown",
  424. "metadata": {},
  425. "source": [
  426. "可以看到虽然最后的结果两种情况一样,但是如果我们看前几次的情况,可以看到使用批标准化的情况能够更快的收敛,因为这只是一个小网络,所以用不用批标准化都能够收敛,但是对于更加深的网络,使用批标准化在训练的时候能够很快地收敛"
  427. ]
  428. },
  429. {
  430. "cell_type": "markdown",
  431. "metadata": {},
  432. "source": [
  433. "从上面可以看到,我们自己实现了 2 维情况的批标准化,对应于卷积的 4 维情况的标准化是类似的,只需要沿着通道的维度进行均值和方差的计算,但是我们自己实现批标准化是很累的,pytorch 当然也为我们内置了批标准化的函数,一维和二维分别是 `torch.nn.BatchNorm1d()` 和 `torch.nn.BatchNorm2d()`,不同于我们的实现,pytorch 不仅将 $\\gamma$ 和 $\\beta$ 作为训练的参数,也将 `moving_mean` 和 `moving_var` 也作为参数进行训练"
  434. ]
  435. },
  436. {
  437. "cell_type": "markdown",
  438. "metadata": {},
  439. "source": [
  440. "下面我们在卷积网络下试用一下批标准化看看效果"
  441. ]
  442. },
  443. {
  444. "cell_type": "code",
  445. "execution_count": null,
  446. "metadata": {
  447. "collapsed": true
  448. },
  449. "outputs": [],
  450. "source": [
  451. "def data_tf(x):\n",
  452. " x = np.array(x, dtype='float32') / 255\n",
  453. " x = (x - 0.5) / 0.5 # 数据预处理,标准化\n",
  454. " x = torch.from_numpy(x)\n",
  455. " x = x.unsqueeze(0)\n",
  456. " return x\n",
  457. "\n",
  458. "train_set = mnist.MNIST('../../data/mnist', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换\n",
  459. "test_set = mnist.MNIST('../../data/mnist', train=False, transform=data_tf, download=True)\n",
  460. "train_data = DataLoader(train_set, batch_size=64, shuffle=True)\n",
  461. "test_data = DataLoader(test_set, batch_size=128, shuffle=False)"
  462. ]
  463. },
  464. {
  465. "cell_type": "code",
  466. "execution_count": 78,
  467. "metadata": {
  468. "collapsed": true
  469. },
  470. "outputs": [],
  471. "source": [
  472. "# 使用批标准化\n",
  473. "class conv_bn_net(nn.Module):\n",
  474. " def __init__(self):\n",
  475. " super(conv_bn_net, self).__init__()\n",
  476. " self.stage1 = nn.Sequential(\n",
  477. " nn.Conv2d(1, 6, 3, padding=1),\n",
  478. " nn.BatchNorm2d(6),\n",
  479. " nn.ReLU(True),\n",
  480. " nn.MaxPool2d(2, 2),\n",
  481. " nn.Conv2d(6, 16, 5),\n",
  482. " nn.BatchNorm2d(16),\n",
  483. " nn.ReLU(True),\n",
  484. " nn.MaxPool2d(2, 2)\n",
  485. " )\n",
  486. " \n",
  487. " self.classfy = nn.Linear(400, 10)\n",
  488. " def forward(self, x):\n",
  489. " x = self.stage1(x)\n",
  490. " x = x.view(x.shape[0], -1)\n",
  491. " x = self.classfy(x)\n",
  492. " return x\n",
  493. "\n",
  494. "net = conv_bn_net()\n",
  495. "optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1"
  496. ]
  497. },
  498. {
  499. "cell_type": "code",
  500. "execution_count": 79,
  501. "metadata": {},
  502. "outputs": [
  503. {
  504. "name": "stdout",
  505. "output_type": "stream",
  506. "text": [
  507. "Epoch 0. Train Loss: 0.160329, Train Acc: 0.952842, Valid Loss: 0.063328, Valid Acc: 0.978441, Time 00:00:33\n",
  508. "Epoch 1. Train Loss: 0.067862, Train Acc: 0.979361, Valid Loss: 0.068229, Valid Acc: 0.979430, Time 00:00:37\n",
  509. "Epoch 2. Train Loss: 0.051867, Train Acc: 0.984625, Valid Loss: 0.044616, Valid Acc: 0.985265, Time 00:00:37\n",
  510. "Epoch 3. Train Loss: 0.044797, Train Acc: 0.986141, Valid Loss: 0.042711, Valid Acc: 0.986056, Time 00:00:38\n",
  511. "Epoch 4. Train Loss: 0.039876, Train Acc: 0.987690, Valid Loss: 0.042499, Valid Acc: 0.985067, Time 00:00:41\n"
  512. ]
  513. }
  514. ],
  515. "source": [
  516. "train(net, train_data, test_data, 5, optimizer, criterion)"
  517. ]
  518. },
  519. {
  520. "cell_type": "code",
  521. "execution_count": 76,
  522. "metadata": {
  523. "collapsed": true
  524. },
  525. "outputs": [],
  526. "source": [
  527. "# 不使用批标准化\n",
  528. "class conv_no_bn_net(nn.Module):\n",
  529. " def __init__(self):\n",
  530. " super(conv_no_bn_net, self).__init__()\n",
  531. " self.stage1 = nn.Sequential(\n",
  532. " nn.Conv2d(1, 6, 3, padding=1),\n",
  533. " nn.ReLU(True),\n",
  534. " nn.MaxPool2d(2, 2),\n",
  535. " nn.Conv2d(6, 16, 5),\n",
  536. " nn.ReLU(True),\n",
  537. " nn.MaxPool2d(2, 2)\n",
  538. " )\n",
  539. " \n",
  540. " self.classfy = nn.Linear(400, 10)\n",
  541. " def forward(self, x):\n",
  542. " x = self.stage1(x)\n",
  543. " x = x.view(x.shape[0], -1)\n",
  544. " x = self.classfy(x)\n",
  545. " return x\n",
  546. "\n",
  547. "net = conv_no_bn_net()\n",
  548. "optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1 "
  549. ]
  550. },
  551. {
  552. "cell_type": "code",
  553. "execution_count": 77,
  554. "metadata": {},
  555. "outputs": [
  556. {
  557. "name": "stdout",
  558. "output_type": "stream",
  559. "text": [
  560. "Epoch 0. Train Loss: 0.211075, Train Acc: 0.935934, Valid Loss: 0.062950, Valid Acc: 0.980123, Time 00:00:27\n",
  561. "Epoch 1. Train Loss: 0.066763, Train Acc: 0.978778, Valid Loss: 0.050143, Valid Acc: 0.984375, Time 00:00:29\n",
  562. "Epoch 2. Train Loss: 0.050870, Train Acc: 0.984292, Valid Loss: 0.039761, Valid Acc: 0.988034, Time 00:00:29\n",
  563. "Epoch 3. Train Loss: 0.041476, Train Acc: 0.986924, Valid Loss: 0.041925, Valid Acc: 0.986155, Time 00:00:29\n",
  564. "Epoch 4. Train Loss: 0.036118, Train Acc: 0.988523, Valid Loss: 0.042703, Valid Acc: 0.986452, Time 00:00:29\n"
  565. ]
  566. }
  567. ],
  568. "source": [
  569. "train(net, train_data, test_data, 5, optimizer, criterion)"
  570. ]
  571. },
  572. {
  573. "cell_type": "markdown",
  574. "metadata": {},
  575. "source": [
  576. "之后介绍一些著名的网络结构的时候,我们会慢慢认识到批标准化的重要性,使用 pytorch 能够非常方便地添加批标准化层"
  577. ]
  578. },
  579. {
  580. "cell_type": "markdown",
  581. "metadata": {},
  582. "source": [
  583. "## References\n",
  584. "* [透彻分析批归一化Batch Normalization强大作用](https://m.toutiaocdn.com/i6641764088760238595)"
  585. ]
  586. }
  587. ],
  588. "metadata": {
  589. "kernelspec": {
  590. "display_name": "Python 3",
  591. "language": "python",
  592. "name": "python3"
  593. },
  594. "language_info": {
  595. "codemirror_mode": {
  596. "name": "ipython",
  597. "version": 3
  598. },
  599. "file_extension": ".py",
  600. "mimetype": "text/x-python",
  601. "name": "python",
  602. "nbconvert_exporter": "python",
  603. "pygments_lexer": "ipython3",
  604. "version": "3.7.9"
  605. }
  606. },
  607. "nbformat": 4,
  608. "nbformat_minor": 2
  609. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。