You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

3-linear-regression-gradient-descend.ipynb 109 kB


  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# 线性模型的PyTorch实现\n",
  8. "\n",
  9. "本节简单回顾一下线性回归模型,并演示一下如何使用PyTorch来对线性回归模型进行建模和模型参数计算。"
  10. ]
  11. },
  12. {
  13. "cell_type": "markdown",
  14. "metadata": {},
  15. "source": [
  16. "\n"
  17. ]
  18. },
  19. {
  20. "cell_type": "markdown",
  21. "metadata": {},
  22. "source": [
  23. "## 1. 一元线性回归\n",
  24. "一元线性回归模型比较简单,假设有变量 $x_i$ 和目标 $y_i$,每个 i 对应于一个数据点,希望建立一个模型\n",
  25. "\n",
  26. "$$\n",
  27. "\\hat{y}_i = w x_i + b\n",
  28. "$$\n",
  29. "\n",
  30. "$\\hat{y}_i$ 是预测的结果,希望通过 $\\hat{y}_i$ 来拟合目标 $y_i$,通俗来讲就是找到这个函数拟合 $y_i$ 使得误差最小,即最小化\n",
  31. "\n",
  32. "$$\n",
  33. "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n",
  34. "$$"
  35. ]
  36. },
  37. {
  38. "cell_type": "markdown",
  39. "metadata": {},
  40. "source": [
  41. "那么如何最小化这个误差呢?"
  42. ]
  43. },
  44. {
  45. "cell_type": "markdown",
  46. "metadata": {},
  47. "source": [
  48. "## 2. 梯度下降法\n",
  49. "\n",
  50. "在梯度下降法中,首先要明确梯度的概念,梯度在数学上就是导数,如果是一个多元函数,那么梯度就是偏导数。比如一个函数$f(x, y)$,那么 $f$ 的梯度就是 \n",
  51. "\n",
  52. "$$\n",
  53. "(\\frac{\\partial f}{\\partial x},\\ \\frac{\\partial f}{\\partial y})\n",
  54. "$$\n",
  55. "\n",
  56. "可以称为 grad f(x, y) 或者 $\\nabla f(x, y)$。具体某一点 $(x_0,\\ y_0)$ 的梯度就是 $\\nabla f(x_0,\\ y_0)$。\n"
  57. ]
  58. },
  59. {
  60. "cell_type": "markdown",
  61. "metadata": {},
  62. "source": [
  63. "梯度有什么意义呢?从几何意义来讲,一个点的梯度值是这个函数变化最快的地方。具体来说,对于函数 f(x, y),在点 $(x_0, y_0)$ 处,沿着梯度 $\\nabla f(x_0,\\ y_0)$ 的方向,函数增加最快,也就是说沿着梯度的方向,能够更快地找到函数的极大值点,或者反过来沿着梯度的反方向,能够更快地找到函数的最小值点。\n",
  64. "\n",
  65. "针对一元线性回归问题,就是沿着梯度的反方向,不断改变 $w$ 和 $b$ 的值,最终找到一组最好的 $w$ 和 $b$ 使得误差最小。\n",
  66. "\n",
  67. "在更新的时候,需要决定每次更新的幅度就是每次往下走的那一步的长度,这个长度称为学习率,用 $\\eta$ 表示。不同的学习率都会导致不同的结果,学习率太小会导致下降非常缓慢;学习率太大又会导致跳动非常明显。\n",
  68. "\n",
  69. "最后我们的更新公式就是\n",
  70. "\n",
  71. "$$\n",
  72. "w := w - \\eta \\frac{\\partial f(w,\\ b)}{\\partial w} \\\\\n",
  73. "b := b - \\eta \\frac{\\partial f(w,\\ b)}{\\partial b}\n",
  74. "$$\n",
  75. "\n",
  76. "通过不断地迭代更新,最终我们能够找到一组最优的 $w$ 和 $b$。"
  77. ]
  78. },
  79. {
  80. "cell_type": "markdown",
  81. "metadata": {},
  82. "source": [
  83. "## 3. PyTorch实现\n",
  84. "\n",
  85. "上面是原理部分,下面通过一个例子来进一步学习线性模型"
  86. ]
  87. },
  88. {
  89. "cell_type": "code",
  90. "execution_count": 1,
  91. "metadata": {},
  92. "outputs": [
  93. {
  94. "data": {
  95. "text/plain": [
  96. "<torch._C.Generator at 0x7f87041343f0>"
  97. ]
  98. },
  99. "execution_count": 1,
  100. "metadata": {},
  101. "output_type": "execute_result"
  102. }
  103. ],
  104. "source": [
  105. "import torch\n",
  106. "import numpy as np\n",
  107. "\n",
  108. "torch.manual_seed(2021)"
  109. ]
  110. },
  111. {
  112. "cell_type": "code",
  113. "execution_count": 2,
  114. "metadata": {},
  115. "outputs": [
  116. {
  117. "name": "stderr",
  118. "output_type": "stream",
  119. "text": [
  120. "Matplotlib is building the font cache; this may take a moment.\n"
  121. ]
  122. },
  123. {
  124. "data": {
  125. "text/plain": [
  126. "[<matplotlib.lines.Line2D at 0x7f85e9151580>]"
  127. ]
  128. },
  129. "execution_count": 2,
  130. "metadata": {},
  131. "output_type": "execute_result"
  132. },
  133. {
  134. "data": {
  135. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOpklEQVR4nO3df4xlZ13H8fd3u25kEG3TnRItzA4YLJDGah1rIdKAVbGNkWCaWJ1AbIwTo1bwL4ibyB9mE0n8Q41RMqlojGNJ2LaKCVYajWCCrd7F/tiyoqXsDEvRTi2C6SSW7X7949y7O7u9M3Pu7D3nPPfe9yuZ3L3nnp35zrOzn/PMc5/nOZGZSJLKdaDrAiRJuzOoJalwBrUkFc6glqTCGdSSVLiDTXzSw4cP5+LiYhOfWpKm0okTJ57LzPlhrzUS1IuLi/R6vSY+tSRNpYhY3+k1hz4kqXAGtSQVzqCWpMIZ1JJUOINakgpnUEtSTWtrsLgIBw5Uj2tr7XzdRqbnSdK0WVuDlRXY2qqer69XzwGWl5v92vaoJamGo0cvhPTA1lZ1vGkGtSTVsLEx2vFxMqglqYaFhdGOj5NBLUk1HDsGc3MXH5ubq443zaCWpBqWl2F1FY4cgYjqcXW1+TcSwVkfklTb8nI7wXwpe9SSVDiDWpIKZ1BLUuEMakkqnEEtSYUzqCWpcAa1JBXOoJakwhnUklQ4g1qSCmdQS1LhDGpJKpxBLUmFM6glqXC1gjoi3hcRJyPiyYh4f8M1SZK22TOoI+J64BeBm4AbgJ+MiDc0XZgkqVKnR/0m4OHM3MrMs8CngXc3W5YkaaBOUJ8EbomIqyNiDrgdeO2lJ0XESkT0IqK3ubk57jolaWbtGdSZeQr4MPAQ8CDwGHB2yHmrmbmUmUvz8/NjL1SSZlWtNxMz848z88bMvAV4HviPZsuSJA3UurltRFyTmc9GxALw08Bbmi1LkjRQ9y7k90XE1cA3gV/JzK81WJMkaZu6Qx9vy8w3Z+YNmfl3TRclzYq1NVhchAMHqse1ta4rUonq9qgljdnaGqyswNZW9Xx9vXoOsLzcXV0qj0vIpY4cPXohpAe2tqrj0nYGtdSRjY3Rjmt2GdRSRxYWRjuu2WVQSx05dgzm5i4+NjdXHZe2M6iljiwvw+oqHDkCEdXj6qpvJOrlnPUhdWh52WDW3uxRS3I+d+HsUUszzvnc5bNHLc0453OXz6CWZpzzuctnUEszzvnc5TOopRnnfO7yGdTSjHM+d/mc9SHJ+dyFs0ctSYUzqCWpcAa1JBXOoJakwhnUklQ4g1qSCmdQSx1y1zrV4TxqqSPuWqe67FFLHXHXOtVlUEsdcdc61WVQSx1x1zrVZVBLHXHXOtVlUEsdcdc61eWsD6lD7lqnOuxRS1LhDGpJM2USFxk59CFpZkzqIiN71JJmxqQuMjKoJe1pEocLhpnURUYGtaRdDYYL1tch88JwwSSG9aQuMjKoJe1qUocLhpnURUa1gjoifj0inoyIkxFxb0R8a9OFSRpNU8MTkzpcMMykLjLaM6gj4lrg14ClzLweuAK4s+nCJNXX5PDEpA4X7GR5GU6fhnPnqsfSQxrqD30cBF4REQeBOeCZ5kqSNKomhycmdbigTU2/2bpnUGfmV4DfATaArwJfz8xPXXpeRKxERC8iepubm+OtUtKumhyemNThgra08WZrZObuJ0RcBdwH/AzwP8DHgeOZ+ec7/Z2lpaXs9Xrjq1LSrhYXq4C41JEj1a/3as642j4iTmTm0rDX6gx9/CjwpczczMxvAvcDb63/5SU1zeGJ7rTxZmudoN4Abo6IuYgI4Fbg1PhKkHS5HJ7oThtvttYZo34EOA58Dnii/3dWx1eCpHGYxNkM06CN32ZqzfrIzA9l5hsz8/rMfE9m/t/4SpCkydXGbzPunidJl6npG0C4hFyShihpIyp71JJ0idL2rbZHrfNK6kFIXSptIyp71ALK60FIXSptIyp71ALK60FIXSptIyqDWkB5PQipS6Wt9DSoBZTXg5C6VNpKT4NaQHk9CKlrJa30NKgFlNeDkHSBsz50XtOrqyTtjz1qSSqcQS1JhTOoJalwBrUkFc6glqTCGdSSVDiDWpIKZ1BLUuEMakkqnEEtSYUzqCWpcAZ1gbwllqTtDOrCDG6Jtb4OmRduiWVYTw4vtBo3g7ow3hJrsnmhVRMM6sJ4S6zJ5oVWTTCoC+MtsSabF1o1waAujLfEmmxeaNUEg7ow3hJrsnmhVRO8FVeBvCXW5Br8ux09Wg13LCxUIe2/py6HQS2NmRdajZtDH5JUOINakgpnUEtS4QxqSSqcQS1JhdszqCPiuoh4dNvHNyLi/S3UJkmixvS8zPwC8H0AEXEF8BXggWbLkiQNjDr0cSvwxcxcb6IYSdLLjRrUdwL3DnshIlYiohcRvc3NzcuvTJIEjBDUEXEI+Cng48Nez8zVzFzKzKX5+flx1SdJM2+UHvVtwOcy87+aKkaS9HKjBPXPssOwhySpObWCOiLmgB8D7m+2HEnSpWrtnpeZW8DVDdciSRrClYmSVDiDWpIKZ1BLUuEMakkqnEEtSYUzqCWpcAa1JBXOoJakwhnUklQ4g1qSCmdQS1LhDGpJKtzEBPXaGiwuwoED1ePaWtcVSVI7JiKo19ZgZQXW1yGzelxZMay74kVTatdEBPXRo7C1dfGxra3quNrlRVNq30QE9cbGaMfbNGu9Sy+aUvsmIqgXFkY73pZZ7F2WfNGUptVEBPWxYzA3d/GxubnqeJdmsXdZ6kVTmmYTEdTLy7C6CkeOQET1uLpaHe/SLPYuS71oStNsIoIaqlA+fRrOnaseuw5pmM3eZakXTWmaTUxQl2hWe5clXjSlaWZQXwZ7l5LaYFBfJnuXKtWsTR2dZge7LkDS+A2mjg5mJQ2mjoKdiUlkj1qaQrM4dXSaGdTSFJrFqaPTzKCWptAsTh2dZga1NIVmderotDKopSnk1NHp4qwPaUotLxvM08IetSQVzqCWpMIZ1JJUOINakgpnUEtS4WoFdURcGRHHI+LfIuJURLyl6cIkSZW60/N+D3gwM++IiEPA3F5/QZI0HnsGdUR8O3AL8PMAmfki8GKzZUmSBuoMfbwe2AT+JCL+NSLuiYhXNlyXJKmvTlAfBG4E/igzvx94AfjgpSdFxEpE9CKit7m5OeYypea4wb5KVyeozwBnMvOR/vPjVMF9kcxczcylzFyan58fZ41SYwYb7K+vQ+aFDfYNa5Vkz6DOzP8EvhwR1/UP3Qp8vtGqpJa4wb4mQd1ZH3cDa/0ZH08DdzVXktQeN9jXJKgV1Jn5KLDUbClS+xYWquGOYcelUrgyUTPNDfY1CQxqzTQ32Nck8MYBmnlusK/S2aOWpMIZ1JJUOINakgpnUEtS4QxqSSqcQS1JhTOoJalwBrUkFa6YoHZPYEkaroiViYM9gQfbTQ72BAZXjElSET1q9wSWpJ0VEdTuCSxJOysiqHfa+9c9gSWpkKB2T2BJ2lkRQe2ewJK0syJmfYB7AkvSToroUUuSdmZQS1LhDGpJKpxBrX1z2b/UjmLeTNRkcdm/1B571NoXl/1L7TGotS8u+5faY1BrX1z2L7XHoNa+uOxfao9BrX1x2b/UHmd9aN9c9i+1wx61JBXOoJakwhnUklQ4g7pgLtGWBL6ZWCyXaEsasEddKJdoSxqo1aOOiNPA/wIvAWczc6nJouQSbUkXjDL08Y7MfK6xSnSRhYVquGPYcUmzxaGPQrlEW9JA3aBO4FMRcSIiVoadEBErEdGLiN7m5ub4KpxRLtGWNBCZufdJEd+Vmc9ExDXAQ8DdmfmZnc5fWlrKXq83xjIlabpFxImd3v+r1aPOzGf6j88CDwA3ja88SdJu9gzqiHhlRLxq8Gfgx4GTTRcmSarUmfXxauCBiBic/xeZ+WCjVUmSztszqDPzaeCGFmqRJA3h9DxJKlytWR8jf9KITWDIco2ZcBiY9YVBtoFtALYBjNYGRzJzftgLjQT1LIuI3qwvsbcNbAOwDWB8beDQhyQVzqCWpMIZ1OO32nUBBbANbAOwDWBMbeAYtSQVzh61JBXOoJakwhnU+xQRPxERX4iIpyLig0NeX46Ix/sfn42IqVvduVcbbDvvByPipYi4o836mlbn+4+It0fEoxHxZER8uu0am1bj/8F3RMRfR8Rj/Ta4q4s6mxQRH42IZyNi6B5IUfn9fhs9HhE3jvxFMtOPET+AK4AvAq8HDgGPAW++5Jy3Alf1/3wb8EjXdbfdBtvO+3vgk8AdXdfd8s/AlcDngYX+82u6rruDNvgN4MP9P88DzwOHuq59zO1wC3AjcHKH128H/gYI4Ob9ZIE96v25CXgqM5/OzBeBjwHv2n5CZn42M7/Wf/ow8JqWa2zanm3QdzdwH/Bsm8W1oM73/3PA/Zm5Aee3CZ4mddoggVdFtavbt1EF9dl2y2xWVnvzP7/LKe8C/iwrDwNXRsR3jvI1DOr9uRb48rbnZ/rHdvILVFfUabJnG0TEtcC7gY+0WFdb6vwMfA9wVUT8Q//uSO9trbp21GmDPwDeBDwDPAG8LzPPtVNeMUbNi5cZ5ea2uiCGHBs6zzEi3kEV1D/caEXtq9MGvwt8IDNf6m+TO03qfP8HgR8AbgVeAfxTRDycmf/edHEtqdMG7wQeBX4E+G7goYj4x8z8RsO1laR2XuzEoN6fM8Brtz1/DVWP4SIR8b3APcBtmfnfLdXWljptsAR8rB/Sh4HbI+JsZv5lKxU2q873fwZ4LjNfAF6IiM9QbRk8LUFdpw3uAn47q8HapyLiS8AbgX9up8Qi1MqL3Tj0sT//ArwhIl4XEYeAO4FPbD8hIhaA+4H3TFEPars92yAzX5eZi5m5CBwHfnlKQhpqfP/AXwFvi4iDETEH/BBwquU6m1SnDTaofqMgIl4NXAc83WqV3fsE8N7+7I+bga9n5ldH+QT2qPchM89GxK8Cf0v1zvdHM/PJiPil/usfAX4TuBr4w36P8mxO0U5iNdtgatX5/jPzVEQ8CDwOnAPuycypuY1dzZ+B3wL+NCKeoBoC+EBmTtXWpxFxL/B24HBEnAE+BHwLnG+DT1LN/HgK2KL6LWO0r9GfPiJJKpRDH5JUOINakgpnUEtS4QxqSSqcQS1JhTOoJalwBrUkFe7/AeTSyedpFuSCAAAAAElFTkSuQmCC\n",
  136. "text/plain": [
  137. "<Figure size 432x288 with 1 Axes>"
  138. ]
  139. },
  140. "metadata": {
  141. "needs_background": "light"
  142. },
  143. "output_type": "display_data"
  144. }
  145. ],
  146. "source": [
  147. "# 生层测试数据\n",
  148. "x_train = np.random.rand(20, 1)\n",
  149. "y_train = x_train * 3 + 4 + 3*np.random.rand(20,1)\n",
  150. "\n",
  151. "# 画出图像\n",
  152. "import matplotlib.pyplot as plt\n",
  153. "%matplotlib inline\n",
  154. "\n",
  155. "plt.plot(x_train, y_train, 'bo')"
  156. ]
  157. },
  158. {
  159. "cell_type": "code",
  160. "execution_count": 3,
  161. "metadata": {},
  162. "outputs": [],
  163. "source": [
  164. "# 转换成 Tensor\n",
  165. "x_train = torch.from_numpy(x_train)\n",
  166. "y_train = torch.from_numpy(y_train)\n",
  167. "\n",
  168. "# 定义参数 w 和 b\n",
  169. "w = torch.randn(1, requires_grad=True) # 随机初始化\n",
  170. "b = torch.zeros(1, requires_grad=True) # 使用 0 进行初始化"
  171. ]
  172. },
  173. {
  174. "cell_type": "code",
  175. "execution_count": 4,
  176. "metadata": {},
  177. "outputs": [],
  178. "source": [
  179. "# 构建线性回归模型\n",
  180. "def linear_model(x):\n",
  181. " return x * w + b\n",
  182. "\n",
  183. "def logistc_regression(x):\n",
  184. " return torch.sigmoid(x*w+b) "
  185. ]
  186. },
  187. {
  188. "cell_type": "code",
  189. "execution_count": 5,
  190. "metadata": {},
  191. "outputs": [],
  192. "source": [
  193. "y_ = linear_model(x_train)"
  194. ]
  195. },
  196. {
  197. "cell_type": "markdown",
  198. "metadata": {},
  199. "source": [
  200. "经过上面的步骤我们就定义好了模型,在进行参数更新之前,我们可以先看看模型的输出结果长什么样"
  201. ]
  202. },
  203. {
  204. "cell_type": "code",
  205. "execution_count": 6,
  206. "metadata": {},
  207. "outputs": [
  208. {
  209. "data": {
  210. "text/plain": [
  211. "<matplotlib.legend.Legend at 0x7f85e9048640>"
  212. ]
  213. },
  214. "execution_count": 6,
  215. "metadata": {},
  216. "output_type": "execute_result"
  217. },
  218. {
  219. "data": {
  220. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAW1UlEQVR4nO3df2xd5X3H8c/XsUMwoRQlHgJS2zANmpT8IDEsbUcIkIa0QSuI/lHmloYOBchAbFo7YJFGpRR1SBMp6fhlZSlacYdGoIxNWcdaYKEiFGzq0DYpCQtOcGBKYlAK+aEk9nd/HNtxnHt9z7XPOfe5975fkuXccy/nPPfYfO7j53zP85i7CwAQrppSNwAAMDqCGgACR1ADQOAIagAIHEENAIGrTWOnU6dO9ebm5jR2DQAVqbOzc5+7N+R6LpWgbm5uVkdHRxq7BoCKZGY78z3H0AcABI6gBoDAEdQAELhUxqhzOXr0qHp6enT48OGsDlnxJk2apGnTpqmurq7UTQGQosyCuqenR6effrqam5tlZlkdtmK5u3p7e9XT06Pzzjuv1M0BkKLMhj4OHz6sKVOmENIJMTNNmTKFv1CADLW3S83NUk1N9L29PZvjZtajlkRIJ4zzCWSnvV1avlw6eDB6vHNn9FiSWlvTPTYXEwEghpUrj4f0oIMHo+1pI6hjam5u1r59+0rdDAAlsmtXcduTFGxQpzkW5O7q7+9PbocAKl5jY3HbkxRkUA+OBe3cKbkfHwsaT1h3d3dr+vTpWrFihebOnatVq1bpkksu0axZs3TvvfcOve7aa6/VvHnz9JnPfEZtbW0JvBsAleC++6T6+hO31ddH29MWZFCnNRb01ltv6cYbb9T999+v3bt367XXXlNXV5c6Ozu1ceNGSdK6devU2dmpjo4OrVmzRr29veM7KICK0NoqtbVJTU2SWfS9rS39C4lSxlUfcaU1FtTU1KT58+frW9/6lp5//nldfPHFkqSPP/5Y27dv14IFC7RmzRr95Cc/kSS9++672r59u6ZMmTK+AwOoCK2t2QTzSEEGdWNjNNyRa/t4nHbaaZKiMep77rlHt9xyywnPv/TSS/rZz36mTZs2qb6+XgsXLqROGUDJBTn0kfZY0NVXX61169bp448/liTt3r1be/bs0f79+3XmmWeqvr5ev/vd7/Tqq68mc0AAGIcge9SDf1qsXBkNdzQ2RiGd1J8cixcv1tatW/XZz35WkjR58mQ98cQTWrJkiR599FHNmjVLF154oebPn5/MAQFgHMzdE99pS0uLj1w4YOvWrZo+fXrix6p2nFegMphZp7u35HouyKEPAMBxBDUABI6gBkqoVLOxobwEeTERqAalnI0N5YUeNVAipZyNDeWFoAZKpJSzsaG8ENQ5PP7443rvvfeGHt98883asmXLuPfb3d2tH//4x0X/d8uWLdP69evHfXyEpZSzsY3EWHnYwg3qEv7mjAzqtWvXasaMGePe71iDGpWplLOxDZfGbJVIVphBndJvzhNPPKFLL71Uc+bM0S233KK+vj4tW7ZMF110kWbOnKnVq1dr/fr16ujoUGtrq+bMmaNDhw5p4cKFGryBZ/Lkybrrrrs0b948LVq0SK+99poWLlyo888/X88995ykKJAvu+wyzZ07V3PnztUrr7wiSbr77rv18ssva86cOVq9erX6+vr07W9/e2i61ccee0xSNBfJ7bffrhkzZmjp0qXas2fPuN43wlTK2diGY6y8DLh74l/z5s3zkbZs2XLStryamtyjiD7xq6kp/j5yHP+aa67xI0eOuLv7bbfd5t/5znd80aJFQ6/58MMP3d398ssv99dff31o+/DHknzDhg3u7n7ttdf6F77wBT9y5Ih3dXX57Nmz3d39wIEDfujQIXd337Ztmw+ejxdffNGXLl06tN/HHnvMV61a5e7uhw8f9nnz5vmOHTv86aef9kWLFvmxY8d89+7dfsYZZ/hTTz2V930B42GW+383s1K3rLpI6vA8mRpmeV4KV1l+/vOfq7OzU5dccokk6dChQ1qyZIl27NihO+64Q0uXLtXixYsL7mfixIlasmSJJGnmzJk65ZRTVFdXp5kzZ6q7u1uSdPToUd1+++3q6urShAkTtG3btpz7ev755/Xmm28OjT/v379f27dv18aNG3XDDTdowoQJOuecc3TllVeO+X0DhaQ1WyWSE+bQRwpXWdxd3/jGN9TV1aWuri699dZbevDBB7V582YtXLhQDz30kG6++eaC+6mrqxta/bumpkannHLK0L+PHTsmSVq9erXOOussbd68WR0dHTpy5EjeNv3gBz8YatM777wz9GHBCuPISihj5cgvzKBO4Tfnqquu0vr164fGez/44APt3LlT/f39uv7667Vq1Sq98cYbkqTTTz9dH3300ZiPtX//fp199tmqqanRj370I/X19eXc79VXX61HHnlER48elSRt27ZNBw4c0IIFC/Tkk0+qr69P77//vl588cUxtwVhC6HaIpSxcuQX5tBHCvOczpgxQ9/97ne1ePFi9ff3q66uTg888ICuu+66oYVuv/e970mKyuFuvfVWnXrqqdq0aVPRx1qxYoWuv/56PfXUU7riiiuGFiyYNWuWamtrNXv2bC1btkx33nmnuru7NXfuXLm7Ghoa9Oyzz+q6667TCy+8oJkzZ+qCCy7Q5ZdfPub3jXCFdGdiqVYuQTxMc1rmOK/lq7k599hwU5M0cLkDVYRpToEAcWci4iKogRIJ6c5EhC3ToE5jmKWacT7LG9UWpRHCBdxiZRbUkyZNUm9vL+GSEHdXb2+vJk2aVOqmYIzKqdqiHMMtl3K9XT7WxUQz+ytJN0tySb+WdJO7H873+lwXE48ePaqenh4dPpz3P0ORJk2apGnTpqmurq7UTUEFG1mdIkU9/1A/VEYT8gXc0S4mFgxqMztX0i8kzXD3Q2b2r5I2uPvj+f6bXEENIF3t7YlWtA4JOdyKVVMT9aRHMpMGqnRLJomqj1pJp5pZraR6Se8VeD2ADKX5J30lVaekdQE37aGhgkHt7rsl/YOkXZLel7Tf3Z9PthkAxiPNGfAqqToljQu4WYx7FwxqMztT0pclnSfpHEmnmdnXcrxuuZl1mFnH3r17k2shgILS7PVWUnVKGhdws5gmNs7QxyJJ77j7Xnc/KukZSZ8b+SJ3b3P3FndvaWhoSK6FAApKs9dbTtUpcbS2RmPr/f3R93zvI+5wRhZDQ3GCepek+WZWb9GUbldJ2ppcExCKSinBqkZp93rjhlulKGY4I4uhoThj1L+UtF7SG4pK82oktSXXBISgXOtLEam0Xm+pFTOckcXQUGaTMiFslVSCBYxXsWV8SZRGjlaeF+Y0p8hcJZVgAeNV7Ko3aU8Ty6RMkFRZJVjAeIVW6UJQQ1J4v5hAKYU25s/QBySlsqgOUNZCWvWGoMaQkH4xARzH0AcABI6gDhA3ngAYjqAODDeelD8+aJE0gjowWUzwgvTwQYs0ENSB4caT8sYHLdJAUAeGG0/KGx+0SANBHRhuPClvfNAiDQR1YEK7IwrF4YMWaeCGlwBx40n54g5PpIGgBhLGBy2SxtAHAASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHBlE9TM8RsOfhZAtsoiqEOe47faQivknwVQqczdE99pS0uLd3R0JLa/5uYoEEZqapK6uxM7TNEGQ2v4/MP19ZU9iVKoPwug3JlZp7u35HyuHIK6pibqvY1kJvX3J3aYolVjaIX6swDK3WhBXRZDH6HO8VuNk8SH+rPAyaptWK6SlUVQhzrHbzWGVqg/C5yIawmVpSyCOtTJ9KsxtEL9WeBErN1YWcpijDpk7e1MEo/wcC2h/Iw2Rs3CAePEJPEIUWNj7gvdlTwsV8nKYugDQHGqcViukhHUQAXiWkJliTX0YWaflLRW0kWSXNI33X1Tiu0CME4My1WOuD3qByX91N0/LWm2pK3pNQnIFvXGCF3BHrWZfULSAknLJMndj0g6km6zgGyMnAZgsN5YojeKcMTpUZ8vaa+kH5rZr8xsrZmdNvJFZrbczDrMrGPv3r2JNxRIA/XGKAdxgrpW0lxJj7j7xZIOSLp75Ivcvc3dW9y9paGhIeFmAumoxmkAUH7iBHWPpB53/+XA4/WKghsoe9U4DQDKT8Ggdvf/k/SumV04sOkqSVtSbRWQEeqNUQ7i3pl4h6R2M5soaYekm9JrEpCdwQuGTAOAkDHXBwAEoCzmo6aWFQByC2JSJmpZASC/IHrU1LICQH5BBDW1rACQXxBBTS0rAOQXRFBTy1qeuAAMZCOIoGbu3PLD4qlAdqijxpg0N+de6qmpSeruzro1QPkrizpqlBcuAAPZIagxJlwABrJDUAcs5It1XAAGskNQByr0i3VcAAayw8XEQHGxDqguXEwsQ1ysAzCIoA4UF+sADCKoA8XFOgCDCOpAcbEOwKAg5qNGbq2tBDMAetQAEDyCGgACR1ADQOAIagAIHEENAIEjqAEgcAQ1AASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHAENQAELnZQm9kEM/uVmf1Hmg0CAJyomB71nZK2ptUQAEBusYLazKZJWippbbrNAQCMFLdH/X1JfyOpP98LzGy5mXWYWcfevXuTaBsAQDGC2syukbTH3TtHe527t7l7i7u3NDQ0JNZAAKh2cXrUn5f0p2bWLelJSVea2ROptgoAMKRgULv7Pe4+zd2bJX1V0gvu/rXUWwYAkEQdNQAEr7aYF7v7S5JeSqUlAICc6FEDQOAIagAIHEENAIEjqAEgcAQ1AASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHAENQAEjqAGgMAR1AAQOIIaAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDAEdQAEDiCGgACR1ADQOAIagAIHEENAIEjqAEgcAQ1AIxXe7vU3CzV1ETf29sT3T1BDQC55ArffNuWL5d27pTco+/Llyca1ubuie1sUEtLi3d0dCS+XwDIxGD4Hjx4fNvEiVEQHz16fFt9vXTqqVJv78n7aGqSurtjH9LMOt29Jddz9KgBVJc4wxQrV54Y0pJ05MiJIS1Fr8kV0pK0a1cSrZUk1Sa2JwAI3cie8uAwhSS1th5/XRIh29g4/n0MKNijNrNPmdmLZrbVzH5rZncmdnQAyFKunvLBg9H24YoJ2SlToiGQ4errpfvuG1sbc4gz9HFM0l+7+3RJ8yX9hZnNSKwFAJCVfD3lkdvvu+/k8J04UaqrO3Fbfb304INSW1s0Jm0WfW9rO7GHPk4Fhz7c/X1J7w/8+yMz2yrpXElbEmsFAGShsTEa7si1fbjBkF25MgrxxsbjPeSR2wZfm2Awj1RU1YeZNUvaKOkid//9iOeWS1ouSY2NjfN25joZAFBKuao56usT7wGPRSJVH2Y2WdLTkv5yZEhLkru3uXuLu7c0NDSMvbUAkJbW1tSHKdIQK6jNrE5RSLe7+zPpNglAcFK+8y5Tra1RfXN/f/Q98JCWYoxRm5lJ+idJW939gfSbBCAocUvakJo4PerPS/q6pCvNrGvg60sptwtAlkbrMcctaUNqCga1u//C3c3dZ7n7nIGvDVk0DsA4xB2uKDRXRdySNqSGW8iBSjE8mKdOlb75zXgTBRXqMee7+SPBO+8wOoIaqAQje8W9vdHcFMPlG64o1GPOdfNHwnfeYXQENVAJcvWKc8kVyoV6zGVa0lZJCGqglFaskGprowCsrY0ej0Xc8eJcoRynx1yGJW2VhKAGkhb3It6KFdIjj0h9fdHjvr7o8VjCOs54cb7hCnrMwWPhACBJxdyiXFt7PKSHmzBBOnZs/Metq5M+8Qnpgw9OnpcCwWHhACArxdQc5wrp0baPJlev+Ic/lPbtY7iiAhDUQJK3RxdTczxhQu7X5tteCOPIFYugRnVLemHSYmqOB2/DjrsdVYugRnVL+vboYmqOH35Yuu224z3oCROixw8/PLZjo2JxMRHVraYm6kmPZBYNIYxFe3v+yeWBPEa7mMjitqhucVf8KEZrK8GMRDH0gerG7dEoAwQ1qhs3e6AMMPQBMFSBwNGjRvEG644H56cwK//lmYCAEdQ4Ls6NH8PrjqXjd9GNt/4YQF4ENSJxb/wYbTpNlmcCUkFQIxL3xo9C02myPBOQOIIakbhzVBSqL2Z5JiBxBDUiceeoyFV3PIj6YyAVBHWISlFVEffGj+F1x9LxeSqoPwZSQx11aEZOAD+yqkJKJwwH9xlnjgrqjoFMMSlTaJqbc889MaipKZprGEBFYYWXNCU56bxEVQWAkxDU45H0pPMSVRUATkJQj0fSk85LVFUAOAlBPR7FrI8XF1UVAEag6mM80ph0XqKqAsAJ6FGPB5POA8hA+QR10tUVSWDSeQAZCCeoRwviNKorktLaGtU19/dH3wlpAAkLI6gLBXEa1RUAUCZiBbWZLTGzt8zsbTO7O/FWFAriNKorAKBMFAxqM5sg6SFJX5Q0Q9INZjYj0VYUCuK4M7sBQAWK06O+VNLb7r7D3Y9IelLSlxNtRaEgproCQBWLE9TnSnp32OOegW0nMLPlZtZhZh179+4trhWFgpjqCgBVLE5QW45tJ0255+5t7t7i7i0NDQ3FtSJOEFNdAaBKxbkzsUfSp4Y9nibpvcRbwt14AJBTnB7165L+yMzOM7OJkr4q6bl0mwUAGFSwR+3ux8zsdkn/JWmCpHXu/tvUWwYAkBRzUiZ33yBpQ8ptAQDkEMadiQCAvAhqAAhcKovbmtleSaOs0FrRpkraV+pGlBjngHMgcQ6k4s5Bk7vnrG1OJairmZl15FtJuFpwDjgHEudASu4cMPQBAIEjqAEgcAR18tpK3YAAcA44BxLnQEroHDBGDQCBo0cNAIEjqAEgcAT1GBVanszMWs3szYGvV8xsdinamaa4S7SZ2SVm1mdmX8myfWmL8/7NbKGZdZnZb83sf7JuY9pi/H9whpn9u5ltHjgHN5WinWkys3VmtsfMfpPneTOzNQPn6E0zm1v0QdydryK/FE1O9b+Szpc0UdJmSTNGvOZzks4c+PcXJf2y1O3O+hwMe90LiuaK+Uqp253x78AnJW2R1Djw+A9K3e4SnIO/lXT/wL8bJH0gaWKp257weVggaa6k3+R5/kuS/lPR3P7zx5IF9KjHpuDyZO7+irt/OPDwVUXzeFeSuEu03SHpaUl7smxcBuK8/z+T9Iy775Ikd6/Gc+CSTjczkzRZUVAfy7aZ6XL3jYreVz5flvTPHnlV0ifN7OxijkFQj02s5cmG+XNFn6iVpOA5MLNzJV0n6dEM25WVOL8DF0g608xeMrNOM7sxs9ZlI845+EdJ0xUtNvJrSXe6e382zQtGsXlxkljTnOIksZYnkyQzu0JRUP9Jqi3KXpxz8H1Jd7l7X9Shqihx3n+tpHmSrpJ0qqRNZvaqu29Lu3EZiXMOrpbUJelKSX8o6b/N7GV3/33KbQtJ7LzIh6Aem1jLk5nZLElrJX3R3XszaltW4pyDFklPDoT0VElfMrNj7v5sJi1MV5z33yNpn7sfkHTAzDZKmi2pUoI6zjm4SdLfezRY+7aZvSPp05Jey6aJQRj3coYMfYxNweXJzKxR0jOSvl5BPajhCp4Ddz/P3ZvdvVnSekkrKiSkpXhL1P2bpMvMrNbM6iX9saStGbczTXHOwS5Ff1HIzM6SdKGkHZm2svSek3TjQPXHfEn73f39YnZAj3oMPM/yZGZ268Dzj0r6O0lTJD080KM85hU0k1jMc1Cx4rx/d99qZj+V9Kakfklr3T1nCVc5ivk7sErS42b2a0VDAHe5e0VNfWpm/yJpoaSpZtYj6V5JddLQOdigqPLjbUkHFf2VUdwxBspHAACBYugDAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDA/T9GRnWgZHl9GwAAAABJRU5ErkJggg==\n",
  221. "text/plain": [
  222. "<Figure size 432x288 with 1 Axes>"
  223. ]
  224. },
  225. "metadata": {
  226. "needs_background": "light"
  227. },
  228. "output_type": "display_data"
  229. }
  230. ],
  231. "source": [
  232. "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
  233. "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
  234. "plt.legend()"
  235. ]
  236. },
  237. {
  238. "cell_type": "markdown",
  239. "metadata": {},
  240. "source": [
  241. "**思考:红色的点表示预测值,似乎排列成一条直线,请思考一下这些点是否在一条直线上?**"
  242. ]
  243. },
  244. {
  245. "cell_type": "markdown",
  246. "metadata": {},
  247. "source": [
  248. "这个时候需要计算我们的误差函数,也就是\n",
  249. "\n",
  250. "$$\n",
  251. "E = \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n",
  252. "$$"
  253. ]
  254. },
  255. {
  256. "cell_type": "code",
  257. "execution_count": 7,
  258. "metadata": {},
  259. "outputs": [],
  260. "source": [
  261. "# 计算误差\n",
  262. "def get_loss(y_, y):\n",
  263. " return torch.sum((y_ - y) ** 2)\n",
  264. "\n",
  265. "loss = get_loss(y_, y_train)"
  266. ]
  267. },
  268. {
  269. "cell_type": "code",
  270. "execution_count": 8,
  271. "metadata": {},
  272. "outputs": [
  273. {
  274. "name": "stdout",
  275. "output_type": "stream",
  276. "text": [
  277. "tensor(733.2964, dtype=torch.float64, grad_fn=<SumBackward0>)\n"
  278. ]
  279. }
  280. ],
  281. "source": [
  282. "# 打印一下看看 loss 的大小\n",
  283. "print(loss)"
  284. ]
  285. },
  286. {
  287. "cell_type": "markdown",
  288. "metadata": {},
  289. "source": [
  290. "定义好了误差函数,接下来我们需要计算 $w$ 和 $b$ 的梯度了,这时得益于 PyTorch 的自动求导,不需要手动去算梯度就可以得到计算好的梯度值。手动计算的$w$ 和 $b$ 的梯度分别是\n",
  291. "\n",
  292. "$$\n",
  293. "\\frac{\\partial}{\\partial w} = \\frac{2}{n} \\sum_{i=1}^n x_i(w x_i + b - y_i) \\\\\n",
  294. "\\frac{\\partial}{\\partial b} = \\frac{2}{n} \\sum_{i=1}^n (w x_i + b - y_i)\n",
  295. "$$"
  296. ]
  297. },
  298. {
  299. "cell_type": "code",
  300. "execution_count": 9,
  301. "metadata": {},
  302. "outputs": [],
  303. "source": [
  304. "# 自动求导\n",
  305. "loss.backward()"
  306. ]
  307. },
  308. {
  309. "cell_type": "code",
  310. "execution_count": 10,
  311. "metadata": {},
  312. "outputs": [
  313. {
  314. "name": "stdout",
  315. "output_type": "stream",
  316. "text": [
  317. "tensor([-135.3880])\n",
  318. "tensor([-239.5816])\n"
  319. ]
  320. }
  321. ],
  322. "source": [
  323. "# 查看 w 和 b 的梯度\n",
  324. "print(w.grad)\n",
  325. "print(b.grad)"
  326. ]
  327. },
  328. {
  329. "cell_type": "code",
  330. "execution_count": 11,
  331. "metadata": {},
  332. "outputs": [],
  333. "source": [
  334. "# 更新一次参数\n",
  335. "w.data = w.data - 1e-2 * w.grad.data\n",
  336. "b.data = b.data - 1e-2 * b.grad.data"
  337. ]
  338. },
  339. {
  340. "cell_type": "markdown",
  341. "metadata": {},
  342. "source": [
  343. "更新完成参数之后,我们再一次看看模型输出的结果"
  344. ]
  345. },
  346. {
  347. "cell_type": "code",
  348. "execution_count": 12,
  349. "metadata": {},
  350. "outputs": [
  351. {
  352. "data": {
  353. "text/plain": [
  354. "<matplotlib.legend.Legend at 0x7f85e8fcb2e0>"
  355. ]
  356. },
  357. "execution_count": 12,
  358. "metadata": {},
  359. "output_type": "execute_result"
  360. },
  361. {
  362. "data": {
  363. "image/png": "\n",
  364. "text/plain": [
  365. "<Figure size 432x288 with 1 Axes>"
  366. ]
  367. },
  368. "metadata": {
  369. "needs_background": "light"
  370. },
  371. "output_type": "display_data"
  372. }
  373. ],
  374. "source": [
  375. "y_ = linear_model(x_train)\n",
  376. "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
  377. "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
  378. "plt.legend()"
  379. ]
  380. },
  381. {
  382. "cell_type": "markdown",
  383. "metadata": {},
  384. "source": [
  385. "从上面的例子可以看到,更新之后红色的线跑到了蓝色的线下面,没有特别好的拟合蓝色的真实值,所以我们需要在进行几次更新"
  386. ]
  387. },
  388. {
  389. "cell_type": "code",
  390. "execution_count": 13,
  391. "metadata": {},
  392. "outputs": [
  393. {
  394. "name": "stdout",
  395. "output_type": "stream",
  396. "text": [
  397. "epoch: 19, loss: 17.798984092741378\n",
  398. "epoch: 39, loss: 16.14508120463308\n",
  399. "epoch: 59, loss: 15.55101918276564\n",
  400. "epoch: 79, loss: 15.33763961353287\n",
  401. "epoch: 99, loss: 15.26099545058815\n"
  402. ]
  403. }
  404. ],
  405. "source": [
  406. "for e in range(100): # 进行 100 次更新\n",
  407. " y_ = linear_model(x_train)\n",
  408. " loss = get_loss(y_, y_train)\n",
  409. " \n",
  410. " w.grad.zero_() # 注意:归零梯度\n",
  411. " b.grad.zero_() # 注意:归零梯度\n",
  412. " loss.backward()\n",
  413. " \n",
  414. " w.data = w.data - 1e-2 * w.grad.data # 更新 w\n",
  415. " b.data = b.data - 1e-2 * b.grad.data # 更新 b \n",
  416. " if (e + 1) % 20 == 0:\n",
  417. " print('epoch: {}, loss: {}'.format(e, loss.item()))"
  418. ]
  419. },
  420. {
  421. "cell_type": "code",
  422. "execution_count": 14,
  423. "metadata": {},
  424. "outputs": [
  425. {
  426. "data": {
  427. "text/plain": [
  428. "<matplotlib.legend.Legend at 0x7f85e8735970>"
  429. ]
  430. },
  431. "execution_count": 14,
  432. "metadata": {},
  433. "output_type": "execute_result"
  434. },
  435. {
  436. "data": {
  437. "image/png": "\n",
  438. "text/plain": [
  439. "<Figure size 432x288 with 1 Axes>"
  440. ]
  441. },
  442. "metadata": {
  443. "needs_background": "light"
  444. },
  445. "output_type": "display_data"
  446. }
  447. ],
  448. "source": [
  449. "y_ = linear_model(x_train)\n",
  450. "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
  451. "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
  452. "plt.legend()"
  453. ]
  454. },
  455. {
  456. "cell_type": "markdown",
  457. "metadata": {},
  458. "source": [
  459. "经过 100 次更新,我们发现红色的预测结果已经比较好的拟合了蓝色的真实值。\n",
  460. "\n",
  461. "现在你已经学会了你的第一个机器学习模型了,再接再厉,完成下面的小练习。"
  462. ]
  463. },
  464. {
  465. "cell_type": "markdown",
  466. "metadata": {},
  467. "source": [
  468. "## 4. 多项式回归模型"
  469. ]
  470. },
  471. {
  472. "cell_type": "markdown",
  473. "metadata": {},
  474. "source": [
  475. "下面更进一步尝试一下多项式回归,下面是关于 x 的多项式:\n",
  476. "\n",
  477. "$$\n",
  478. "\\hat{y} = w_0 + w_1 x + w_2 x^2 + w_3 x^3 \n",
  479. "$$\n",
  480. "\n",
  481. "这样就能够拟合更加复杂的模型,这就是多项式模型,这里使用了 $x$ 的更高次,同理还有多元回归模型,形式也是一样的,只是出了使用 $x$,还是更多的变量,比如 $y$、$z$ 等等,同时他们的 $loss$ 函数和简单的线性回归模型是一致的。"
  482. ]
  483. },
  484. {
  485. "cell_type": "markdown",
  486. "metadata": {},
  487. "source": [
  488. "\n"
  489. ]
  490. },
  491. {
  492. "cell_type": "markdown",
  493. "metadata": {},
  494. "source": [
  495. "首先我们可以先定义一个需要拟合的目标函数,这个函数是个三次的多项式"
  496. ]
  497. },
  498. {
  499. "cell_type": "code",
  500. "execution_count": 15,
  501. "metadata": {},
  502. "outputs": [
  503. {
  504. "name": "stdout",
  505. "output_type": "stream",
  506. "text": [
  507. "y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3\n"
  508. ]
  509. }
  510. ],
  511. "source": [
  512. "# 定义一个多变量函数\n",
  513. "\n",
  514. "w_target = np.array([0.5, 3, 2.4]) # 定义参数\n",
  515. "b_target = np.array([0.9]) # 定义参数\n",
  516. "\n",
  517. "f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(\n",
  518. " b_target[0], w_target[0], w_target[1], w_target[2]) # 打印出函数的式子\n",
  519. "\n",
  520. "print(f_des)"
  521. ]
  522. },
  523. {
  524. "cell_type": "markdown",
  525. "metadata": {},
  526. "source": [
  527. "我们可以先画出这个多项式的图像"
  528. ]
  529. },
  530. {
  531. "cell_type": "code",
  532. "execution_count": 16,
  533. "metadata": {},
  534. "outputs": [
  535. {
  536. "data": {
  537. "text/plain": [
  538. "<matplotlib.legend.Legend at 0x7f85e86d5640>"
  539. ]
  540. },
  541. "execution_count": 16,
  542. "metadata": {},
  543. "output_type": "execute_result"
  544. },
  545. {
  546. "data": {
  547. "image/png": "\n",
  548. "text/plain": [
  549. "<Figure size 432x288 with 1 Axes>"
  550. ]
  551. },
  552. "metadata": {
  553. "needs_background": "light"
  554. },
  555. "output_type": "display_data"
  556. }
  557. ],
  558. "source": [
  559. "# 画出这个函数的曲线\n",
  560. "x_sample = np.arange(-3, 3.1, 0.1)\n",
  561. "y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 + w_target[2] * x_sample ** 3\n",
  562. "\n",
  563. "plt.plot(x_sample, y_sample, label='real curve')\n",
  564. "plt.legend()"
  565. ]
  566. },
  567. {
  568. "cell_type": "markdown",
  569. "metadata": {},
  570. "source": [
  571. "接着构建数据集,需要 x 和 y,同时是一个三次多项式,所以取 $x,\\ x^2, x^3$"
  572. ]
  573. },
  574. {
  575. "cell_type": "code",
  576. "execution_count": 17,
  577. "metadata": {},
  578. "outputs": [],
  579. "source": [
  580. "# 构建数据 x 和 y\n",
  581. "# x 是一个如下矩阵 [x, x^2, x^3]\n",
  582. "# y 是函数的结果 [y]\n",
  583. "\n",
  584. "x_train = np.stack([x_sample ** i for i in range(1, 4)], axis=1)\n",
  585. "x_train = torch.from_numpy(x_train).float() # 转换成 float tensor\n",
  586. "\n",
  587. "y_train = torch.from_numpy(y_sample).float().unsqueeze(1) # 转化成 float tensor "
  588. ]
  589. },
  590. {
  591. "cell_type": "code",
  592. "execution_count": 18,
  593. "metadata": {},
  594. "outputs": [
  595. {
  596. "name": "stdout",
  597. "output_type": "stream",
  598. "text": [
  599. "torch.Size([61, 3])\n"
  600. ]
  601. }
  602. ],
  603. "source": [
  604. "print(x_train.size())"
  605. ]
  606. },
  607. {
  608. "cell_type": "markdown",
  609. "metadata": {},
  610. "source": [
  611. "接着我们可以定义需要优化的参数,就是前面这个函数里面的 $w_i$"
  612. ]
  613. },
  614. {
  615. "cell_type": "code",
  616. "execution_count": 19,
  617. "metadata": {},
  618. "outputs": [],
  619. "source": [
  620. "# 定义参数\n",
  621. "w = torch.randn((3, 1), dtype=torch.float, requires_grad=True)\n",
  622. "b = torch.zeros((1), dtype=torch.float, requires_grad=True)\n",
  623. "\n",
  624. "# 定义模型\n",
  625. "def multi_linear(x):\n",
  626. " return torch.mm(x, w) + b\n",
  627. "\n",
  628. "def get_loss(y_, y):\n",
  629. " return torch.mean((y_ - y) ** 2)"
  630. ]
  631. },
  632. {
  633. "cell_type": "markdown",
  634. "metadata": {},
  635. "source": [
  636. "我们可以画出没有更新之前的模型和真实的模型之间的对比"
  637. ]
  638. },
  639. {
  640. "cell_type": "code",
  641. "execution_count": 20,
  642. "metadata": {},
  643. "outputs": [
  644. {
  645. "data": {
  646. "text/plain": [
  647. "<matplotlib.legend.Legend at 0x7f85e8619220>"
  648. ]
  649. },
  650. "execution_count": 20,
  651. "metadata": {},
  652. "output_type": "execute_result"
  653. },
  654. {
  655. "data": {
  656. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqEklEQVR4nO3deXxU1f3/8deHEAib7CCCLLaKIMpi4Ata0YoCVQRUrLjiilatYFtBxJ/aKopCXWhFi4poRZZiFdoKgiDuQANCBSKCyiZbBFEwBEhyfn+cGRIwgZCZyZ2ZvJ+Px3nMdmfuZ7J85sy5536OOecQEZHkVCHoAEREJHaU5EVEkpiSvIhIElOSFxFJYkryIiJJTEleRCSJlTjJm9l4M9tmZssL3VfHzOaY2erQZe1Cjw0zszVmtsrMekQ7cBERObKj6clPAHoect89wFzn3InA3NBtzKw10B84JfScsWaWEnG0IiJyVCqWdEPn3Ptm1vyQu/sA54SuvwzMB4aG7p/snNsLfG1ma4BOwCeH20e9evVc8+aH7kJERA5n8eLF3zrn6hf1WImTfDEaOuc2AzjnNptZg9D9jYEFhbbbGLrvsJo3b05GRkaEIYmIlC9mtq64x2J14NWKuK/I+glmNtDMMswsIysrK0bhiIiUT5Em+a1m1gggdLktdP9G4PhC2zUBNhX1As65cc65dOdcev36RX7bEBGRUoo0yc8ABoSuDwCmF7q/v5lVNrMWwInAogj3JSIiR6nEY/JmNgl/kLWemW0EHgBGAlPN7EZgPXAZgHNuhZlNBVYCucDtzrm80gS4f/9+Nm7cSE5OTmmeLhFKS0ujSZMmpKamBh2KiJSCxVOp4fT0dHfogdevv/6aGjVqULduXcyKGuqXWHHOsX37dnbt2kWLFi2CDkdEimFmi51z6UU9FvdnvObk5CjBB8TMqFu3rr5FiSSwuE/ygBJ8gPSzF0lsCZHkRUSS2ZgxMGNGbF5bSb4ExowZQ6tWrbjqqquYMWMGI0eOBODNN99k5cqVB7abMGECmzYVzBS96aabDnpcRORQO3fCsGEwffoRNy2VSM94LRfGjh3LzJkzDxx87N27N+CTfK9evWjdujXgk3ybNm047rjjAHjhhReCCbiQvLw8UlJUNkgkXk2YANnZcPvtsXl99eSP4NZbb+Wrr76id+/ePPnkk0yYMIE77riDjz/+mBkzZnD33XfTrl07HnvsMTIyMrjqqqto164de/bs4ZxzzjlQpqF69eoMHz6ctm3b0rlzZ7Zu3QrAl19+SefOnenYsSP3338/1atXLzKOV155hdNOO422bdtyzTXXAHDdddcxbdq0A9uEnzt//nx++ctfcuWVV3LqqacydOhQxo4de2C7Bx98kD//+c8AjBo1io4dO3LaaafxwAMPRP8HKCLFys+HZ56BM86ADh1is4/E6skPHgxLl0b3Ndu1g6eeKvbh5557jlmzZvHuu+9Sr149JkyYAMAZZ5xB79696dWrF/369QNg5syZjB49mvT0n85k+vHHH+ncuTMjRoxgyJAhPP/889x3330MGjSIQYMGccUVV/Dcc88VGcOKFSsYMWIEH330EfXq1WPHjh1HfFuLFi1i+fLltGjRgk8//ZTBgwdz2223ATB16lRmzZrF7NmzWb16NYsWLcI5R+/evXn//ffp2rXrEV9fRCI3Zw6sWQN//GPs9qGefBmpVKkSvXr1AuD0009n7dq1AHzyySdcdtllAFx55ZVFPnfevHn069ePevXqAVCnTp0j7q9Tp04Hhpfat2/Ptm3b2LRpE8uWLaN27do0bdqU2bNnM3v2bNq3b0+HDh34/PPPWb16daRvVURK6K9/hYYNIdRPjInE6skfpscd71JTUw9MR0xJSSE3N7fEz3XOFTmVsWLFiuTn5x/YZt++fQceq1at2kHb9uvXj2nTprFlyxb69+9/4DnDhg3jlltuOer3IyKR+eor+M9/4L77oFKl2O1HPfkI1KhRg127dhV7uyQ6d+7M66+/DsDkyZOL3KZbt25MnTqV7du3AxwYrmnevDmLFy8GYPr06ezfv7/Y/fTv35/Jkyczbdq0A8NLPXr0YPz48ezevRuAb775hm3bthX7GiISPc8+CxUqQKz7WEryEejfvz+jRo2iffv2fPnll1x33XXceuutBw68lsRTTz3FE088QadOndi8eTM1a9b8yTannHIKw4cP5+yzz6Zt27b87ne/A+Dmm2/mvffeo1OnTixcuPAnvfdDX2PXrl00btyYRo0aAdC9e3euvPJKunTpwqmnnkq/fv2O+kNKRI5edja8+CJcfDE0PuJKG5GJ+9o1mZmZtGrVKqCIYi87O5sqVapgZkyePJlJkyYxPVYTZksp2X8HImVt/Hi48UaYPx/OPjvy1ztc7ZrEGpNPQosXL+aOO+7AOUetWrUYP3580CGJSAw5B3/5C7RpA2UxkU1JPmBnnXUWy5YtCzoMESkjn3ziZ4I/+yyURWkojcmLiJShv/4VjjkGrr66bPanJC8iUka++Qb+8Q+4/noo5uT2qFOSFxEpI3/9qy9lcOedZbdPJXkRkTKwezc895yfNnnCCWW3XyX5MtC8eXO+/fbboMMQkQBNmODLCv/+92W7XyX5o+CcO1BGQHGISEnl5cGTT0LnztClS9nuW0n+CNauXUurVq247bbb6NChAxs2bCi2PG/fvn05/fTTOeWUUxg3btwRX3vWrFl06NCBtm3b0q1bN8CXAR49evSBbdq0acPatWt/EsdDDz3EkCFDDmw3YcIEfvvb3wLw6quv0qlTJ9q1a8ctt9xCXl5etH4cIlIKM2b4WjVl3YuHKM2TN7O7gJsAB3wGXA9UBaYAzYG1wK+dc99Fsp8AKg0DsGrVKl566SXGjh172PK848ePp06dOuzZs4eOHTty6aWXUrdu3SJfMysri5tvvpn333+fFi1alKh8cOE4srKy6NKlC48//jgAU6ZMYfjw4WRmZjJlyhQ++ugjUlNTue2225g4cSLXXnvtUf5kRCRanngCmjeHvn3Lft8RJ3kzawzcCbR2zu0xs6lAf6A1MNc5N9LM7gHuAYZGur8gNGvWjM6dOwMcVJ4XYPfu3axevZquXbsyZswY3njjDQA2bNjA6tWri03yCxYsoGvXrgfKAZekfHDhOOrXr88JJ5zAggULOPHEE1m1ahVnnnkmzzzzDIsXL6Zjx44A7NmzhwYNGkT2AxCRUlu0CD780HcmKwZw+mm0dlkRqGJm+/E9+E3AMOCc0OMvA/OJMMkHVWm4cOGv4srzzp8/n3feeYdPPvmEqlWrcs4555CTk1Psa5akfDBw0GscWoDs8ssvZ+rUqZx88slcfPHFmBnOOQYMGMCjjz561O9TRKLviSegZk244YZg9h/xmLxz7htgNLAe2Ax875ybDTR0zm0ObbMZKLI7aWYDzSzDzDKysrIiDSfmiivP+/3331O7dm2qVq3K559/zoIFCw77Ol26dOG9997j66+/Bg4uH7xkyRIAlixZcuDxolxyySW8+eabTJo0icsvvxzwZYmnTZt2oGTwjh07WLduXWRvWkRKZd06mDYNBg6EGjWCiSEawzW1gT5AC2An8A8zK/EJu865ccA48FUoI40n1rp3705mZiZdQofIq1evzquvvkrPnj157rnnOO2002jZsuWBYZXi1K9fn3HjxnHJJZeQn59PgwYNmDNnDpdeeimvvPIK7dq1o2PHjpx00knFvkbt2rVp3bo1K1eupFOnTgC0bt2ahx9+mO7du5Ofn09qairPPPMMzZo1i94PQURKZMwYX58mNCciEBGXGjazy4CezrkbQ7evBToD3YBznHObzawRMN851/Jwr1UeSw0nAv0ORI7ezp3QtCn06gWvvRbbfR2u1HA0plCuBzqbWVXzg8zdgExgBjAgtM0AIL6KpIuIxNDYsbBrFwwNeLpJxMM1zrmFZjYNWALkAp/ih1+qA1PN7Eb8B8Flke5LRCQRZGf7k58uuADatg02lqjMrnHOPQA8cMjde/G9+mi8fpEzUST24mnlMJFE8cIL8O23cO+9QUeSAGe8pqWlsX37diWbADjn2L59O2lpaUGHIpIw9u2DUaP8qk9nnhl0NAmwMlSTJk3YuHEjiTC9MhmlpaXRpEmToMMQSRivvgobN/refDyI+ySfmpp64KxQEZF4lpcHI0dChw7QvXvQ0Xhxn+RFRBLF66/D6tX+BKh4OYwY92PyIiKJwDl45BE4+WS/MEi8UE9eRCQKZs6EZcvgpZegQhx1n+MoFBGRxBTuxTdtClddFXQ0B1NPXkQkQvPmwUcfwV/+AqmpQUdzMPXkRUQi4Bzcfz80bgw33RR0ND+lnryISARmz4aPP/a1auLxvEH15EVESinci2/aNLhFQY5EPXkRkVJ66y2/vN/zz0PlykFHUzT15EVESiHci2/RAgYMOPL2QVFPXkSkFGbMgCVL/Lz4eJtRU5h68iIiRyk/Hx54AH7+c7i6xIudBkM9eRGRo/TGG/7s1r//HSrGeRZVT15E5CiEe/EtW8IVVwQdzZHF+WeQiEh8mTgRVqyASZMgJSXoaI5MPXkRkRLKyYH77oPTT4df/zroaEpGPXkRkRJ65hlYvx7Gj4+vSpOHkyBhiogE67vvYMQI6NEDunULOpqSi0qSN7NaZjbNzD43s0wz62JmdcxsjpmtDl3Wjsa+RESCMHIk7NwJjz0WdCRHJ1o9+aeBWc65k4G2QCZwDzDXOXciMDd0W0Qk4WzYAE8/DddcA23bBh3N0Yk4yZvZMUBX4EUA59w+59xOoA/wcmizl4G+ke5LRCQI99/vLx96KNg4SiMaPfkTgCzgJTP71MxeMLNqQEPn3GaA0GWDop5sZgPNLMPMMrKysqIQjohI9Hz2Gbz8Mvz2t77aZKKJRpKvCHQAnnXOtQd+5CiGZpxz45xz6c659Pr160chHBGR6LnnHqhZE4YNCzqS0olGkt8IbHTOLQzdnoZP+lvNrBFA6HJbFPYlIlJm5s715YTvvRfq1Ak6mtKJOMk757YAG8ysZeiubsBKYAYQLsA5AJge6b5ERMpKbi4MGgQnnOCHahJVtE6G+i0w0cwqAV8B1+M/QKaa2Y3AeuCyKO1LRCTmnn3Wly944434XNavpKKS5J1zS4H0Ih5KoFMGRES8b7/1M2rOOw/69Ak6msjojFcRkUPcdx/s2uXnxpsFHU1klORFRApZuhTGjYM77oDWrYOOJnJK8iIiIc7BnXdC3brw4INBRxMdqkIpIhIydSp88AH87W9Qq1bQ0USHevIiIkB2Ntx9N7RvDzfeGHQ00aOevIgIvi7Nhg1+5adEWPGppNSTF5Fy77PPYPRouP56OOusoKOJLiV5ESnX8vPhllv8GPyoUUFHE30arhGRcm3cOPjkE19psm7doKOJPvXkRaTc2rzZV5k891y/IEgyUpIXkXLrrrsgJ8fXqUn0M1uLoyQvIuXSzJkwZQoMHw4nnRR0NLGjJC8i5U52Ntx2G5x8MgwZEnQ0saUDryJS7gwfDmvXwnvvQeXKQUcTW+rJi0i58v77vrrk7bdD165BRxN7SvIiUm7s3u1PeGrRAh57LOhoyoaGa0Sk3Bg6FL7+2g/TVKsWdDRlQz15ESkX5s6FsWNh8ODkK11wOEryIpL0fvgBbrgBWraEESOCjqZsabhGRJLe738PGzfCRx9BlSpBR1O21JMXkaT21lvwwgu+VnznzkFHU/ailuTNLMXMPjWzf4du1zGzOWa2OnRZO1r7EhEpic2b4brr4NRT4Y9/DDqaYESzJz8IyCx0+x5grnPuRGBu6LaISJnIz4drr/XTJqdMSf6TnooTlSRvZk2AC4EXCt3dB3g5dP1loG809iUiUhKjRsE778CYMdCqVdDRBCdaPfmngCFAfqH7GjrnNgOELhsU9UQzG2hmGWaWkZWVFaVwRKQ8W7gQ7rsPLrssudZrLY2Ik7yZ9QK2OecWl+b5zrlxzrl051x6/fr1Iw1HRMq577+HK66Axo39giDJWkK4pKIxhfJMoLeZXQCkAceY2avAVjNr5JzbbGaNgG1R2JeISLGcg1tvhfXrfY2aWrWCjih4EffknXPDnHNNnHPNgf7APOfc1cAMYEBoswHA9Ej3JSJyOC++CJMn+5k0Z5wRdDTxIZbz5EcC55vZauD80G0RkZj473/hjjvgvPP8kn7iRfWMV+fcfGB+6Pp2oFs0X19EpCjbtsGll8Kxx8KkSZCSEnRE8UNlDUQkoeXmQv/+kJXlyxbUqxd0RPFFSV5EEtqwYfDuuzBhAnToEHQ08Ue1a0QkYU2dCqNH+1WeBgw48vblkZK8iCSk5ct9+eAzz4Qnngg6mvilJC8iCWfrVrjoIqhRA/7xD6hUKeiI4pfG5EUkoWRnQ+/efkbNe+9Bo0ZBRxTflORFJGHk58M11/g58W+8AenpQUcU/5TkRSRh3HMP/POf8OST0KdP0NEkBo3Ji0hC+NvffPng22+HQYOCjiZxKMmLSNybNcsn9wsugKeeUmXJo6EkLyJx7eOPfcmCNm188bGKGmQ+KkryIhK3li71vffjjoO33/ZTJuXoKMmLSFz64gvo3h2OOcYv49ewYdARJSYleRGJO+vX+5LBAHPmQLNmwcaTyDS6JSJxZetWOP98+OEHX3isZcugI0psSvIiEje2bvU9+A0bfA++ffugI0p8SvIiEhc2bYJu3WDdOvjXv3zhMYmckryIBG7DBjj3XNiyxc+J79o16IiSh5K8iATq6699gt+xA2bPhi5dgo4ouSjJi0hg1qzxCX73bpg7VwXHYiHiKZRmdryZvWtmmWa2wswGhe6vY2ZzzGx16LJ25OGKSLJYvBh+8QvYs8fPolGCj41ozJPPBX7vnGsFdAZuN7PWwD3AXOfcicDc0G0REd56C84+G9LS4IMPoG3boCNKXhEneefcZufcktD1XUAm0BjoA7wc2uxloG+k+xKRxPfCC37Rj5YtYcECOPnkoCNKblE949XMmgPtgYVAQ+fcZvAfBECDaO5LRBKLc3D//XDzzf5kp/nz4dhjg44q+UUtyZtZdeB1YLBz7oejeN5AM8sws4ysrKxohSMicSQnB667Dh56CG68EWbMULGxshKVJG9mqfgEP9E598/Q3VvNrFHo8UbAtqKe65wb55xLd86l169fPxrhiEgc2bABzjoLXnkF/vQneP55SE0NOqryIxqzawx4Ech0zj1R6KEZwIDQ9QHA9Ej3JSKJ5b334PTTYdUqePNN+H//Twt+lLVo9OTPBK4BzjWzpaF2ATASON/MVgPnh26LSDngHDz9tC9TULcuLFqkNVmDEvHJUM65D4HiPpu7Rfr6IpJYdu2C3/wGJk6Evn3h5Zd9TXgJhurJi0jULFzoK0dOmuQPsr7+uhJ80JTkRSRieXkwYoSvHLl/vx+Lv+8+qKAMEzjVrhGRiGzYAFdfDe+/D/37w7PPQq1aQUclYfqcFZFSyc/30yFPPRWWLPFj76+9pgQfb5TkReSorVoFv/wlDBzox+CXLoVrr9X0yHikJC8iJbZvnx97b9sW/vc/X4dm3jz42c+CjkyKozF5ESmRefNg0CBYvhwuuwzGjFHtmUSgnryIHNYXX/iqkd26+cU9pk+HqVOV4BOFkryIFGnHDhg8GE45xVeMfPRRyMz0CV8Sh4ZrROQgP/wAf/kL/PnP8P33cNNNvrBYw4ZBRyaloSQvIoBP6GPGwJNPwnffQa9e8MgjfoqkJC4leZFy7ttvYexYn9x37vTDMfff76tHSuJTkhcppz77zFeKnDjRL+rRp49P7h06BB2ZRJOSvEg5kpsL//mPT+7vvgtVqviTmO680x9gleSjJC9SDqxY4csO/P3vsGULHH88jBzp11utUyfo6CSWlORFktSWLTBtmk/uGRlQsSJccIFfa/Wii/xtSX76NYskka++gjfe8O3jj/0KTe3awVNPwRVXQIMGQUcoZU1JXiSB5eT4ZP7OO/DWW7Bsmb+/bVt48EG45BJo0ybQECVgSvIiCWTvXvj0U1+7/Z134IMPfKJPSYEuXWD0aLj4YjjhhKAjlXihJC8Sp/Ly4Msvfa32BQt8+/RTXwkSfA/91lvhvPOga1eoUSPYeCU+KcmLBGzfPli71if0zEw/f/2zz2DlStizx29TtSp07Ah33QWdO/umAmFSEjFP8mbWE3gaSAFecM6NjPU+ReKBc5CdDdu3w9atsGlTQfvmG1i3zif2DRv8Kkthxx7rSwn85jf+sm1bf6nZMFIaMf2zMbMU4BngfGAj8F8zm+GcWxnL/UrZ2bPHJ7CtW30y27nz4LZrl0902dl+2+xsP4a8f78/MWf/ft/y8vzrOXfw66ek+MWgU1IKWmqqT3iFL1NToVIl3wrfPvR64eeFW4UKBc3MXzrnE2/hy337Ctrevf4yO9uX3929G3780V/u3Ol/Fjt2+O0OVaGCT+RNm8IvfuEX3Ai3k06C+vVj+iuTcibWfYNOwBrn3FcAZjYZ6ANEN8nv3OmnFjRp4s/yaNzY/1dLRJyDjRv9tLy1aw9umzb5edg//FD886tUgerVoVo1f71qVd+qVStIvuGkm5JSsHRc+DKcYPPyDm7hD4fcXJ9Ed+/2t/ftK7gMXw+38O1oMXNUruSompZP9ar5VK+SR/UqeVRLy+PEurl0bpFLnRr7qXPMfuoek0uD2vtpXH8vx9XbT4M6uaRUNP9Gw59c4U+yrBTYmVrwAwq3ypV9S0mJ3puQciHWSb4xsKHQ7Y3A/0V9L5mZcNVVB9/XsKFP+I0aHdyOO84/1qCBb1WrRj2cRLR5Myxe7M+MXLnS/0gzM30CDTPzn5/Nmvm518ce63+U4ct69fwizuFWuXIUAtu/v+ArwKFfCQpfHtpycn5y6bL3kL93P7l79pObk0tuTi77c/Jw+/aTvy/3oFZhXw4VyMNwVCAfw1GJfQdaisuHvfj2fRTeZ0lVrOh/sGlp/pMz/OkZvqxRw3+yhluNGnDMMVCzZkGrVQtq1/anutas6T9gJGnFOskXtazvQV/IzWwgMBCgadOmpdtLhw4+I23Y4NvGjQXX163z0xKysop+brVq/vtxvXr+j75uXd/q1PH/CIWzVs2a/h+mRg3f0tIScuXiXbvgk09g0SJ/JmRGhh8jDjvuOGjdGm64AVq1gp//HJo395+ZlSvju9PFJdVvsmFNEUn40AR9aCvu/tzc0r3JSpV84gsnw7Q0rHJlUqpUIaVyZSrXqQxpNQt6yJUr++dUrlzQcy5qzKfwOE9KSsFl4d54eNyncAP/1STcoOBrSuHLwmNYhb+GhMeIwi388y78s/vxRz9OFB4/Co8hHY6Z/9uuU8f/D9Sv71uDBv7y2GMPbnXqJOTffHlm7tBB0Gi+uFkX4EHnXI/Q7WEAzrlHi9o+PT3dZWRkxCaYffv8wPHmzbBtm29ZWf4yPKC8Y4e/DA8uH0lKik/21ar5Fh6PCLdwjystrSCRHJo4wmMV4VZ4kDg8QFxUoghfL5wgCo9t5OYeaLt+rMBHXzdi/lfNmL+uBRlbm5Dn/Nf+lsdsIr3Wl6Qf8wXp1TI5tdIqaubtKEgk4cvCrbSJt0KFn47dhFuVKj/9GR66XeGea+HHw/eHW1qahjXC8vL8p/r33xe0nTt9wfgdO3z77jv/N5+VdXAranwrNdX3Apo0KWiNG/sDDM2a+d5A3br6IChjZrbYOZde5GMxTvIVgS+AbsA3wH+BK51zK4raPqZJ/mjl5voB5/A/ReEjiYe2H38suhd6aILcu7egdxZja/gZ/+Ii/sVFfMBZ5JJKKvvoVCGDc1I/5uy0hXSqtoKaaXsP7sEW/kAKXw9/UIUTaOHrhRProUMH4cfCiT01Vf/8icI5/7e/das/+BJumzf7r30bNxa0nJyDn1u1qk/4J5xw8FHln/0MWrSI0jieFHa4JB/T4RrnXK6Z3QG8jZ9COb64BB93Klb0X01jUaLPuZ9OLSncA8/LO/irfXh6R+Ejk+HroeEBVyGFJcsrMeVfVZjxdhqr1viebOtW+fzuQji/B3TpUolq1c4Azoj+e5LkEh7GqVULWrYsfjvn/LeB9ev90Ojatf5y3Tp/xH7+/IOHjCpU8Im+ZcuC1qqVr3Nct25s31M5FdOe/NGKq558gvjyS7/ow2uvwapVvrN89tm+ymCvXjq9XQLmnB8S/fJL3774wv+hrlrlrxf+FtCwoU/2p5ziTwxo186f1lulSmDhJ4rAevISG9nZMGkSPP88LFzo7zv7bPj97+HSS1UfXOKImU/eDRvCGYd8g8zP95MjMjP9tK4VK2D5chg/vqD3X6GCP3mgXTvf0tP9RIvatcv6nSQs9eQTyKpV8OyzMGGCHy5t08av6tO/v5/5IpIU8vP9sM/Spb6s5rJl/vq6dQXb/OxnPuF37OhrPHToUK57/IEdeD1aSvI/5RzMnAlPPAFz5/rhmMsu86e8n3mmjmNKObJ9u6/WFp73m5HhjwWA/8do184n/DPO8KcSN2kSaLhlSUk+AeXl+VV9Hn3Ud2SOP94n9htu8N98RQQ/42fhwoIynYsW+fFM8NM5zzrLt65d/bBPkvaKlOQTyL59fh3Oxx6D1avh5JPhnnvgyit9Z0VEDiM3F/73P19o/4MPfOH98ImQjRrBuecWtObNAw01mpTkE0B+PkydCsOH+5lnp58O994LffvqrHORUnPOz+J57z14912YN8/P9gE/lbNHD9/OPdefzZ6glOTj3Lx5MGSIrx1z2ml+iOZXv0rab5YiwXHOF2eaN88vrTVvni//ULGiX1qrZ0+48EL/j5hA/4BK8nFqxQr4wx9g1iw/5v7ww77Oms7IFykj+/b5RXLfftu3Tz/19x9/vD/R5KKL4Je/9Gd0xzEl+Tjz44/wpz/5GTPVq/shmjvuiPu/I5Hkt2UL/Oc/8O9/w+zZ/iBu1aq+h3/JJb6XX6tW0FH+hJJ8HJk+He6808/8uv56ePxxX/xPROJMTo4vyzBjBrz5pq/bk5oK3br5hN+3b9ys8HK4JK9DemVk/Xro08f/XdSo4Q/8jx+vBC8St9LSfA9+7FhfiO3jj2HwYD/tbeBAP1unRw946SVfyTNOKcnHmHP+DNU2bfxxnscf98N+v/hF0JGJSIlVqOAPzD7+uE/yS5f62RJr1hScvHLRRTB5csHq63FCST6Gtm3z3+quv96fjLd8Odx9t+a7iyQ0M7+6+iOP+CS/aJEfg126FK64wif8G27wUzYLr9AeECX5GJkxwxfSe+stGD3a/75btAg6KhGJKjNfP2f0aF9bZ9486NfPn64ePuHqvvv8yS8BUZKPsj17/HBdnz5+AZ3Fi311SE2LFElyFSr46Zbjx/tZOpMm+XHaRx/1BdXOPdfXBS/j4Rwl+Sj66itfNOz5530pgoUL/e9YRMqZqlV9edi33vI9/Ice8pU1r77a9/7uvNOflFUGlOSjZMYMX+107Vo/xfbRR/2qeiJSzjVp4ods1qzxpWQvuAD+9je/OMo558CUKf6krBhRko9Qbq7vtffpAz//ua+EeuGFQUclInGnQoWCIZuNG2HkSD+3OrwgxKhRsdltTF61nPjuOz9N9rHH4JZb4MMPk6qwnYjESv36MHSo792/9Rb83//FbK69lv8rpS+/9D32r77y50Jcd13QEYlIwqlQwVcj/NWv/Ek1MaAkXwoffujPXHXOn+DUtWvQEYlIwotR1cuIhmvMbJSZfW5m/zOzN8ysVqHHhpnZGjNbZWY9Io40Tkyc6EtX1K3rF6JRgheReBbpmPwcoI1z7jTgC2AYgJm1BvoDpwA9gbFmltAzxZ3zlSOvvtovIfnJJ3DiiUFHJSJyeBEleefcbOdcbujmAiC8cm4fYLJzbq9z7mtgDdApkn0FKT/f1yV64AG49lpfdrpOnaCjEhE5smjOrrkBmBm63hjYUOixjaH7fsLMBppZhpllZIXXYowjublw440wZgzcdZcvNqb57yKSKI6Y5M3sHTNbXkTrU2ib4UAuMDF8VxEvVeShY+fcOOdcunMuvX6c1GYO27vXT2GdMAEefBD+/OeEWhFMROTIs2ucc+cd7nEzGwD0Arq5ghVINgLHF9qsCbCptEEGITvbV5B8+2148kk/XCMikmginV3TExgK9HbOZRd6aAbQ38wqm1kL4ERgUST7Kks//uinrc6ZAy++qAQvIokr0nnyfwUqA3PMj2MscM7d6pxbYWZTgZX4YZzbnXN5Ee6rTOzZA717+7nwr70Gl18edEQiIqUXUZJ3zv38MI+NAEZE8vplbe9eP0Tz7rvw978rwYtI4tMZryH798Ovfw2zZsELL8BVVwUdkYhI5FSgDD9N8qqrfLngZ57xUyZFRJJBuU/y+fl+OcZ//MNPkbzttqAjEhGJnnKf5IcO9ePvDz0Ev/td0NGIiERXuU7yY8b49Xdvvx2GDw86GhGR6Cu3Sf711/3897594emndSariCSncpnkP/rIH2jt3NnPhU9J6PqYIiLFK3dJ/vPP4aKLoFkzP5umSpWgIxIRiZ1yleSzsny5gtRUmDkT6tULOiIRkdgqNydD7d8P/frBli3w/vtwwglBRyQiEnvlJskPHuyT+8SJ0LFj0NGIiJSNcjFcM24cjB0LQ4bAlVcGHY2ISNlJ+iT/wQd+HnzPnvDII0FHIyJStpI6ya9fD5deCi1awKRJmiopIuVP0ib5PXvg4oshJwemT4datYKOSESk7CXtgdfBg2HJEj8XvlWroKMREQlGUvbkJ03yB1uHDPEnPomIlFdJl+S/+AIGDoQzzoCHHw46GhGRYCVVks/J8as7VaoEkyf7M1tFRMqzpBqTv+suWLYM/v1vOP74oKMREQle0vTkp0yB556Du++GCy8MOhoRkfgQlSRvZn8wM2dm9QrdN8zM1pjZKjPrEY39FGf1arj5ZujSBUaMiOWeREQSS8TDNWZ2PHA+sL7Qfa2B/sApwHHAO2Z2knMuL9L9FaViRZ/gn39e4/AiIoVFoyf/JDAEcIXu6wNMds7tdc59DawBOkVhX0Vq0QLefhuaNo3VHkREElNESd7MegPfOOeWHfJQY2BDodsbQ/cV9RoDzSzDzDKysrIiCUdERA5xxOEaM3sHOLaIh4YD9wLdi3paEfe5Iu7DOTcOGAeQnp5e5DYiIlI6R0zyzrnzirrfzE4FWgDLzK+C3QRYYmad8D33wpMYmwCbIo5WRESOSqmHa5xznznnGjjnmjvnmuMTewfn3BZgBtDfzCqbWQvgRGBRVCIWEZESi8nJUM65FWY2FVgJ5AK3x2pmjYiIFC9qST7Umy98ewSgWesiIgFKmjNeRUTkp5TkRUSSmDkXP7MWzSwLWBfBS9QDvo1SOEFKlvcBei/xKFneB+i9hDVzztUv6oG4SvKRMrMM51x60HFEKlneB+i9xKNkeR+g91ISGq4REUliSvIiIkks2ZL8uKADiJJkeR+g9xKPkuV9gN7LESXVmLyIiBws2XryIiJSSFIleTN7yMz+Z2ZLzWy2mR0XdEylZWajzOzz0Pt5w8xqBR1TaZnZZWa2wszyzSzhZkKYWc/QCmdrzOyeoOMpLTMbb2bbzGx50LFEysyON7N3zSwz9Lc1KOiYSsPM0sxskZktC72PP0Z9H8k0XGNmxzjnfghdvxNo7Zy7NeCwSsXMugPznHO5ZvYYgHNuaMBhlYqZtQLygb8Bf3DOZQQcUomZWQrwBX71s43Af4ErnHMrAw2sFMysK7AbeMU51yboeCJhZo2ARs65JWZWA1gM9E2034v5Er7VnHO7zSwV+BAY5JxbEK19JFVPPpzgQ6pRTA37ROCcm+2cyw3dXIAv15yQnHOZzrlVQcdRSp2ANc65r5xz+4DJ+JXPEo5z7n1gR9BxRINzbrNzbkno+i4gk2IWJopnztsdupkaalHNW0mV5AHMbISZbQCuAu4POp4ouQGYGXQQ5VSJVzmTYJhZc6A9sDDgUErFzFLMbCmwDZjjnIvq+0i4JG9m75jZ8iJaHwDn3HDn3PHAROCOYKM9vCO9l9A2w/HlmicGF+mRleS9JKgSr3ImZc/MqgOvA4MP+SafMJxzec65dvhv653MLKpDaTGpJx9Lxa1UVYTXgP8AD8QwnIgc6b2Y2QCgF9DNxfnBk6P4vSQarXIWp0Jj2K8DE51z/ww6nkg553aa2XygJxC1g+MJ15M/HDM7sdDN3sDnQcUSKTPrCQwFejvnsoOOpxz7L3CimbUws0pAf/zKZxKg0AHLF4FM59wTQcdTWmZWPzxzzsyqAOcR5byVbLNrXgda4mdyrANudc59E2xUpWNma4DKwPbQXQsSeKbQxcBfgPrATmCpc65HoEEdBTO7AHgKSAHGhxbESThmNgk4B1/tcCvwgHPuxUCDKiUz+wXwAfAZ/v8d4F7n3FvBRXX0zOw04GX831YFYKpz7k9R3UcyJXkRETlYUg3XiIjIwZTkRUSSmJK8iEgSU5IXEUliSvIiIklMSV5EJIkpyYuIJDEleRGRJPb/AT43r/9qplA4AAAAAElFTkSuQmCC\n",
  657. "text/plain": [
  658. "<Figure size 432x288 with 1 Axes>"
  659. ]
  660. },
  661. "metadata": {
  662. "needs_background": "light"
  663. },
  664. "output_type": "display_data"
  665. }
  666. ],
  667. "source": [
  668. "# 画出更新之前的模型\n",
  669. "y_pred = multi_linear(x_train)\n",
  670. "\n",
  671. "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
  672. "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
  673. "plt.legend()"
  674. ]
  675. },
  676. {
  677. "cell_type": "markdown",
  678. "metadata": {},
  679. "source": [
  680. "可以发现,这两条曲线之间存在差异,我们计算一下他们之间的误差"
  681. ]
  682. },
  683. {
  684. "cell_type": "code",
  685. "execution_count": 21,
  686. "metadata": {},
  687. "outputs": [
  688. {
  689. "name": "stdout",
  690. "output_type": "stream",
  691. "text": [
  692. "tensor(1144.2654, grad_fn=<MeanBackward0>)\n"
  693. ]
  694. }
  695. ],
  696. "source": [
  697. "# 计算误差,这里的误差和一元的线性模型的误差是相同的,前面已经定义过了 get_loss\n",
  698. "loss = get_loss(y_pred, y_train)\n",
  699. "print(loss)"
  700. ]
  701. },
  702. {
  703. "cell_type": "code",
  704. "execution_count": 22,
  705. "metadata": {},
  706. "outputs": [],
  707. "source": [
  708. "# 自动求导\n",
  709. "loss.backward()"
  710. ]
  711. },
  712. {
  713. "cell_type": "code",
  714. "execution_count": 23,
  715. "metadata": {},
  716. "outputs": [
  717. {
  718. "name": "stdout",
  719. "output_type": "stream",
  720. "text": [
  721. "tensor([[ -94.7455],\n",
  722. " [-139.1247],\n",
  723. " [-629.8584]])\n",
  724. "tensor([-25.7413])\n"
  725. ]
  726. }
  727. ],
  728. "source": [
  729. "# 查看一下 w 和 b 的梯度\n",
  730. "print(w.grad)\n",
  731. "print(b.grad)"
  732. ]
  733. },
  734. {
  735. "cell_type": "code",
  736. "execution_count": 24,
  737. "metadata": {},
  738. "outputs": [],
  739. "source": [
  740. "# 更新一下参数\n",
  741. "w.data = w.data - 0.001 * w.grad.data\n",
  742. "b.data = b.data - 0.001 * b.grad.data"
  743. ]
  744. },
  745. {
  746. "cell_type": "code",
  747. "execution_count": 25,
  748. "metadata": {},
  749. "outputs": [
  750. {
  751. "data": {
  752. "text/plain": [
  753. "<matplotlib.legend.Legend at 0x7f85e860c9a0>"
  754. ]
  755. },
  756. "execution_count": 25,
  757. "metadata": {},
  758. "output_type": "execute_result"
  759. },
  760. {
  761. "data": {
  762. "image/png": "\n",
  763. "text/plain": [
  764. "<Figure size 432x288 with 1 Axes>"
  765. ]
  766. },
  767. "metadata": {
  768. "needs_background": "light"
  769. },
  770. "output_type": "display_data"
  771. }
  772. ],
  773. "source": [
  774. "# 画出更新一次之后的模型\n",
  775. "y_pred = multi_linear(x_train)\n",
  776. "\n",
  777. "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
  778. "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
  779. "plt.legend()"
  780. ]
  781. },
  782. {
  783. "cell_type": "markdown",
  784. "metadata": {},
  785. "source": [
  786. "因为只更新了一次,所以两条曲线之间的差异仍然存在,我们进行 100 次迭代"
  787. ]
  788. },
  789. {
  790. "cell_type": "code",
  791. "execution_count": 26,
  792. "metadata": {},
  793. "outputs": [
  794. {
  795. "name": "stdout",
  796. "output_type": "stream",
  797. "text": [
  798. "epoch 20, Loss: 65.56586\n",
  799. "epoch 40, Loss: 15.41177\n",
  800. "epoch 60, Loss: 3.70702\n",
  801. "epoch 80, Loss: 0.97122\n",
  802. "epoch 100, Loss: 0.32874\n"
  803. ]
  804. }
  805. ],
  806. "source": [
  807. "# 进行 100 次参数更新\n",
  808. "for e in range(100):\n",
  809. " y_pred = multi_linear(x_train)\n",
  810. " loss = get_loss(y_pred, y_train)\n",
  811. " \n",
  812. " w.grad.data.zero_()\n",
  813. " b.grad.data.zero_()\n",
  814. " loss.backward()\n",
  815. " \n",
  816. " # 更新参数\n",
  817. " w.data = w.data - 0.001 * w.grad.data\n",
  818. " b.data = b.data - 0.001 * b.grad.data\n",
  819. " if (e + 1) % 20 == 0:\n",
  820. " print('epoch {}, Loss: {:.5f}'.format(e+1, loss.data.item()))"
  821. ]
  822. },
  823. {
  824. "cell_type": "markdown",
  825. "metadata": {},
  826. "source": [
  827. "可以看到更新完成之后 loss 已经非常小了,我们画出更新之后的曲线对比"
  828. ]
  829. },
  830. {
  831. "cell_type": "code",
  832. "execution_count": 27,
  833. "metadata": {},
  834. "outputs": [
  835. {
  836. "data": {
  837. "text/plain": [
  838. "<matplotlib.legend.Legend at 0x7f85e8584ee0>"
  839. ]
  840. },
  841. "execution_count": 27,
  842. "metadata": {},
  843. "output_type": "execute_result"
  844. },
  845. {
  846. "data": {
  847. "image/png": "\n",
  848. "text/plain": [
  849. "<Figure size 432x288 with 1 Axes>"
  850. ]
  851. },
  852. "metadata": {
  853. "needs_background": "light"
  854. },
  855. "output_type": "display_data"
  856. }
  857. ],
  858. "source": [
  859. "# 画出更新之后的结果\n",
  860. "y_pred = multi_linear(x_train)\n",
  861. "\n",
  862. "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
  863. "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
  864. "plt.legend()"
  865. ]
  866. },
  867. {
  868. "cell_type": "markdown",
  869. "metadata": {},
  870. "source": [
  871. "可以看到,经过 100 次更新之后,可以看到拟合的线和真实的线已经完全重合了"
  872. ]
  873. },
  874. {
  875. "cell_type": "markdown",
  876. "metadata": {
  877. "collapsed": true
  878. },
  879. "source": [
  880. "## 5. 练习题\n",
  881. "\n",
  882. "上面的例子是一个三次的多项式,尝试使用二次的多项式去拟合它,看看最后能做到多好\n",
  883. "\n",
  884. "**提示:参数 `w = torch.randn(2, 1)`,同时重新构建 x 数据集**"
  885. ]
  886. }
  887. ],
  888. "metadata": {
  889. "kernelspec": {
  890. "display_name": "Python 3 (ipykernel)",
  891. "language": "python",
  892. "name": "python3"
  893. },
  894. "language_info": {
  895. "codemirror_mode": {
  896. "name": "ipython",
  897. "version": 3
  898. },
  899. "file_extension": ".py",
  900. "mimetype": "text/x-python",
  901. "name": "python",
  902. "nbconvert_exporter": "python",
  903. "pygments_lexer": "ipython3",
  904. "version": "3.9.7"
  905. }
  906. },
  907. "nbformat": 4,
  908. "nbformat_minor": 2
  909. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。