You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

10-regularization.ipynb 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# 正则化\n",
  8. "前面我们讲了数据增强和 dropout,而在实际使用中,现在的网络往往不使用 dropout,而是用另外一个技术,叫正则化。\n",
  9. "\n",
  10. "正则化是机器学习中提出来的一种方法,有 L1 和 L2 正则化,目前使用较多的是 L2 正则化,引入正则化相当于在 loss 函数上面加上一项,比如\n",
  11. "\n",
  12. "$$\n",
  13. "f = loss + \\lambda \\sum_{p \\in params} ||p||_2^2\n",
  14. "$$\n",
  15. "\n",
  16. "就是在 loss 的基础上加上了参数的二范数作为一个正则化,我们在训练网络的时候,不仅要最小化 loss 函数,同时还要最小化参数的二范数,也就是说我们会对参数做一些限制,不让它变得太大。"
  17. ]
  18. },
  19. {
  20. "cell_type": "markdown",
  21. "metadata": {},
  22. "source": [
  23. "如果对新的损失函数 $f$ 求导进行梯度下降,就有\n",
  24. "\n",
  25. "$$\n",
  26. "\\frac{\\partial f}{\\partial p_j} = \\frac{\\partial loss}{\\partial p_j} + 2 \\lambda p_j\n",
  27. "$$\n",
  28. "\n",
  29. "那么在更新参数的时候就有\n",
  30. "\n",
  31. "$$\n",
  32. "p_j \\rightarrow p_j - \\eta (\\frac{\\partial loss}{\\partial p_j} + 2 \\lambda p_j) = p_j - \\eta \\frac{\\partial loss}{\\partial p_j} - 2 \\eta \\lambda p_j \n",
  33. "$$\n"
  34. ]
  35. },
  36. {
  37. "cell_type": "markdown",
  38. "metadata": {},
  39. "source": [
  40. "可以看到 $p_j - \\eta \\frac{\\partial loss}{\\partial p_j}$ 和没加正则项要更新的部分一样,而后面的 $2\\eta \\lambda p_j$ 就是正则项的影响,可以看到加完正则项之后会对参数做更大程度的更新,这也被称为权重衰减(weight decay)。在 PyTorch 中正则项就是通过这种方式来加入的,比如想在随机梯度下降法中使用正则项,或者说权重衰减,`torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-4)` 就可以了,这个 `weight_decay` 系数就是上面公式中的 $\\lambda$,非常方便\n",
  41. "\n",
  42. "注意正则项的系数的大小非常重要,如果太大,会极大的抑制参数的更新,导致欠拟合;如果太小,那么正则项这个部分基本没有贡献。所以选择一个合适的权重衰减系数非常重要,这个需要根据具体的情况去尝试,初步尝试可以使用 `1e-4` 或者 `1e-3` \n",
  43. "\n",
  44. "下面我们在训练 cifar 10 中添加正则项"
  45. ]
  46. },
  47. {
  48. "cell_type": "code",
  49. "execution_count": 1,
  50. "metadata": {
  51. "ExecuteTime": {
  52. "end_time": "2017-12-24T08:02:11.903459Z",
  53. "start_time": "2017-12-24T08:02:11.383170Z"
  54. },
  55. "collapsed": true
  56. },
  57. "outputs": [],
  58. "source": [
  59. "import sys\n",
  60. "sys.path.append('..')\n",
  61. "\n",
  62. "import numpy as np\n",
  63. "import torch\n",
  64. "from torch import nn\n",
  65. "import torch.nn.functional as F\n",
  66. "from torch.autograd import Variable\n",
  67. "from torchvision.datasets import CIFAR10\n",
  68. "from utils import train, resnet\n",
  69. "from torchvision import transforms as tfs"
  70. ]
  71. },
  72. {
  73. "cell_type": "code",
  74. "execution_count": 2,
  75. "metadata": {
  76. "ExecuteTime": {
  77. "end_time": "2017-12-24T08:02:13.120502Z",
  78. "start_time": "2017-12-24T08:02:11.905617Z"
  79. },
  80. "collapsed": true
  81. },
  82. "outputs": [],
  83. "source": [
  84. "def data_tf(x):\n",
  85. " im_aug = tfs.Compose([\n",
  86. " tfs.Resize(96),\n",
  87. " tfs.ToTensor(),\n",
  88. " tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
  89. " ])\n",
  90. " x = im_aug(x)\n",
  91. " return x\n",
  92. "\n",
  93. "train_set = CIFAR10('./data', train=True, transform=data_tf)\n",
  94. "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4)\n",
  95. "test_set = CIFAR10('./data', train=False, transform=data_tf)\n",
  96. "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4)\n",
  97. "\n",
  98. "net = resnet(3, 10)\n",
  99. "optimizer = torch.optim.SGD(net.parameters(), \n",
  100. " lr=0.01, weight_decay=1e-4) # 增加正则项\n",
  101. "criterion = nn.CrossEntropyLoss()"
  102. ]
  103. },
  104. {
  105. "cell_type": "code",
  106. "execution_count": 3,
  107. "metadata": {
  108. "ExecuteTime": {
  109. "end_time": "2017-12-24T08:11:36.106177Z",
  110. "start_time": "2017-12-24T08:02:13.122785Z"
  111. }
  112. },
  113. "outputs": [
  114. {
  115. "name": "stdout",
  116. "output_type": "stream",
  117. "text": [
  118. "Epoch 0. Train Loss: 1.429834, Train Acc: 0.476982, Valid Loss: 1.261334, Valid Acc: 0.546776, Time 00:00:26\n",
  119. "Epoch 1. Train Loss: 0.994539, Train Acc: 0.645400, Valid Loss: 1.310620, Valid Acc: 0.554688, Time 00:00:27\n",
  120. "Epoch 2. Train Loss: 0.788570, Train Acc: 0.723585, Valid Loss: 1.256101, Valid Acc: 0.577433, Time 00:00:28\n",
  121. "Epoch 3. Train Loss: 0.629832, Train Acc: 0.780411, Valid Loss: 1.222015, Valid Acc: 0.609474, Time 00:00:27\n",
  122. "Epoch 4. Train Loss: 0.500406, Train Acc: 0.825288, Valid Loss: 0.831702, Valid Acc: 0.720332, Time 00:00:27\n",
  123. "Epoch 5. Train Loss: 0.388376, Train Acc: 0.868646, Valid Loss: 0.829582, Valid Acc: 0.726760, Time 00:00:27\n",
  124. "Epoch 6. Train Loss: 0.291237, Train Acc: 0.902094, Valid Loss: 1.499777, Valid Acc: 0.623714, Time 00:00:28\n",
  125. "Epoch 7. Train Loss: 0.222401, Train Acc: 0.925072, Valid Loss: 1.832660, Valid Acc: 0.558643, Time 00:00:28\n",
  126. "Epoch 8. Train Loss: 0.157753, Train Acc: 0.947990, Valid Loss: 1.255313, Valid Acc: 0.668117, Time 00:00:28\n",
  127. "Epoch 9. Train Loss: 0.111407, Train Acc: 0.963595, Valid Loss: 1.004693, Valid Acc: 0.724782, Time 00:00:27\n",
  128. "Epoch 10. Train Loss: 0.084960, Train Acc: 0.972926, Valid Loss: 0.867961, Valid Acc: 0.775119, Time 00:00:27\n",
  129. "Epoch 11. Train Loss: 0.066854, Train Acc: 0.979280, Valid Loss: 1.011263, Valid Acc: 0.749604, Time 00:00:28\n",
  130. "Epoch 12. Train Loss: 0.048280, Train Acc: 0.985534, Valid Loss: 2.438345, Valid Acc: 0.576938, Time 00:00:27\n",
  131. "Epoch 13. Train Loss: 0.046176, Train Acc: 0.985614, Valid Loss: 1.008425, Valid Acc: 0.756527, Time 00:00:27\n",
  132. "Epoch 14. Train Loss: 0.039515, Train Acc: 0.988411, Valid Loss: 0.945017, Valid Acc: 0.766317, Time 00:00:27\n",
  133. "Epoch 15. Train Loss: 0.025882, Train Acc: 0.992667, Valid Loss: 0.918691, Valid Acc: 0.784217, Time 00:00:27\n",
  134. "Epoch 16. Train Loss: 0.018592, Train Acc: 0.994985, Valid Loss: 1.507427, Valid Acc: 0.680281, Time 00:00:27\n",
  135. "Epoch 17. Train Loss: 0.021062, Train Acc: 0.994246, Valid Loss: 2.976452, Valid Acc: 0.558940, Time 00:00:27\n",
  136. "Epoch 18. Train Loss: 0.021458, Train Acc: 0.993926, Valid Loss: 0.927871, Valid Acc: 0.785898, Time 00:00:27\n",
  137. "Epoch 19. Train Loss: 0.015656, Train Acc: 0.995824, Valid Loss: 0.962502, Valid Acc: 0.782832, Time 00:00:27\n"
  138. ]
  139. }
  140. ],
  141. "source": [
  142. "from utils import train\n",
  143. "train(net, train_data, test_data, 20, optimizer, criterion)"
  144. ]
  145. }
  146. ],
  147. "metadata": {
  148. "kernelspec": {
  149. "display_name": "Python 3",
  150. "language": "python",
  151. "name": "python3"
  152. },
  153. "language_info": {
  154. "codemirror_mode": {
  155. "name": "ipython",
  156. "version": 3
  157. },
  158. "file_extension": ".py",
  159. "mimetype": "text/x-python",
  160. "name": "python",
  161. "nbconvert_exporter": "python",
  162. "pygments_lexer": "ipython3",
  163. "version": "3.5.4"
  164. }
  165. },
  166. "nbformat": 4,
  167. "nbformat_minor": 2
  168. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。