You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

linear-regression-gradient-descend.ipynb 116 kB


  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# 线性模型和梯度下降\n",
  8. "这是神经网络的第一课,我们会学习一个非常简单的模型,线性回归,同时也会学习一个优化算法-梯度下降法,对这个模型进行优化。线性回归是监督学习里面一个非常简单的模型,同时梯度下降也是深度学习中应用最广的优化算法,我们将从这里开始我们的深度学习之旅"
  9. ]
  10. },
  11. {
  12. "cell_type": "markdown",
  13. "metadata": {},
  14. "source": [
  15. "\n"
  16. ]
  17. },
  18. {
  19. "cell_type": "markdown",
  20. "metadata": {},
  21. "source": [
  22. "## 一元线性回归\n",
  23. "一元线性模型非常简单,假设我们有变量 $x_i$ 和目标 $y_i$,每个 i 对应于一个数据点,希望建立一个模型\n",
  24. "\n",
  25. "$$\n",
  26. "\\hat{y}_i = w x_i + b\n",
  27. "$$\n",
  28. "\n",
  29. "$\\hat{y}_i$ 是我们预测的结果,希望通过 $\\hat{y}_i$ 来拟合目标 $y_i$,通俗来讲就是找到这个函数拟合 $y_i$ 使得误差最小,即最小化\n",
  30. "\n",
  31. "$$\n",
  32. "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n",
  33. "$$"
  34. ]
  35. },
  36. {
  37. "cell_type": "markdown",
  38. "metadata": {},
  39. "source": [
  40. "那么如何最小化这个误差呢?\n",
  41. "\n",
  42. "这里需要用到**梯度下降**,这是我们接触到的第一个优化算法,非常简单,但是却非常强大,在深度学习中被大量使用,所以让我们从简单的例子出发了解梯度下降法的原理"
  43. ]
  44. },
  45. {
  46. "cell_type": "markdown",
  47. "metadata": {},
  48. "source": [
  49. "## 梯度下降法\n",
  50. "在梯度下降法中,我们首先要明确梯度的概念,随后我们再了解如何使用梯度进行下降。"
  51. ]
  52. },
  53. {
  54. "cell_type": "markdown",
  55. "metadata": {},
  56. "source": [
  57. "### 梯度\n",
  58. "梯度在数学上就是导数,如果是一个多元函数,那么梯度就是偏导数。比如一个函数f(x, y),那么 f 的梯度就是 \n",
  59. "\n",
  60. "$$\n",
  61. "(\\frac{\\partial f}{\\partial x},\\ \\frac{\\partial f}{\\partial y})\n",
  62. "$$\n",
  63. "\n",
  64. "可以称为 grad f(x, y) 或者 $\\nabla f(x, y)$。具体某一点 $(x_0,\\ y_0)$ 的梯度就是 $\\nabla f(x_0,\\ y_0)$。\n",
  65. "\n",
  66. "下面这个图片是 $f(x) = x^2$ 这个函数在 x=1 处的梯度\n",
  67. "\n",
  68. "![](https://ws3.sinaimg.cn/large/006tNc79ly1fmarbuh2j3j30ba0b80sy.jpg)"
  69. ]
  70. },
  71. {
  72. "cell_type": "markdown",
  73. "metadata": {},
  74. "source": [
  75. "梯度有什么意义呢?从几何意义来讲,一个点的梯度值是这个函数变化最快的地方,具体来说,对于函数 f(x, y),在点 $(x_0, y_0)$ 处,沿着梯度 $\\nabla f(x_0,\\ y_0)$ 的方向,函数增加最快,也就是说沿着梯度的方向,我们能够更快地找到函数的极大值点,或者反过来沿着梯度的反方向,我们能够更快地找到函数的最小值点。"
  76. ]
  77. },
  78. {
  79. "cell_type": "markdown",
  80. "metadata": {},
  81. "source": [
  82. "### 梯度下降法\n",
  83. "有了对梯度的理解,我们就能了解梯度下降发的原理了。上面我们需要最小化这个误差,也就是需要找到这个误差的最小值点,那么沿着梯度的反方向我们就能够找到这个最小值点。\n",
  84. "\n",
  85. "我们可以来看一个直观的解释。比如我们在一座大山上的某处位置,由于我们不知道怎么下山,于是决定走一步算一步,也就是在每走到一个位置的时候,求解当前位置的梯度,沿着梯度的负方向,也就是当前最陡峭的位置向下走一步,然后继续求解当前位置梯度,向这一步所在位置沿着最陡峭最易下山的位置走一步。这样一步步的走下去,一直走到觉得我们已经到了山脚。当然这样走下去,有可能我们不能走到山脚,而是到了某一个局部的山峰低处。\n",
  86. "\n",
  87. "类比我们的问题,就是沿着梯度的反方向,我们不断改变 w 和 b 的值,最终找到一组最好的 w 和 b 使得误差最小。\n",
  88. "\n",
  89. "在更新的时候,我们需要决定每次更新的幅度,比如在下山的例子中,我们需要每次往下走的那一步的长度,这个长度称为学习率,用 $\\eta$ 表示,这个学习率非常重要,不同的学习率都会导致不同的结果,学习率太小会导致下降非常缓慢,学习率太大又会导致跳动非常明显,可以看看下面的例子\n",
  90. "\n",
  91. "![](https://ws2.sinaimg.cn/large/006tNc79ly1fmgn23lnzjg30980gogso.gif)\n",
  92. "\n",
  93. "可以看到上面的学习率较为合适,而下面的学习率太大,就会导致不断跳动\n",
  94. "\n",
  95. "最后我们的更新公式就是\n",
  96. "\n",
  97. "$$\n",
  98. "w := w - \\eta \\frac{\\partial f(w,\\ b)}{\\partial w} \\\\\n",
  99. "b := b - \\eta \\frac{\\partial f(w,\\ b)}{\\partial b}\n",
  100. "$$\n",
  101. "\n",
  102. "通过不断地迭代更新,最终我们能够找到一组最优的 w 和 b,这就是梯度下降法的原理。\n",
  103. "\n",
  104. "最后可以通过这张图形象地说明一下这个方法\n",
  105. "\n",
  106. "![](https://ws3.sinaimg.cn/large/006tNc79ly1fmarxsltfqj30gx091gn4.jpg)"
  107. ]
  108. },
  109. {
  110. "cell_type": "markdown",
  111. "metadata": {},
  112. "source": [
  113. "\n"
  114. ]
  115. },
  116. {
  117. "cell_type": "markdown",
  118. "metadata": {},
  119. "source": [
  120. "上面是原理部分,下面通过一个例子来进一步学习线性模型"
  121. ]
  122. },
  123. {
  124. "cell_type": "code",
  125. "execution_count": 1,
  126. "metadata": {},
  127. "outputs": [
  128. {
  129. "data": {
  130. "text/plain": [
  131. "<torch._C.Generator at 0x7fe7240c4cd0>"
  132. ]
  133. },
  134. "execution_count": 1,
  135. "metadata": {},
  136. "output_type": "execute_result"
  137. }
  138. ],
  139. "source": [
  140. "import torch\n",
  141. "import numpy as np\n",
  142. "from torch.autograd import Variable\n",
  143. "\n",
  144. "torch.manual_seed(2017)"
  145. ]
  146. },
  147. {
  148. "cell_type": "code",
  149. "execution_count": 8,
  150. "metadata": {},
  151. "outputs": [],
  152. "source": [
  153. "# 读入数据 x 和 y\n",
  154. "x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],\n",
  155. " [9.779], [6.182], [7.59], [2.167], [7.042],\n",
  156. " [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)\n",
  157. "\n",
  158. "y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],\n",
  159. " [3.366], [2.596], [2.53], [1.221], [2.827],\n",
  160. " [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)"
  161. ]
  162. },
  163. {
  164. "cell_type": "code",
  165. "execution_count": 9,
  166. "metadata": {},
  167. "outputs": [
  168. {
  169. "data": {
  170. "text/plain": [
  171. "[<matplotlib.lines.Line2D at 0x7fe6d149d0f0>]"
  172. ]
  173. },
  174. "execution_count": 9,
  175. "metadata": {},
  176. "output_type": "execute_result"
  177. },
  178. {
  179. "data": {
  180. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAD8NJREFUeJzt3X+IHOd9x/HP5yRR++IQt9WRqJLuttCQkpjaShfXrqEYuwY3NXagLrhsXaekHIS0sYuh1DlwSeBKCsX9EUPMYqdR2sVNkE2qmritSAyJoVFYqbJsSYYYqjvLVaqzXct2N3Wr6Ns/ZoVOm7vs7N3uzuwz7xcss/Pco90vy95Hz81+Z9YRIQBAWqaKLgAAMHyEOwAkiHAHgAQR7gCQIMIdABJEuANAggh3AEgQ4Q4ACSLcASBBW4t64u3bt0etVivq6QFgIh06dOjViJjpN6+wcK/Vamq320U9PQBMJNtLeeZxWAYAEkS4A0CC+oa77ctsf9f2c7aP2f7MGnM+ZnvF9pHu7fdGUy4AII88x9zfkXRTRLxte5ukZ20/HRHf6Zn3lYj4/eGXCAAYVN9wj+yC7293d7d1b1wEHgBKLNcxd9tbbB+RdEbSgYg4uMa037B91PY+27vXeZx5223b7ZWVlU2UDQCTp9WSajVpairbtlqje65c4R4RP4yIayTtknSt7at6pvyjpFpE/IKkA5L2rvM4zYioR0R9ZqZvmyYAJKPVkubnpaUlKSLbzs+PLuAH6paJiDckPSPp1p7x1yLine7uo5J+cTjlAUAaFhakTufSsU4nGx+FPN0yM7av7N6/XNItkl7smbNj1e7tkk4Ms0gAmHTLy4ONb1aebpkdkvba3qLsP4OvRsRTtj8rqR0R+yV9yvbtks5Jel3Sx0ZTLgBMptnZ7FDMWuOjkKdb5qikPWuMP7jq/gOSHhhuaQCQjsXF7Bj76kMz09PZ+ChwhioAjEGjITWb0tycZGfbZjMbH4XCLhwGAFXTaIwuzHuxcgeABBHuAJI1zpOGyobDMgCSdOGkoQsfYF44aUga36GRIrFyB5CkcZ80VDaEO4AkjfukobIh3AEkab2Tg0Z10lDZEO4AkrS4mJ0ktNooTxoqG8IdQJLGfdJQ2dAtAyBZ4zxpqGxYuQNAggh3AEgQ4Q4ACSLcASBBhDsAJIhwB4AEEe4AkCDCHRhAlS8hi8nCSUxATlW/hCwmCyt3IKeqX0IWk4VwB3Kq+iVkMVkIdyCnql9CFpOFcAdyqvolZDFZCHcgp6pfQhaThW4ZYABVvoQsJgsrdwBIEOEOAAki3AEgQYQ7ACSIcAeABBHuAJAgwh0AEkS4A0CCCHcASFDfcLd9me3v2n7O9jHbn1ljzk/Y/ortl2wftF0bRbEAgHzyrNzfkXRTRFwt6RpJt9q+rmfOxyX9V0T8nKS/kPRnwy0TADCIvuEembe7u9u6t+iZdoekvd37+yTdbNtDqxIAMJBcx9xtb7F9RNIZSQci4mDPlJ2SXpakiDgn6aykn17jceZtt223V1ZWNlc5AGBducI9In4YEddI2iXpWttXbeTJIqIZEfWIqM/MzGzkIQAAOQzULRMRb0h6RtKtPT96RdJuSbK9VdJ7JL02jAIBAIPL0y0zY/vK7v3LJd0i6cWeafsl3dO9f6ekb0ZE73F5AMCY5Pmyjh2S9treouw/g69GxFO2PyupHRH7JT0m6W9tvyTpdUl3jaxiAEBffcM9Io5K2rPG+IOr7v+PpN8cbmkAgI3iDFUgca2WVKtJU1PZttUquiKMA9+hCiSs1ZLm56VOJ9tfWsr2Jb4LNnWs3IGELSxcDPYLOp1sHGkj3IGELS8PNo50EO5AwmZnBxtHOgh3IGGLi9L09KVj09PZONJGuAMjUoYulUZDajaluTnJzrbNJh+mVgHdMsAIlKlLpdEgzKuIlTswAnSpoGiEOzACdKmgaIQ7MAJ0qaBohDswAnSpoGiEe0WUoXOjSuhSQdHolqmAMnVuVAldKigSK/cKoHMDqB7CvQLo3ACqh3CvADo3gOoh3CuAzg2gegj3CqBzA6geumUqgs4NoFpYuQNAggh3AEgQ4Q4ACSLcASBBhDsAJIhwB4AEEe4AkCDCHcnjcseoIk5iQtK43DGqipU7ksbljlFVhDuSxuWOUVWEO5LG5Y5RVYQ7ksbljlFVhDuSltLljun6wSDolkHyUrjcMV0/GFTflbvt3bafsX3c9jHb964x50bbZ20f6d4eHE25QDXR9YNB5Vm5n5N0f0Qctv1uSYdsH4iI4z3zvh0Rtw2/RAB0/WBQfVfuEXE6Ig53778l6YSknaMuDMBFdP1gUAN9oGq7JmmPpINr/Ph628/Zftr2h9b59/O227bbKysrAxcLVBVdPxhU7nC3fYWkJyTdFxFv9vz4sKS5iLha0uclfW2tx4iIZkTUI6I+MzOz0ZqBykmp6wfj4YjoP8neJukpSf8cEQ/lmH9SUj0iXl1vTr1ej3a7PUCpAADbhyKi3m9enm4ZS3pM0on1gt32+7rzZPva7uO+NljJAIBhydMtc4OkuyU9b/tId+zTkmYlKSIekXSnpE/YPifpB5Luijx/EgAARqJvuEfEs5LcZ87Dkh4eVlEAgM3h8gMAkCDCHQASRLgDQIIIdwBIEOEOAAki3AEgQYQ7ACSIcAeABBHuAJAgwh0AEkS4A0CCCHcASBDhDgAJItwBIEGEOwAkiHAHgAQR7gCQIMIdABJEuANAggh3AEgQ4Q4ACSLcASBBhDsAJIhwB4AEEe4AkCDCHQASRLgDQIIIdxSu1ZJqNWlqKtu2WkVXBEy+rUUXgGprtaT5eanTyfaXlrJ9SWo0iqsLmHSs3FGohYWLwX5Bp5ONA9g4wh2FWl4ebBxAPoQ7CjU7O9g4gHwIdxRqcVGanr50bHo6GwewcYQ7CtVoSM2mNDcn2dm22eTDVGCz6JZB4RoNwhwYtr4rd9u7bT9j+7jtY7bvXWOObf+17ZdsH7X94dGUCwDII8/K/Zyk+yPisO13Szpk+0BEHF8159ckvb97+yVJX+huAQAF6Ltyj4jTEXG4e/8tSSck7eyZdoekL0fmO5KutL1j6NUCAHIZ6ANV2zVJeyQd7PnRTkkvr9o/pR/9D0C25223bbdXVlYGqxQAkFvucLd9haQnJN0XEW9u5MkiohkR9Yioz8zMbOQhAAA55Ap329uUBXsrIp5cY8orknav2t/VHQMAFCBPt4wlPSbpREQ8tM60/ZJ+p9s1c52ksxFxeoh1AgAGkKdb5gZJd0t63vaR7tinJc1KUkQ8Iunrkj4i6SVJHUm/O/xSAQB59Q33iHhWkvvMCUmfHFZRAIDN4fIDAJAgwh0AEkS4A0CCCHcASBDhDgAJItwBIEGEOwAkiHAHgAQR7gCQIMIdABJEuANAggh3AEgQ4Q4ACSLcASBBhDsAJIhwB4AEEe4AkCDCHQASRLgPUasl1WrS1FS2bbWKrgjjxnsAZZHnC7KRQ6slzc9LnU62v7SU7UtSo1FcXRgf3gMoE2ffbT1+9Xo92u12Ic89CrVa9svca25OOnly3NWgCLwHMA62D0VEvd88DssMyfLyYONID+8BlAnhPiSzs4ONV00VjkXzHkCZEO5DsrgoTU9fOjY9nY1X3YVj0UtLUsTFY9GpBTzvAZQJ4T4kjYbUbGbHV+1s22zyQZokLSxc/JDxgk4nG08J7wGUCR+oYuSmprIVey9bOn9+/PUAk4wPVFEaHIsGxo9wx8hxLBoYP8IdI8exaGD8CPdElL3VsNHITuQ5fz7bEuzAaHH5gQRw2juAXqzcE1CVVkMA+RHuCeC0dwC9CPcE0GoIoBfhngBaDQH06hvutr9o+4ztF9b5+Y22z9o+0r09OPwy8ePQagigV55umS9JeljSl3/MnG9HxG1DqQgb0mgQ5gAu6rtyj4hvSXp9DLUAAIZkWMfcr7f9nO2nbX9ovUm25223bbdXVlaG9NQAgF7DCPfDkuYi4mpJn5f0tfUmRkQzIuoRUZ+ZmRnCUwMA1rLpcI+INyPi7e79r0vaZnv7pisDAGzYpsPd9vtsu3v/2u5jvrbZxwUAbFzfbhnbj0u6UdJ226ck/YmkbZIUEY9IulPSJ2yfk/QDSXdFUd8AAgCQlCPcI+K3+vz8YWWtkgCAkuAMVQBIEOEOAAki3AEgQYQ7ACSIcAeABBHuAJAgwh0AEkS4A0CCCHcASBDhPqBWS6rVpKmpbNtqFV0RAPyoPN/EhK5WS5qflzqdbH9pKduX+BYkAOXCyn0ACwsXg/2CTicbB4AyIdwHsLw82DgAFIVwH8Ds7GDjAFAUwn0Ai4vS9PSlY9PT2TgAlAnhPoBGQ2o2pbk5yc62zSYfpgIon4kK9zK0ITYa0smT0vnz2ZZgB1BGE9MKSRsiAOQ3MSt32hABIL+JCXfaEAEgv4kJd9oQASC/iQl32hABIL+JCXfaEAEgv4nplpGyICfMAaC/iVm5AwDyI9wBIEGEOwAkiHAHgAQR7gCQIEdEMU9sr0hayjF1u6RXR1zOJOJ1WR+vzdp4XdY3Sa/NXETM9JtUWLjnZbsdEfWi6ygbXpf18dqsjddlfSm+NhyWAYAEEe4AkKBJCPdm0QWUFK/L+nht1sbrsr7kXpvSH3MHAAxuElbuAIABlTLcbe+2/Yzt47aP2b636JrKxPYW2/9m+6miaykT21fa3mf7RdsnbF9fdE1lYfsPu79LL9h+3PZlRddUFNtftH3G9gurxn7K9gHb3+tuf7LIGoehlOEu6Zyk+yPig5Kuk/RJ2x8suKYyuVfSiaKLKKG/kvRPEfHzkq4Wr5EkyfZOSZ+SVI+IqyRtkXRXsVUV6kuSbu0Z+2NJ34iI90v6Rnd/opUy3CPidEQc7t5/S9kv6c5iqyoH27sk/bqkR4uupUxsv0fSr0h6TJIi4n8j4o1iqyqVrZIut71V0rSk/yi4nsJExLckvd4zfIekvd37eyV9dKxFjUApw3012zVJeyQdLLaS0vhLSX8k6XzRhZTMz0pakfQ33UNWj9p+V9FFlUFEvCLpzyUtSzot6WxE/EuxVZXOeyPidPf+9yW9t8hihqHU4W77CklPSLovIt4sup6i2b5N0pmIOFR0LSW0VdKHJX0hIvZI+m8l8Kf1MHSPH9+h7D/An5H0Ltu/XWxV5RVZC+HEtxGWNtxtb1MW7K2IeLLoekriBkm32z4p6e8l3WT774otqTROSToVERf+wtunLOwh/aqkf4+IlYj4P0lPSvrlgmsqm/+0vUOSutszBdezaaUMd9tWduz0REQ8VHQ9ZRERD0TEroioKftA7JsRwQpMUkR8X9LLtj/QHbpZ0vECSyqTZUnX2Z7u/m7dLD5s7rVf0j3d+/dI+ocCaxmKUoa7shXq3cpWpke6t48UXRRK7w8ktWwflXSNpD8tuJ5S6P41s0/SYUnPK/u9T+6MzLxsPy7pXyV9wPYp2x+X9DlJt9j+nrK/dD5XZI3DwBmqAJCgsq7cAQCbQLgDQIIIdwBIEOEOAAki3AEgQYQ7ACSIcAeABBHuAJCg/weHsaZQFbgrMwAAAABJRU5ErkJggg==\n",
  181. "text/plain": [
  182. "<Figure size 432x288 with 1 Axes>"
  183. ]
  184. },
  185. "metadata": {
  186. "needs_background": "light"
  187. },
  188. "output_type": "display_data"
  189. }
  190. ],
  191. "source": [
  192. "# 画出图像\n",
  193. "import matplotlib.pyplot as plt\n",
  194. "%matplotlib inline\n",
  195. "\n",
  196. "plt.plot(x_train, y_train, 'bo')"
  197. ]
  198. },
  199. {
  200. "cell_type": "code",
  201. "execution_count": 10,
  202. "metadata": {},
  203. "outputs": [],
  204. "source": [
  205. "# 转换成 Tensor\n",
  206. "x_train = torch.from_numpy(x_train)\n",
  207. "y_train = torch.from_numpy(y_train)\n",
  208. "\n",
  209. "# 定义参数 w 和 b\n",
  210. "w = Variable(torch.randn(1), requires_grad=True) # 随机初始化\n",
  211. "b = Variable(torch.zeros(1), requires_grad=True) # 使用 0 进行初始化"
  212. ]
  213. },
  214. {
  215. "cell_type": "code",
  216. "execution_count": 11,
  217. "metadata": {},
  218. "outputs": [],
  219. "source": [
  220. "# 构建线性回归模型\n",
  221. "x_train = Variable(x_train)\n",
  222. "y_train = Variable(y_train)\n",
  223. "\n",
  224. "def linear_model(x):\n",
  225. " return x * w + b"
  226. ]
  227. },
  228. {
  229. "cell_type": "code",
  230. "execution_count": 12,
  231. "metadata": {},
  232. "outputs": [],
  233. "source": [
  234. "y_ = linear_model(x_train)"
  235. ]
  236. },
  237. {
  238. "cell_type": "markdown",
  239. "metadata": {},
  240. "source": [
  241. "经过上面的步骤我们就定义好了模型,在进行参数更新之前,我们可以先看看模型的输出结果长什么样"
  242. ]
  243. },
  244. {
  245. "cell_type": "code",
  246. "execution_count": 13,
  247. "metadata": {},
  248. "outputs": [
  249. {
  250. "data": {
  251. "text/plain": [
  252. "<matplotlib.legend.Legend at 0x7fe6d1458128>"
  253. ]
  254. },
  255. "execution_count": 13,
  256. "metadata": {},
  257. "output_type": "execute_result"
  258. },
  259. {
  260. "data": {
  261. "image/png": "\n",
  262. "text/plain": [
  263. "<Figure size 432x288 with 1 Axes>"
  264. ]
  265. },
  266. "metadata": {
  267. "needs_background": "light"
  268. },
  269. "output_type": "display_data"
  270. }
  271. ],
  272. "source": [
  273. "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
  274. "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
  275. "plt.legend()"
  276. ]
  277. },
  278. {
  279. "cell_type": "markdown",
  280. "metadata": {},
  281. "source": [
  282. "**思考:红色的点表示预测值,似乎排列成一条直线,请思考一下这些点是否在一条直线上?**"
  283. ]
  284. },
  285. {
  286. "cell_type": "markdown",
  287. "metadata": {},
  288. "source": [
  289. "这个时候需要计算我们的误差函数,也就是\n",
  290. "\n",
  291. "$$\n",
  292. "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n",
  293. "$$"
  294. ]
  295. },
  296. {
  297. "cell_type": "code",
  298. "execution_count": 14,
  299. "metadata": {},
  300. "outputs": [],
  301. "source": [
  302. "# 计算误差\n",
  303. "def get_loss(y_, y):\n",
  304. " return torch.mean((y_ - y) ** 2)\n",
  305. "\n",
  306. "loss = get_loss(y_, y_train)"
  307. ]
  308. },
  309. {
  310. "cell_type": "code",
  311. "execution_count": 15,
  312. "metadata": {},
  313. "outputs": [
  314. {
  315. "name": "stdout",
  316. "output_type": "stream",
  317. "text": [
  318. "tensor(94.9309, grad_fn=<MeanBackward1>)\n"
  319. ]
  320. }
  321. ],
  322. "source": [
  323. "# 打印一下看看 loss 的大小\n",
  324. "print(loss)"
  325. ]
  326. },
  327. {
  328. "cell_type": "markdown",
  329. "metadata": {},
  330. "source": [
  331. "定义好了误差函数,接下来我们需要计算 w 和 b 的梯度了,这时得益于 PyTorch 的自动求导,我们不需要手动去算梯度,有兴趣的同学可以手动计算一下,w 和 b 的梯度分别是\n",
  332. "\n",
  333. "$$\n",
  334. "\\frac{\\partial}{\\partial w} = \\frac{2}{n} \\sum_{i=1}^n x_i(w x_i + b - y_i) \\\\\n",
  335. "\\frac{\\partial}{\\partial b} = \\frac{2}{n} \\sum_{i=1}^n (w x_i + b - y_i)\n",
  336. "$$"
  337. ]
  338. },
  339. {
  340. "cell_type": "code",
  341. "execution_count": 16,
  342. "metadata": {},
  343. "outputs": [],
  344. "source": [
  345. "# 自动求导\n",
  346. "loss.backward()"
  347. ]
  348. },
  349. {
  350. "cell_type": "code",
  351. "execution_count": 17,
  352. "metadata": {},
  353. "outputs": [
  354. {
  355. "name": "stdout",
  356. "output_type": "stream",
  357. "text": [
  358. "tensor([-126.6150])\n",
  359. "tensor([-18.3376])\n"
  360. ]
  361. }
  362. ],
  363. "source": [
  364. "# 查看 w 和 b 的梯度\n",
  365. "print(w.grad)\n",
  366. "print(b.grad)"
  367. ]
  368. },
  369. {
  370. "cell_type": "code",
  371. "execution_count": 18,
  372. "metadata": {},
  373. "outputs": [],
  374. "source": [
  375. "# 更新一次参数\n",
  376. "w.data = w.data - 1e-2 * w.grad.data\n",
  377. "b.data = b.data - 1e-2 * b.grad.data"
  378. ]
  379. },
  380. {
  381. "cell_type": "markdown",
  382. "metadata": {},
  383. "source": [
  384. "更新完成参数之后,我们再一次看看模型输出的结果"
  385. ]
  386. },
  387. {
  388. "cell_type": "code",
  389. "execution_count": 19,
  390. "metadata": {},
  391. "outputs": [
  392. {
  393. "data": {
  394. "text/plain": [
  395. "<matplotlib.legend.Legend at 0x7fe6d14283c8>"
  396. ]
  397. },
  398. "execution_count": 19,
  399. "metadata": {},
  400. "output_type": "execute_result"
  401. },
  402. {
  403. "data": {
  404. "image/png": "\n",
  405. "text/plain": [
  406. "<Figure size 432x288 with 1 Axes>"
  407. ]
  408. },
  409. "metadata": {
  410. "needs_background": "light"
  411. },
  412. "output_type": "display_data"
  413. }
  414. ],
  415. "source": [
  416. "y_ = linear_model(x_train)\n",
  417. "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
  418. "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
  419. "plt.legend()"
  420. ]
  421. },
  422. {
  423. "cell_type": "markdown",
  424. "metadata": {},
  425. "source": [
  426. "从上面的例子可以看到,更新之后红色的线跑到了蓝色的线下面,没有特别好的拟合蓝色的真实值,所以我们需要在进行几次更新"
  427. ]
  428. },
  429. {
  430. "cell_type": "code",
  431. "execution_count": 20,
  432. "metadata": {},
  433. "outputs": [
  434. {
  435. "name": "stdout",
  436. "output_type": "stream",
  437. "text": [
  438. "epoch: 0, loss: 1.9595526456832886\n",
  439. "epoch: 1, loss: 0.23876741528511047\n",
  440. "epoch: 2, loss: 0.20673297345638275\n",
  441. "epoch: 3, loss: 0.2059527039527893\n",
  442. "epoch: 4, loss: 0.20575186610221863\n",
  443. "epoch: 5, loss: 0.2055628001689911\n",
  444. "epoch: 6, loss: 0.20537473261356354\n",
  445. "epoch: 7, loss: 0.20518775284290314\n",
  446. "epoch: 8, loss: 0.20500165224075317\n",
  447. "epoch: 9, loss: 0.2048165202140808\n"
  448. ]
  449. },
  450. {
  451. "name": "stderr",
  452. "output_type": "stream",
  453. "text": [
  454. "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:11: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n",
  455. " # This is added back by InteractiveShellApp.init_path()\n"
  456. ]
  457. }
  458. ],
  459. "source": [
  460. "for e in range(10): # 进行 10 次更新\n",
  461. " y_ = linear_model(x_train)\n",
  462. " loss = get_loss(y_, y_train)\n",
  463. " \n",
  464. " w.grad.zero_() # 记得归零梯度\n",
  465. " b.grad.zero_() # 记得归零梯度\n",
  466. " loss.backward()\n",
  467. " \n",
  468. " w.data = w.data - 1e-2 * w.grad.data # 更新 w\n",
  469. " b.data = b.data - 1e-2 * b.grad.data # 更新 b \n",
  470. " print('epoch: {}, loss: {}'.format(e, loss.data[0]))"
  471. ]
  472. },
  473. {
  474. "cell_type": "code",
  475. "execution_count": 21,
  476. "metadata": {},
  477. "outputs": [
  478. {
  479. "data": {
  480. "text/plain": [
  481. "<matplotlib.legend.Legend at 0x7fe6d163f6d8>"
  482. ]
  483. },
  484. "execution_count": 21,
  485. "metadata": {},
  486. "output_type": "execute_result"
  487. },
  488. {
  489. "data": {
  490. "image/png": "\n",
  491. "text/plain": [
  492. "<Figure size 432x288 with 1 Axes>"
  493. ]
  494. },
  495. "metadata": {
  496. "needs_background": "light"
  497. },
  498. "output_type": "display_data"
  499. }
  500. ],
  501. "source": [
  502. "y_ = linear_model(x_train)\n",
  503. "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
  504. "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
  505. "plt.legend()"
  506. ]
  507. },
  508. {
  509. "cell_type": "markdown",
  510. "metadata": {},
  511. "source": [
  512. "经过 10 次更新,我们发现红色的预测结果已经比较好的拟合了蓝色的真实值。\n",
  513. "\n",
  514. "现在你已经学会了你的第一个机器学习模型了,再接再厉,完成下面的小练习。"
  515. ]
  516. },
  517. {
  518. "cell_type": "markdown",
  519. "metadata": {},
  520. "source": [
  521. "**小练习:**\n",
  522. "\n",
  523. "重启 notebook 运行上面的线性回归模型,但是改变训练次数以及不同的学习率进行尝试得到不同的结果"
  524. ]
  525. },
  526. {
  527. "cell_type": "markdown",
  528. "metadata": {},
  529. "source": [
  530. "## 多项式回归模型"
  531. ]
  532. },
  533. {
  534. "cell_type": "markdown",
  535. "metadata": {},
  536. "source": [
  537. "下面我们更进一步,讲一讲多项式回归。什么是多项式回归呢?非常简单,根据上面的线性回归模型\n",
  538. "\n",
  539. "$$\n",
  540. "\\hat{y} = w x + b\n",
  541. "$$\n",
  542. "\n",
  543. "这里是关于 x 的一个一次多项式,这个模型比较简单,没有办法拟合比较复杂的模型,所以我们可以使用更高次的模型,比如\n",
  544. "\n",
  545. "$$\n",
  546. "\\hat{y} = w_0 + w_1 x + w_2 x^2 + w_3 x^3 + \\cdots\n",
  547. "$$\n",
  548. "\n",
  549. "这样就能够拟合更加复杂的模型,这就是多项式模型,这里使用了 x 的更高次,同理还有多元回归模型,形式也是一样的,只是出了使用 x,还是更多的变量,比如 y、z 等等,同时他们的 loss 函数和简单的线性回归模型是一致的。"
  550. ]
  551. },
  552. {
  553. "cell_type": "markdown",
  554. "metadata": {},
  555. "source": [
  556. "\n"
  557. ]
  558. },
  559. {
  560. "cell_type": "markdown",
  561. "metadata": {},
  562. "source": [
  563. "首先我们可以先定义一个需要拟合的目标函数,这个函数是个三次的多项式"
  564. ]
  565. },
  566. {
  567. "cell_type": "code",
  568. "execution_count": 20,
  569. "metadata": {},
  570. "outputs": [
  571. {
  572. "name": "stdout",
  573. "output_type": "stream",
  574. "text": [
  575. "y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3\n"
  576. ]
  577. }
  578. ],
  579. "source": [
  580. "# 定义一个多变量函数\n",
  581. "\n",
  582. "w_target = np.array([0.5, 3, 2.4]) # 定义参数\n",
  583. "b_target = np.array([0.9]) # 定义参数\n",
  584. "\n",
  585. "f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(\n",
  586. " b_target[0], w_target[0], w_target[1], w_target[2]) # 打印出函数的式子\n",
  587. "\n",
  588. "print(f_des)"
  589. ]
  590. },
  591. {
  592. "cell_type": "markdown",
  593. "metadata": {},
  594. "source": [
  595. "我们可以先画出这个多项式的图像"
  596. ]
  597. },
  598. {
  599. "cell_type": "code",
  600. "execution_count": 21,
  601. "metadata": {},
  602. "outputs": [
  603. {
  604. "data": {
  605. "text/plain": [
  606. "<matplotlib.legend.Legend at 0x7ff9c871af28>"
  607. ]
  608. },
  609. "execution_count": 21,
  610. "metadata": {},
  611. "output_type": "execute_result"
  612. },
  613. {
  614. "data": {
  615. "image/png": "\n",
  616. "text/plain": [
  617. "<Figure size 432x288 with 1 Axes>"
  618. ]
  619. },
  620. "metadata": {
  621. "needs_background": "light"
  622. },
  623. "output_type": "display_data"
  624. }
  625. ],
  626. "source": [
  627. "# 画出这个函数的曲线\n",
  628. "x_sample = np.arange(-3, 3.1, 0.1)\n",
  629. "y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 + w_target[2] * x_sample ** 3\n",
  630. "\n",
  631. "plt.plot(x_sample, y_sample, label='real curve')\n",
  632. "plt.legend()"
  633. ]
  634. },
  635. {
  636. "cell_type": "markdown",
  637. "metadata": {},
  638. "source": [
  639. "接着我们可以构建数据集,需要 x 和 y,同时是一个三次多项式,所以我们取了 $x,\\ x^2, x^3$"
  640. ]
  641. },
  642. {
  643. "cell_type": "code",
  644. "execution_count": 23,
  645. "metadata": {},
  646. "outputs": [],
  647. "source": [
  648. "# 构建数据 x 和 y\n",
  649. "# x 是一个如下矩阵 [x, x^2, x^3]\n",
  650. "# y 是函数的结果 [y]\n",
  651. "\n",
  652. "x_train = np.stack([x_sample ** i for i in range(1, 4)], axis=1)\n",
  653. "x_train = torch.from_numpy(x_train).float() # 转换成 float tensor\n",
  654. "\n",
  655. "y_train = torch.from_numpy(y_sample).float().unsqueeze(1) # 转化成 float tensor "
  656. ]
  657. },
  658. {
  659. "cell_type": "markdown",
  660. "metadata": {},
  661. "source": [
  662. "接着我们可以定义需要优化的参数,就是前面这个函数里面的 $w_i$"
  663. ]
  664. },
  665. {
  666. "cell_type": "code",
  667. "execution_count": 25,
  668. "metadata": {},
  669. "outputs": [],
  670. "source": [
  671. "# 定义参数和模型\n",
  672. "w = Variable(torch.randn(3, 1), requires_grad=True)\n",
  673. "b = Variable(torch.zeros(1), requires_grad=True)\n",
  674. "\n",
  675. "# 将 x 和 y 转换成 Variable\n",
  676. "x_train = Variable(x_train)\n",
  677. "y_train = Variable(y_train)\n",
  678. "\n",
  679. "def multi_linear(x):\n",
  680. " return torch.mm(x, w) + b"
  681. ]
  682. },
  683. {
  684. "cell_type": "markdown",
  685. "metadata": {},
  686. "source": [
  687. "我们可以画出没有更新之前的模型和真实的模型之间的对比"
  688. ]
  689. },
  690. {
  691. "cell_type": "code",
  692. "execution_count": 26,
  693. "metadata": {},
  694. "outputs": [
  695. {
  696. "data": {
  697. "text/plain": [
  698. "<matplotlib.legend.Legend at 0x7ff9c867b7b8>"
  699. ]
  700. },
  701. "execution_count": 26,
  702. "metadata": {},
  703. "output_type": "execute_result"
  704. },
  705. {
  706. "data": {
  707. "image/png": "\n",
  708. "text/plain": [
  709. "<Figure size 432x288 with 1 Axes>"
  710. ]
  711. },
  712. "metadata": {
  713. "needs_background": "light"
  714. },
  715. "output_type": "display_data"
  716. }
  717. ],
  718. "source": [
  719. "# 画出更新之前的模型\n",
  720. "y_pred = multi_linear(x_train)\n",
  721. "\n",
  722. "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
  723. "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
  724. "plt.legend()"
  725. ]
  726. },
  727. {
  728. "cell_type": "markdown",
  729. "metadata": {},
  730. "source": [
  731. "可以发现,这两条曲线之间存在差异,我们计算一下他们之间的误差"
  732. ]
  733. },
  734. {
  735. "cell_type": "code",
  736. "execution_count": 27,
  737. "metadata": {},
  738. "outputs": [
  739. {
  740. "name": "stdout",
  741. "output_type": "stream",
  742. "text": [
  743. "tensor(509.5237, grad_fn=<MeanBackward1>)\n"
  744. ]
  745. }
  746. ],
  747. "source": [
  748. "# 计算误差,这里的误差和一元的线性模型的误差是相同的,前面已经定义过了 get_loss\n",
  749. "loss = get_loss(y_pred, y_train)\n",
  750. "print(loss)"
  751. ]
  752. },
  753. {
  754. "cell_type": "code",
  755. "execution_count": 28,
  756. "metadata": {},
  757. "outputs": [],
  758. "source": [
  759. "# 自动求导\n",
  760. "loss.backward()"
  761. ]
  762. },
  763. {
  764. "cell_type": "code",
  765. "execution_count": 29,
  766. "metadata": {},
  767. "outputs": [
  768. {
  769. "name": "stdout",
  770. "output_type": "stream",
  771. "text": [
  772. "tensor([[ -64.6688],\n",
  773. " [ -84.8521],\n",
  774. " [-431.2343]])\n",
  775. "tensor([-16.0116])\n"
  776. ]
  777. }
  778. ],
  779. "source": [
  780. "# 查看一下 w 和 b 的梯度\n",
  781. "print(w.grad)\n",
  782. "print(b.grad)"
  783. ]
  784. },
  785. {
  786. "cell_type": "code",
  787. "execution_count": 30,
  788. "metadata": {},
  789. "outputs": [],
  790. "source": [
  791. "# 更新一下参数\n",
  792. "w.data = w.data - 0.001 * w.grad.data\n",
  793. "b.data = b.data - 0.001 * b.grad.data"
  794. ]
  795. },
  796. {
  797. "cell_type": "code",
  798. "execution_count": 31,
  799. "metadata": {},
  800. "outputs": [
  801. {
  802. "data": {
  803. "text/plain": [
  804. "<matplotlib.legend.Legend at 0x7ff9c8640320>"
  805. ]
  806. },
  807. "execution_count": 31,
  808. "metadata": {},
  809. "output_type": "execute_result"
  810. },
  811. {
  812. "data": {
  813. "image/png": "\n",
  814. "text/plain": [
  815. "<Figure size 432x288 with 1 Axes>"
  816. ]
  817. },
  818. "metadata": {
  819. "needs_background": "light"
  820. },
  821. "output_type": "display_data"
  822. }
  823. ],
  824. "source": [
  825. "# 画出更新一次之后的模型\n",
  826. "y_pred = multi_linear(x_train)\n",
  827. "\n",
  828. "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
  829. "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
  830. "plt.legend()"
  831. ]
  832. },
  833. {
  834. "cell_type": "markdown",
  835. "metadata": {},
  836. "source": [
  837. "因为只更新了一次,所以两条曲线之间的差异仍然存在,我们进行 100 次迭代"
  838. ]
  839. },
  840. {
  841. "cell_type": "code",
  842. "execution_count": 32,
  843. "metadata": {},
  844. "outputs": [
  845. {
  846. "name": "stdout",
  847. "output_type": "stream",
  848. "text": [
  849. "epoch 20, Loss: 24.61406\n",
  850. "epoch 40, Loss: 5.92470\n",
  851. "epoch 60, Loss: 1.55844\n",
  852. "epoch 80, Loss: 0.53303\n",
  853. "epoch 100, Loss: 0.28755\n"
  854. ]
  855. },
  856. {
  857. "name": "stderr",
  858. "output_type": "stream",
  859. "text": [
  860. "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:14: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n",
  861. " \n"
  862. ]
  863. }
  864. ],
  865. "source": [
  866. "# 进行 100 次参数更新\n",
  867. "for e in range(100):\n",
  868. " y_pred = multi_linear(x_train)\n",
  869. " loss = get_loss(y_pred, y_train)\n",
  870. " \n",
  871. " w.grad.data.zero_()\n",
  872. " b.grad.data.zero_()\n",
  873. " loss.backward()\n",
  874. " \n",
  875. " # 更新参数\n",
  876. " w.data = w.data - 0.001 * w.grad.data\n",
  877. " b.data = b.data - 0.001 * b.grad.data\n",
  878. " if (e + 1) % 20 == 0:\n",
  879. " print('epoch {}, Loss: {:.5f}'.format(e+1, loss.data[0]))"
  880. ]
  881. },
  882. {
  883. "cell_type": "markdown",
  884. "metadata": {},
  885. "source": [
  886. "可以看到更新完成之后 loss 已经非常小了,我们画出更新之后的曲线对比"
  887. ]
  888. },
  889. {
  890. "cell_type": "code",
  891. "execution_count": 33,
  892. "metadata": {},
  893. "outputs": [
  894. {
  895. "data": {
  896. "text/plain": [
  897. "<matplotlib.legend.Legend at 0x7ff9c8603e10>"
  898. ]
  899. },
  900. "execution_count": 33,
  901. "metadata": {},
  902. "output_type": "execute_result"
  903. },
  904. {
  905. "data": {
  906. "image/png": "\n",
  907. "text/plain": [
  908. "<Figure size 432x288 with 1 Axes>"
  909. ]
  910. },
  911. "metadata": {
  912. "needs_background": "light"
  913. },
  914. "output_type": "display_data"
  915. }
  916. ],
  917. "source": [
  918. "# 画出更新之后的结果\n",
  919. "y_pred = multi_linear(x_train)\n",
  920. "\n",
  921. "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
  922. "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
  923. "plt.legend()"
  924. ]
  925. },
  926. {
  927. "cell_type": "markdown",
  928. "metadata": {},
  929. "source": [
  930. "可以看到,经过 100 次更新之后,可以看到拟合的线和真实的线已经完全重合了"
  931. ]
  932. },
  933. {
  934. "cell_type": "markdown",
  935. "metadata": {
  936. "collapsed": true
  937. },
  938. "source": [
  939. "**小练习:上面的例子是一个三次的多项式,尝试使用二次的多项式去拟合它,看看最后能做到多好**\n",
  940. "\n",
  941. "**提示:参数 `w = torch.randn(2, 1)`,同时重新构建 x 数据集**"
  942. ]
  943. }
  944. ],
  945. "metadata": {
  946. "kernelspec": {
  947. "display_name": "Python 3",
  948. "language": "python",
  949. "name": "python3"
  950. },
  951. "language_info": {
  952. "codemirror_mode": {
  953. "name": "ipython",
  954. "version": 3
  955. },
  956. "file_extension": ".py",
  957. "mimetype": "text/x-python",
  958. "name": "python",
  959. "nbconvert_exporter": "python",
  960. "pygments_lexer": "ipython3",
  961. "version": "3.5.2"
  962. }
  963. },
  964. "nbformat": 4,
  965. "nbformat_minor": 2
  966. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。