{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 线性模型和梯度下降\n", "这是神经网络的第一课,我们会学习一个非常简单的模型,线性回归,同时也会学习一个优化算法-梯度下降法,对这个模型进行优化。线性回归是监督学习里面一个非常简单的模型,同时梯度下降也是深度学习中应用最广的优化算法,我们将从这里开始我们的深度学习之旅" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 一元线性回归\n", "一元线性模型非常简单,假设我们有变量 $x_i$ 和目标 $y_i$,每个 i 对应于一个数据点,希望建立一个模型\n", "\n", "$$\n", "\\hat{y}_i = w x_i + b\n", "$$\n", "\n", "$\\hat{y}_i$ 是我们预测的结果,希望通过 $\\hat{y}_i$ 来拟合目标 $y_i$,通俗来讲就是找到这个函数拟合 $y_i$ 使得误差最小,即最小化\n", "\n", "$$\n", "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "那么如何最小化这个误差呢?\n", "\n", "这里需要用到**梯度下降**,这是我们接触到的第一个优化算法,非常简单,但是却非常强大,在深度学习中被大量使用,所以让我们从简单的例子出发了解梯度下降法的原理" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 梯度下降法\n", "在梯度下降法中,我们首先要明确梯度的概念,随后我们再了解如何使用梯度进行下降。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 梯度\n", "梯度在数学上就是导数,如果是一个多元函数,那么梯度就是偏导数。比如一个函数f(x, y),那么 f 的梯度就是 \n", "\n", "$$\n", "(\\frac{\\partial f}{\\partial x},\\ \\frac{\\partial f}{\\partial y})\n", "$$\n", "\n", "可以称为 grad f(x, y) 或者 $\\nabla f(x, y)$。具体某一点 $(x_0,\\ y_0)$ 的梯度就是 $\\nabla f(x_0,\\ y_0)$。\n", "\n", "下面这个图片是 $f(x) = x^2$ 这个函数在 x=1 处的梯度\n", "\n", "![](https://ws3.sinaimg.cn/large/006tNc79ly1fmarbuh2j3j30ba0b80sy.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "梯度有什么意义呢?从几何意义来讲,一个点的梯度值是这个函数变化最快的地方,具体来说,对于函数 f(x, y),在点 $(x_0, y_0)$ 处,沿着梯度 $\\nabla f(x_0,\\ y_0)$ 的方向,函数增加最快,也就是说沿着梯度的方向,我们能够更快地找到函数的极大值点,或者反过来沿着梯度的反方向,我们能够更快地找到函数的最小值点。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 梯度下降法\n", "有了对梯度的理解,我们就能了解梯度下降发的原理了。上面我们需要最小化这个误差,也就是需要找到这个误差的最小值点,那么沿着梯度的反方向我们就能够找到这个最小值点。\n", "\n", "我们可以来看一个直观的解释。比如我们在一座大山上的某处位置,由于我们不知道怎么下山,于是决定走一步算一步,也就是在每走到一个位置的时候,求解当前位置的梯度,沿着梯度的负方向,也就是当前最陡峭的位置向下走一步,然后继续求解当前位置梯度,向这一步所在位置沿着最陡峭最易下山的位置走一步。这样一步步的走下去,一直走到觉得我们已经到了山脚。当然这样走下去,有可能我们不能走到山脚,而是到了某一个局部的山峰低处。\n", "\n", "类比我们的问题,就是沿着梯度的反方向,我们不断改变 w 和 b 的值,最终找到一组最好的 w 和 b 使得误差最小。\n", "\n", "在更新的时候,我们需要决定每次更新的幅度,比如在下山的例子中,我们需要每次往下走的那一步的长度,这个长度称为学习率,用 $\\eta$ 表示,这个学习率非常重要,不同的学习率都会导致不同的结果,学习率太小会导致下降非常缓慢,学习率太大又会导致跳动非常明显,可以看看下面的例子\n", "\n", "![](https://ws2.sinaimg.cn/large/006tNc79ly1fmgn23lnzjg30980gogso.gif)\n", "\n", "可以看到上面的学习率较为合适,而下面的学习率太大,就会导致不断跳动\n", "\n", "最后我们的更新公式就是\n", "\n", "$$\n", "w := w - \\eta \\frac{\\partial f(w,\\ b)}{\\partial w} \\\\\n", "b := b - \\eta \\frac{\\partial f(w,\\ b)}{\\partial b}\n", "$$\n", "\n", "通过不断地迭代更新,最终我们能够找到一组最优的 w 和 b,这就是梯度下降法的原理。\n", "\n", "最后可以通过这张图形象地说明一下这个方法\n", "\n", "![](https://ws3.sinaimg.cn/large/006tNc79ly1fmarxsltfqj30gx091gn4.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上面是原理部分,下面通过一个例子来进一步学习线性模型" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import numpy as np\n", "from torch.autograd import Variable\n", "\n", "torch.manual_seed(2017)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# 读入数据 x 和 y\n", "x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],\n", " [9.779], [6.182], [7.59], [2.167], [7.042],\n", " [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)\n", "\n", "y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],\n", " [3.366], [2.596], [2.53], [1.221], [2.827],\n", " [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPrElEQVR4nO3df4gc933G8ec5SdS+OMRtdSSqrLstNKQkprbSxbVrKMauwU2NXagLLlvXKSkHIW3sYih1BC4JXEmhuD9iiFnsNEq7uAmySV0TtxWJITE0CitV/iUZYqjubFepznYt293UraJP/5gVkq67t7On2ZvZ77xfsMzMd0e7H4a7R9+b/cysI0IAgOk3U3YBAIBiEOgAkAgCHQASQaADQCIIdABIxNay3nj79u3RaDTKensAmEoHDx58LSLmBj1XWqA3Gg11u92y3h4AppLt5WHPccoFABJBoANAIkYGuu2LbH/P9jO2X7D92QH7fNz2qu3D/cfvTqZcAMAwec6hvyvp+oh4x/Y2SU/bfjIivrtmv69GxO8VXyIAII+RgR7ZzV7e6W9u6z+4AQwAVEyuc+i2t9g+LOmEpP0RcWDAbr9u+1nb+2zvGvI6i7a7trurq6sbrxoAplCnIzUa0sxMtux0in39XIEeET+KiCslXSbpKtuXr9nlHyQ1IuLnJO2XtHfI67QjohkRzbm5gW2UAJCkTkdaXJSWl6WIbLm4WGyoj9XlEhFvSnpK0k1rxl+PiHf7mw9J+vlCqgOAROzZI/V654/1etl4UfJ0uczZvrS/frGkGyW9uGafHeds3iLpaHElAsD0W1kZb3wj8nS57JC01/YWZf8BfC0inrD9OUndiHhc0qdt3yLplKQ3JH28uBIBYPrNz2enWQaNFyVPl8uzknYPGL/vnPV7Jd1bXFkAkJalpeyc+bmnXWZns/GicKUoAGyCVktqt6WFBcnOlu12Nl6U0m7OBQB102oVG+BrMUMHgEQQ6ACSNekLeaqGUy4AknTmQp4zH0KeuZBHmuxpjzIxQweQpM24kKdqCHQASdqMC3mqhkAHkKRhF+wUeSFP1RDoAJK0tJRduHOuoi/kqRoCHUCSNuNCnqqhywVAsiZ9IU/VMEMHgEQQ6ACQCAIdABJBoANAIgh0AEgEgQ4AiSDQASARBDowhrrdjhXThQuLgJzqeDtWTBdm6EBOdbwdK6YLgQ7kVMfbsWK6EOhATnW8HSumC4EO5FTH27FiuhDoQE51vB0rpgtdLsAY6nY7VkwXZugAkAgCHQASQaADQCIIdABIBIEOAIkg0AEgEQQ6ACSCQAeARBDoAJCIkYFu+yLb37P9jO0XbH92wD4/Zvurtl+yfcB2YyLVAgCGyjNDf1fS9RFxhaQrJd1k++o1+3xC0n9GxM9I+nNJf1polQCAkUYGemTe6W9u6z9izW63StrbX98n6QbbLqxKAMBIuc6h295i+7CkE5L2R8SBNbvslPSyJEXEKUknJf3kgNdZtN213V1dXb2gwgEA58sV6BHxo4i4UtJlkq6yfflG3iwi2hHRjIjm3NzcRl4CADDEWF0uEfGmpKck3bTmqVcl7ZIk21slvU/S6wXUBwDIKU+Xy5ztS/vrF0u6UdKLa3Z7XNKd/fXbJH0rItaeZwcATFCeL7jYIWmv7S3K/gP4WkQ8YftzkroR8bikhyX9je2XJL0h6faJVQwAGGhkoEfEs5J2Dxi/75z1/5b0G8WWBgAYB1eKAonrdKRGQ5qZyZadTtkVYVL4TlEgYZ2OtLgo9XrZ9vJyti3x3agpYoYOJGzPnrNhfkavl40jPQQ6kLCVlfHGMd0IdCBh8/PjjWO6EehAwpaWpNnZ88dmZ7NxpIdAByakCt0lrZbUbksLC5KdLdttPhBNFV0uwARUqbuk1SLA64IZOjABdJegDAQ6MAF0l6AMBDowAXSXoAwEOjABdJegDAR6TVSh46JO6C5BGehyqYEqdVzUCd0l2GzM0GuAjgugHgj0GqDjAqgHAr0G6LgA6oFArwE6LoB6INBrgI4LoB7ocqkJOi6A9DFDB4BEEOgAkAgCHQASQaADQCIIdABIBIEOAIkg0AEgEQQ6ksetg1EXXFiEpHHrYNQJM3QkjVsHo04IdCSNWwejTgh0JI1bB6NOCHQkjVsHo04IdCQtpVsH062DUehyQfJSuHUw3TrIY+QM3fYu20/ZPmL7Bdt3DdjnOtsnbR/uP+6bTLlAPdGtgzzyzNBPSbonIg7Zfq+kg7b3R8SRNft9JyJuLr5EAHTrII+RM/SIOB4Rh/rrb0s6KmnnpAsDcBbdOshjrA9FbTck7ZZ0YMDT19h+xvaTtj8y5N8v2u7a7q6uro5fLVBTdOsgj9yBbvsSSY9Kujsi3lrz9CFJCxFxhaQvSPr6oNeIiHZENCOiOTc3t8GSgfpJqVsHk+OIGL2TvU3SE5L+KSLuz7H/MUnNiHht2D7NZjO63e4YpQIAbB+MiOag5/J0uVjSw5KODgtz2x/o7yfbV/Vf9/WNlwwAGFeeLpdrJd0h6Tnbh/tjn5E0L0kR8aCk2yR90vYpST+UdHvkmfoDAAozMtAj4mlJHrHPA5IeKKooAMD4uPQfABJBoANAIgh0AEgEgQ4AiSDQASARBDoAJIJAB4BEEOgAkAgCHQASQaADQCIIdABIBIEOAIkg0AEgEQQ6ACSCQAeARBDoAJAIAh0AEkGgA0AiCHQASASBDgCJINABIBEEOgAkgkAHgEQQ6ACQCAIdABJBoANAIgh0AEgEgY7SdTpSoyHNzGTLTqfsioDptLXsAlBvnY60uCj1etn28nK2LUmtVnl1AdOIGTpKtWfP2TA/o9fLxgGMh0BHqVZWxhsHMByBjlLNz483DmA4Ah2lWlqSZmfPH5udzcYBjIdAR6laLandlhYWJDtbttt8IApsBF0uKF2rRYADRRg5Q7e9y/ZTto/YfsH2XQP2se2/sv2S7Wdtf3Qy5QIAhskzQz8l6Z6IOGT7vZIO2t4fEUfO2edXJH2w//gFSV/sLwEAm2TkDD0ijkfEof7625KOStq5ZrdbJX0lMt+VdKntHYVXCwAYaqwPRW03JO2WdGDNUzslvXzO9iv6/6Ev24u2u7a7q6urY5YKAFhP7kC3fYmkRyXdHRFvbeTNIqIdEc2IaM7NzW3kJQAAQ+QKdNvblIV5JyIeG7DLq5J2nbN9WX8MALBJ8nS5WNLDko5GxP1Ddntc0m/3u12ulnQyIo4XWCcAYIQ8XS7XSrpD0nO2D/fHPiNpXpIi4kFJ35D0MUkvSepJ+p3CKwUArGtkoEfE05I8Yp+Q9KmiigIAjI9L/wEgEQQ6ACSCQAeARBDoAJAIAh0AEkGgA0AiCHQASASBDgCJINABIBEEOgAkgkAHgEQQ6ACQCAIdABJBoANAIgh0AEgEgQ4AiSDQASARBDoAJIJAL1CnIzUa0sxMtux0yq4Im42fAZQpz5dEI4dOR1pclHq9bHt5OduWpFarvLqwefgZQNmcfb/z5ms2m9Htdkt570loNLJf4LUWFqRjxza7GpSBnwFsBtsHI6I56DlOuRRkZWW8caSHnwGUjUAvyPz8eON1U4dzy/wMoGwEekGWlqTZ2fPHZmez8bo7c255eVmKOHtuObVQ52cAZSPQC9JqSe12dr7UzpbtNh+GSdKePWc/KDyj18vGU8LPAMrGh6KYuJmZbGa+li2dPr359QDTjA9FUSrOLQObg0DHxHFuGdgcBDomjnPLwOYg0BNR9bbAViu7uOb06WxJmAPF49L/BHDJOQCJGXoS6tIWCGB9BHoCuOQcgESgJ4G2QAASgZ4E2gIBSDkC3faXbJ+w/fyQ56+zfdL24f7jvuLLxHpoCwQg5ety+bKkByR9ZZ19vhMRNxdSETak1SLAgbobOUOPiG9LemMTagEAXICizqFfY/sZ20/a/siwnWwv2u7a7q6urhb01gAAqZhAPyRpISKukPQFSV8ftmNEtCOiGRHNubm5At4aAHDGBQd6RLwVEe/0178haZvt7RdcGQBgLBcc6LY/YNv99av6r/n6hb4uAGA8I7tcbD8i6TpJ222/IumPJW2TpIh4UNJtkj5p+5SkH0q6Pcr61gwAqLGRgR4Rvzni+QeUtTUCAErElaIAkAgCHQASQaADQCIIdABIBIEOAIkg0AEgEQQ6ACSCQAeARBDoAJAIAn1MnY7UaEgzM9my0ym7IgDI5PnGIvR1OtLiotTrZdvLy9m2xLcFASgfM/Qx7NlzNszP6PWycQAoG4E+hpWV8cYBYDMR6GOYnx9vHAA2E4E+hqUlaXb2/LHZ2WwcAMpGoI+h1ZLabWlhQbKzZbvNB6IAqmGqAr0KLYOtlnTsmHT6dLYkzAFUxdS0LdIyCADrm5oZOi2DALC+qQl0WgYBYH1TE+i0DALA+qYm0GkZBID1TU2g0zIIAOubmi4XKQtvAhwABpuaGToAYH0EOgAkgkAHgEQQ6ACQCAIdABLhiCjnje1VScs5dt0u6bUJlzONOC7DcWwG47gMN03HZiEi5gY9UVqg52W7GxHNsuuoGo7LcBybwTguw6VybDjlAgCJINABIBHTEOjtsguoKI7LcBybwTguwyVxbCp/Dh0AkM80zNABADkQ6ACQiEoGuu1dtp+yfcT2C7bvKrumKrG9xfa/2n6i7FqqxPaltvfZftH2UdvXlF1TVdj+g/7v0vO2H7F9Udk1lcX2l2yfsP38OWM/YXu/7e/3lz9eZo0bVclAl3RK0j0R8WFJV0v6lO0Pl1xTldwl6WjZRVTQX0r6x4j4WUlXiGMkSbK9U9KnJTUj4nJJWyTdXm5VpfqypJvWjP2RpG9GxAclfbO/PXUqGegRcTwiDvXX31b2i7mz3KqqwfZlkn5V0kNl11Iltt8n6ZckPSxJEfE/EfFmqUVVy1ZJF9veKmlW0r+XXE9pIuLbkt5YM3yrpL399b2Sfm0zaypKJQP9XLYbknZLOlByKVXxF5L+UNLpkuuomp+WtCrpr/unox6y/Z6yi6qCiHhV0p9JWpF0XNLJiPjncquqnPdHxPH++g8kvb/MYjaq0oFu+xJJj0q6OyLeKruestm+WdKJiDhYdi0VtFXSRyV9MSJ2S/ovTemfzUXrnw++Vdl/ej8l6T22f6vcqqorsl7uqeznrmyg296mLMw7EfFY2fVUxLWSbrF9TNLfSbre9t+WW1JlvCLplYg485fcPmUBD+mXJf1bRKxGxP9KekzSL5ZcU9X8h+0dktRfnii5ng2pZKDbtrJzoUcj4v6y66mKiLg3Ii6LiIayD7W+FRHMtCRFxA8kvWz7Q/2hGyQdKbGkKlmRdLXt2f7v1g3iA+O1Hpd0Z3/9Tkl/X2ItG1bJQFc2E71D2Qz0cP/xsbKLQuX9vqSO7WclXSnpT8otpxr6f7Xsk3RI0nPKfu+TuNR9I2w/IulfJH3I9iu2PyHp85JutP19ZX/RfL7MGjeKS/8BIBFVnaEDAMZEoANAIgh0AEgEgQ4AiSDQASARBDoAJIJAB4BE/B/WmKZIJX5BAgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出图像\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "plt.plot(x_train, y_train, 'bo')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([2.2691], requires_grad=True)\n" ] } ], "source": [ "# 转换成 Tensor\n", "x_train = torch.from_numpy(x_train)\n", "y_train = torch.from_numpy(y_train)\n", "\n", "# 定义参数 w 和 b\n", "w = Variable(torch.randn(1), requires_grad=True) # 随机初始化\n", "b = Variable(torch.zeros(1), requires_grad=True) # 使用 0 进行初始化\n", "print(w)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# 构建线性回归模型\n", "x_train = Variable(x_train)\n", "y_train = Variable(y_train)\n", "\n", "def linear_model(x):\n", " return x * w + b\n", "\n", "def logistc_regression(x):\n", " return torch.sigmoid(x*w+b) " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "y_ = linear_model(x_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "经过上面的步骤我们就定义好了模型,在进行参数更新之前,我们可以先看看模型的输出结果长什么样" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWJUlEQVR4nO3df4zU9Z3H8dd7cSuuEM/ihqB0d2nTcCI/VlgNnlfKCQJXTYWYNGf2FJI22Fo82rRe9PhDE91rc2nlzv6hbpVTy9arxR81PdJikYY2paW7HloLhk28XVxEWdGj/Ayw+74/ZnaBdWZndma+8/1+Zp6PZDKz3xlm3vPVec1nPt/P9/MxdxcAIDw1cRcAACgMAQ4AgSLAASBQBDgABIoAB4BAXVDOF7vsssu8qampnC8JAMHr6ur6wN3rR24va4A3NTWps7OznC8JAMEzs95M2+lCAYBA5QxwM/uUmW0zs91m9mczW5ve/oCZ7TezXenLF6IvFwAwJJ8ulDOSvuXur5nZREldZvZK+r717v696MoDAGSTM8Dd/YCkA+nbR8xsj6QrSlXA6dOn1dfXp5MnT5bqKave+PHjNXXqVNXW1sZdCoAIjekgppk1Sbpa0h8kXS9pjZndIalTqVb6Rxn+zWpJqyWpoaHhY8/Z19eniRMnqqmpSWY25jeA87m7Dh06pL6+Pk2bNi3ucgBEKO+DmGY2QdLzkr7h7n+R9Kikz0hqVqqF/v1M/87d2929xd1b6us/NgpGJ0+e1KRJkwjvEjEzTZo0iV80QFJ0dEhNTVJNTeq6o6NkT51XC9zMapUK7w53f0GS3P39c+7/oaSfF1oE4V1a7E8gITo6pNWrpePHU3/39qb+lqTW1qKfPp9RKCbpSUl73P3hc7ZPOedhKyS9WXQ1AFBJ1q07G95Djh9PbS+BfLpQrpd0u6QbRgwZ/Dcz+5OZvSHp7yR9syQVBaipqUkffPBB3GUASJp9+8a2fYxyBri7/9bdzd1nu3tz+rLZ3W9391np7V9Mj1aJXITdSZJSBwEHBwdL+6QAqlOGgRujbh+joM7EHOpO6u2V3M92JxUb4j09PZo+fbruuOMOzZw5Uw8++KCuueYazZ49W/fff//w45YvX6558+bpqquuUnt7e5HvBkDFa2uT6urO31ZXl9peAkEFeJTdSd3d3brrrru0fv167d+/Xzt37tSuXbvU1dWl7du3S5I2bNigrq4udXZ26pFHHtGhQ4eKf2EAlau1VWpvlxobJbPUdXt7SQ5gSmWezKpYUXYnNTY2av78+fr2t7+tLVu26Oqrr5YkHT16VN3d3VqwYIEeeeQRvfjii5Kkd955R93d3Zo0aVLxLw6gcrW2liywRwoqwBsaUt0mmbYX6+KLL5aU6gO/7777dOedd553/69//Wv96le/0o4dO1RXV6eFCxcy1hpArILqQom4O0mStHTpUm3YsEFHjx6VJO3fv18HDx7U4cOHdemll6qurk5vvfWWfv/735fuRQGgAEG1wId+haxbl+o2aWhIhXcpf50sWbJEe/bs0XXXXSdJmjBhgjZu3Khly5bpscce05VXXqnp06dr/vz5pXtRACiAuXvZXqylpcVHLuiwZ88eXXnllWWroVqwX4HKYWZd7t4ycntQXSgAgLMIcAAIFAEOAIEiwAEgUAQ4AASKAAeAQBHgY/DUU0/p3XffHf77K1/5inbv3l308/b09OjHP/7xmP/dqlWrtGnTpqJfH0CYwgvwqOeTHcXIAH/iiSc0Y8aMop+30AAHUN3CCvCI5pPduHGjrr32WjU3N+vOO+/UwMCAVq1apZkzZ2rWrFlav369Nm3apM7OTrW2tqq5uVknTpzQwoULNXRi0oQJE3TPPffoqquu0uLFi7Vz504tXLhQn/70p/Xyyy9LSgX15z73Oc2dO1dz587V7373O0nSvffeq9/85jdqbm7W+vXrNTAwoHvuuWd4StvHH39cUmqeljVr1mj69OlavHixDh48WNT7BipSjI28snP3sl3mzZvnI+3evftj27JqbHRPRff5l8bG/J8jw+vffPPNfurUKXd3/9rXvuYPPPCAL168ePgxH330kbu7f/7zn/c//vGPw9vP/VuSb9682d3dly9f7jfeeKOfOnXKd+3a5XPmzHF392PHjvmJEyfc3X3v3r0+tD+2bdvmN9100/DzPv744/7ggw+6u/vJkyd93rx5/vbbb/vzzz/vixcv9jNnzvj+/fv9kksu8Z/+9KdZ3xdQdTZudK+rOz8f6upS2wMmqdMzZGpQc6FEMZ/s1q1b1dXVpWuuuUaSdOLECS1btkxvv/227r77bt10001asmRJzuf5xCc+oWXLlkmSZs2apQsvvFC1tbWaNWuWenp6JEmnT5/WmjVrtGvXLo0bN0579+7N+FxbtmzRG2+8Mdy/ffjwYXV3d2v79u267bbbNG7cOF1++eW64YYbCn7fQEUabdGAiKZ0jVNYAR7BfLLurpUrV+o73/nOedvb2tr0y1/+Uo899piee+45bdiwYdTnqa2tHV4NvqamRhdeeOHw7TNnzkiS1q9fr8mTJ+v111/X4OCgxo8fn7WmH/zgB1q6dOl52zdv3lzQewSqRsRrUCZNWH3gEcwnu2jRIm3atGm4P/nDDz9Ub2+vBgcHdeutt+qhhx7Sa6+9JkmaOHGijhw5UvBrHT58WFOmTFFNTY1+9KMfaWBgIOPzLl26VI8++qhOnz4tSdq7d6+OHTumBQsW6Cc/+YkGBgZ04MABbdu2reBagIoU8RqUSRNWCzyC+WRnzJihhx56SEuWLNHg4KBqa2v18MMPa8WKFcOLGw+1zletWqWvfvWruuiii7Rjx44xv9Zdd92lW2+9Vc8884yWLVs2vIjE7NmzNW7cOM2ZM0erVq3S2rVr1dPTo7lz58rdVV9fr5deekkrVqzQq6++qhkzZqihoWF4ylsAaW1tqYEN53ajlHrRgARhOtkKxX5F1eroiHbRgBhkm042rBY4AOQS4RqUSRNWHzgAYFgiAryc3TjVgP0JVIfYA3z8+PE6dOgQoVMi7q5Dhw5lHaIIoHLE3gc+depU9fX1qb+/P+5SKsb48eM1derUuMsAELHYA7y2tlbTpk2LuwwACE7sXSgAgMIQ4AAQKAIcAAJFgANAoAhwAAhUzgA3s0+Z2TYz221mfzaztentnzSzV8ysO319afTlAgCG5NMCPyPpW+4+Q9J8SV83sxmS7pW01d0/K2lr+m8AQJnkDHB3P+Dur6VvH5G0R9IVkm6R9HT6YU9LWh5RjQCADMbUB25mTZKulvQHSZPd/UD6rvckTc7yb1abWaeZdXK2JQCUTt4BbmYTJD0v6Rvu/pdz70svuplxMhN3b3f3Fndvqa+vL6pYAMBZeQW4mdUqFd4d7v5CevP7ZjYlff8USQejKREAkEk+o1BM0pOS9rj7w+fc9bKklenbKyX9rPTlAQCyyWcyq+sl3S7pT2a2K73tXyR9V9JzZvZlSb2SvhRJhQCAjHIGuLv/VpJluXtRacsBAOSLMzEBIFAEOAAEigAHgEAR4AAQKAIcAAJFgANAoAhwAAgUAQ5Uoo4OqalJqqlJXXd0xF0RIpDPmZgAQtLRIa1eLR0/nvq7tzf1tyS1tsZXF0qOFjhQadatOxveQ44fT21HRSHAgUqzb9/YtiNYBDhQaRoaxrYdwSLAgUrT1ibV1Z2/ra4utR0VhQAHKk1rq9TeLjU2Smap6/Z2DmBWIEahAJWotZXArgK0wIEkYzw3RkELHEgqxnMjB1rgQFIxnhs5EOBAUjGeGzkQ4EBSMZ4bORDgQFIxnhs5EOBAUjGeGzkwCgVIMsZzYxS0wAEgUAQ4AASKAAeAQBHgQClx6jvKiIOYQKlw6jvKjBY4UIxzW9wrV3LqO8qKFjhQqJEt7oGBzI/j1HdEhBY4UKhMk01lwqnviAgBDhQqn5Y1p74jQgQ4kEu2kSXZWtbjxnHqO8qCPnBgNKONLGlrO/8+KdXiJrRRJjlb4Ga2wcwOmtmb52x7wMz2m9mu9OUL0ZYJxGS0RRWYbAoxM3cf/QFmCyQdlfSMu89Mb3tA0lF3/95YXqylpcU7OzsLLBWIQU2NlOkzYiYNDpa/HlQlM+ty95aR23O2wN19u6QPI6kKSDoWVUCCFXMQc42ZvZHuYrk024PMbLWZdZpZZ39/fxEvB8SARRWQYIUG+KOSPiOpWdIBSd/P9kB3b3f3Fndvqa+vL/DlgJjQz40EK2gUiru/P3TbzH4o6eclqwhIGhZVQEIV1AI3synn/LlC0pvZHgsAiEbOFriZPStpoaTLzKxP0v2SFppZsySX1CPpzuhKBABkks8olNvcfYq717r7VHd/0t1vd/dZ7j7b3b/o7gfKUSyQN+blRhXgTExUHublRpVgLhRUntHOngQqCAGOypNtlkDm5UaFIcBReTh7ElWCAEfl4exJVAkCHJWHsydRJRiFgsrE2ZOoArTAASBQBDgABIoAB4BAEeAAECgCHAACRYAjHkw2BRSNYYQoPyabAkqCFjjKj8mmgJIgwFF+TDYFlAQBjuhk6+dmsimgJOgDRzRG6+duazv/PonJpoACEOCIxmj93D09Zx+zb1+q5d3WxgFMYIzM3cv2Yi0tLd7Z2Vm210OMamqkTP9vmUmDg+WvBwiYmXW5e8vI7fSBIxr0cwORI8ARDRZVACJHgCMaLKoARI6DmIgOiyoAkaIFDgCBIsABIFAEOAAEigAHgEAR4AAQKAIcAAJFgANAoAhwAAgUAQ4AgcoZ4Ga2wcwOmtmb52z7pJm9Ymbd6etLoy0TBWHhYKCi5dMCf0rSshHb7pW01d0/K2lr+m8kydCCCr29qWldhxZUIMSBipEzwN19u6QPR2y+RdLT6dtPS1pe2rJQNBYOBipeoX3gk939QPr2e5ImZ3ugma02s04z6+zv7y/w5TBmLBwMVLyiD2J6akmfrMv6uHu7u7e4e0t9fX2xL4d8saACUPEKDfD3zWyKJKWvD5auJORttIOULKgAVLxCA/xlSSvTt1dK+llpykHech2kZEEFoOLlXNTYzJ6VtFDSZZLel3S/pJckPSepQVKvpC+5+8gDnR/DosYl1NSUCu2RGhvPrvoOoCJkW9Q454o87n5blrsWFV0VCsdBSqDqcSZmqDhICVQ9AjxUHKQEqh4BHioOUgJVj1XpQ8aq70BVowUOAIEiwAEgUAQ4AASKAAeAQBHgABAoAhwAAkWAA0CgCHAACBQBDgCBIsCLxcrvAGLCqfTFGFpUYWjx4KFFFSROcQcQOVrgxWDldwAxIsCLwaIKAGJEgOcjWz83iyoAiBF94LmM1s/d1nb+fRKLKgAoGwI8l9H6uYcWD163LtVt0tCQCm8OYAIog5yr0pdSkKvS19RImfaRmTQ4WP56AFSdbKvS0weeC/3cABKKAM+FxYMBJBQBnguLBwNIKA5i5oPFgwEkEC1wAAgUAQ4AgSLAASBQBDgABIoAB4BAEeAAECgCHAACRYADQKAIcAAIVFFnYppZj6QjkgYknck0WxYAIBqlaIH/nbs3RxberPoOABkley4UVn0HgKyKbYG7pC1m1mVmqzM9wMxWm1mnmXX29/eP7dlZ9R0Asio2wP/W3edK+ntJXzezBSMf4O7t7t7i7i319fVje3ZWfQeArIoKcHffn74+KOlFSdeWoqhhrIYDAFkVHOBmdrGZTRy6LWmJpDdLVZgkVsMBgFEU0wKfLOm3Zva6pJ2S/tvdf1GastJYDQcAsmJVegBIOFalB4AKQ4ADQKAIcAAIFAEOAIEiwAEgQlFO50SAA6goSZr/bmg6p95eyf3sdE6lqokAB3JIUiBgdFEH5lhFPZ0T48CBUYycEFNKnQzM+WTJ1NSUCu2RGhulnp5yV5P60s8UsWbS4GD+z8M4cKAATIgZlqTNfxf1dE4EODCKsQZCUrpbklJHuSVt/ruop3MiwCtYtX6IS2ksgZCU/tek1BGHpM1/F/l0Tu5etsu8efMc5bFxo3tdnXvqI5y61NWltleDjRvdGxvdzVLXhb7vsezHxsbzHzd0aWws/H0UIil1xKVU/+2TRFKnZ8hUArxCVfOHuNRfXvkGglnmfW5W6DspLIyiqAPxyhbgjEKpUKU6+h2iuEYilPp1Cx0Bk7SRGCgeo1CqTNIO5pRTXCMRSt3/WugImKT1AyM6BHiFquYPcVxfXqU+YFXoFxHroFQPArxCJf1DHOUImTi/vFpbU90Ug4Op62L2dzFfRKWsA8lFgFewpH6Iox7mlvQvr3xV868o5IeDmCg7DrLlr6Mj1ee9b1+q5d3WFt4XEYrHQcyIcLJMdtn2TdJOdy5G1P/9k/orCslAgBch7jPekvzlMdq+qZQRMnH/9wc4kacIcZ4sk/QzLUfbN0mvPV/VfLIUyktZTuShBZ6HJHYFJH2WvNH2TaUcZKykriCEKfEBHnc3QVK7ApIeHrn2TSX07VZKVxDClegAT0If42gt3TiHeSU9PKphCFw1vEckXKZ+laguY+0DT0IfY66JgeKa+SyEfuRKnBVupGp4j4ifQpzMKgkTMiV5zDJjhIHqEOQ48CR0EyT5Z3Il9CMDKFyiAzwJ4VkpIyYAVJ4L4i5gNEMhGXc3QWsrgQ0geRId4BLhCQDZJLoLBQCQHQEOAIEiwAEgUAQ4AASKAAeAQJX1TEwz65eU4bzG81wm6YMylBMi9k1m7Jfs2DfZhbRvGt29fuTGsgZ4PsysM9Mpo2DfZMN+yY59k10l7Bu6UAAgUAQ4AAQqiQHeHncBCca+yYz9kh37Jrvg903i+sABAPlJYgscAJAHAhwAApWYADezT5nZNjPbbWZ/NrO1cdeUJGY2zsz+x8x+HnctSWJmf2Vmm8zsLTPbY2bXxV1TUpjZN9OfpTfN7FkzGx93TXEwsw1mdtDM3jxn2yfN7BUz605fXxpnjYVKTIBLOiPpW+4+Q9J8SV83sxkx15QkayXtibuIBPoPSb9w97+WNEfsI0mSmV0h6Z8ktbj7TEnjJP1DvFXF5ilJy0Zsu1fSVnf/rKSt6b+Dk5gAd/cD7v5a+vYRpT6IV8RbVTKY2VRJN0l6Iu5aksTMLpG0QNKTkuTup9z9/2ItKlkukHSRmV0gqU7SuzHXEwt33y7pwxGbb5H0dPr205KWl7OmUklMgJ/LzJokXS3pDzGXkhT/LumfJZVpKedgTJPUL+k/091LT5jZxXEXlQTuvl/S9yTtk3RA0mF33xJvVYky2d0PpG+/J2lynMUUKnEBbmYTJD0v6Rvu/pe464mbmd0s6aC7d8VdSwJdIGmupEfd/WpJxxToT+FSS/fp3qLUl9zlki42s3+Mt6pk8tRY6iDHUycqwM2sVqnw7nD3F+KuJyGul/RFM+uR9F+SbjCzjfGWlBh9kvrcfeiX2ialAh3SYkn/6+797n5a0guS/ibmmpLkfTObIknp64Mx11OQxAS4mZlSfZl73P3huOtJCne/z92nunuTUgehXnV3WlKS3P09Se+Y2fT0pkWSdsdYUpLskzTfzOrSn61F4gDvuV6WtDJ9e6Wkn8VYS8ESE+BKtTRvV6qFuSt9+ULcRSHx7pbUYWZvSGqW9K/xlpMM6V8lmyS9JulPSn3Wgz91vBBm9qykHZKmm1mfmX1Z0ncl3Whm3Ur9WvlunDUWilPpASBQSWqBAwDGgAAHgEAR4AAQKAIcAAJFgANAoAhwAAgUAQ4Agfp/1cKknX7Ge+oAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**思考:红色的点表示预测值,似乎排列成一条直线,请思考一下这些点是否在一条直线上?**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这个时候需要计算我们的误差函数,也就是\n", "\n", "$$\n", "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n", "$$" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# 计算误差\n", "def get_loss(y_, y):\n", " return torch.mean((y_ - y) ** 2)\n", "\n", "loss = get_loss(y_, y_train)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(153.3520, grad_fn=)\n" ] } ], "source": [ "# 打印一下看看 loss 的大小\n", "print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义好了误差函数,接下来我们需要计算 w 和 b 的梯度了,这时得益于 PyTorch 的自动求导,我们不需要手动去算梯度,有兴趣的同学可以手动计算一下,w 和 b 的梯度分别是\n", "\n", "$$\n", "\\frac{\\partial}{\\partial w} = \\frac{2}{n} \\sum_{i=1}^n x_i(w x_i + b - y_i) \\\\\n", "\\frac{\\partial}{\\partial b} = \\frac{2}{n} \\sum_{i=1}^n (w x_i + b - y_i)\n", "$$" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# 自动求导\n", "loss.backward()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([161.0043])\n", "tensor([22.8730])\n" ] } ], "source": [ "# 查看 w 和 b 的梯度\n", "print(w.grad)\n", "print(b.grad)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# 更新一次参数\n", "w.data = w.data - 1e-2 * w.grad.data\n", "b.data = b.data - 1e-2 * b.grad.data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "更新完成参数之后,我们再一次看看模型输出的结果" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAV/klEQVR4nO3df3BV5Z3H8c83IYoB1lrMOFqaXDqzw8oP+RUdXLeURUQqTovj/lEmbaW7TmwtLt22dnT4w+6otX/sSLU7o2YptS2x7YradTpuh6q4tFOU3tBgW1CYxQSDtITYUvm1/Mh3/7g3AeK93Jvknnuee8/7NZPJvedezv3mZPLhOc95nueYuwsAEK6auAsAAJwfQQ0AgSOoASBwBDUABI6gBoDAjYlip5deeqmnUqkodg0AVamjo+Oguzfkei2SoE6lUkqn01HsGgCqkpl153uNrg8ACFzBoDazKWbWedbXX8zsS2WoDQCgIro+3P1NSbMkycxqJe2T9Fy0ZQEABgy3j/p6Sf/r7nn7UvI5efKkenp6dPz48eH+U+QxduxYTZo0SXV1dXGXAiBCww3qT0n6Ya4XzKxVUqskNTY2vu/1np4eTZgwQalUSmY23DoxhLurr69PPT09mjx5ctzlAIhQ0RcTzewCSZ+Q9HSu1929zd2b3b25oeH9I0yOHz+uiRMnEtIlYmaaOHEiZyhACNrbpVRKqqnJfG9vL+nuh9Oi/rikbe7+x5F+GCFdWhxPIADt7VJrq3T0aOZ5d3fmuSS1tJTkI4YzPG+58nR7AEBirV59JqQHHD2a2V4iRQW1mY2TdIOkZ0v2yRUmlUrp4MGDcZcBIDR79w5v+wgUFdTufsTdJ7r7oZJ9cgFRdvm4u/r7+0u3QwDJlWPwxHm3j0CQMxMHuny6uyX3M10+ownrrq4uTZkyRZ/97Gc1ffp03X///br66qt11VVX6b777ht837JlyzR37lxNmzZNbW1tJfhpAFS1Bx+U6uvP3VZfn9leIkEGdVRdPrt379add96pNWvWaN++fdq6das6OzvV0dGhzZs3S5LWrVunjo4OpdNpPfroo+rr6xvdhwKobi0tUlub1NQkmWW+t7WV7EKiFNGiTKMVVZdPU1OT5s2bp69+9avauHGjZs+eLUk6fPiwdu/erfnz5+vRRx/Vc89lJl6+/fbb2r17tyZOnDi6DwZQ3VpaShrMQwUZ1I2Nme6OXNtHY9y4cZIyfdT33nuv7rjjjnNef+WVV/Tiiy9qy5Ytqq+v14IFCxinDCB2QXZ9RN3lc+ONN2rdunU6fPiwJGnfvn06cOCADh06pEsuuUT19fV644039Oqrr5bmAwFgFIJsUQ+cQaxenenuaGzMhHSpziwWL16snTt36tprr5UkjR8/XuvXr9eSJUv0+OOP68orr9SUKVM0b9680nwgAIyCuXvJd9rc3OxDbxywc+dOXXnllSX/rKTjuALVwcw63L0512tBdn0AAM4gqAEgcAQ1AASOoAaAwBHUABA4ghoAAkdQ5/Dkk0/qnXfeGXx+++23a8eOHaPeb1dXl5566qlh/7sVK1Zow4YNo/58AJUp3KCO+NY25zM0qNeuXaupU6eOer8jDWoAyRZmUEexzqmk9evX65prrtGsWbN0xx136PTp01qxYoWmT5+uGTNmaM2aNdqwYYPS6bRaWlo0a9YsHTt2TAsWLNDABJ7x48fr7rvv1rRp07Ro0SJt3bpVCxYs0Ec+8hE9//zzkjKB/NGPflRz5szRnDlz9Ktf/UqSdM899+gXv/iFZs2apTVr1uj06dO6++67B5dbfeKJJyRl1iJZuXKlpkyZokWLFunAgQOj+rkBVDh3L/nX3LlzfagdO3a8b1teTU3umYg+96upqfh95Pj8m2++2U+cOOHu7l/4whf861//ui9atGjwPX/605/c3f1jH/uY//rXvx7cfvZzSf7CCy+4u/uyZcv8hhtu8BMnTnhnZ6fPnDnT3d2PHDnix44dc3f3Xbt2+cDx2LRpky9dunRwv0888YTff//97u5+/Phxnzt3ru/Zs8efeeYZX7RokZ86dcr37dvnF198sT/99NN5fy4AlU9S2vNkapBrfUSxzulLL72kjo4OXX311ZKkY8eOacmSJdqzZ4/uuusuLV26VIsXLy64nwsuuEBLliyRJM2YMUMXXnih6urqNGPGDHV1dUmSTp48qZUrV6qzs1O1tbXatWtXzn1t3LhRr7/++mD/86FDh7R7925t3rxZy5cvV21tra644gotXLhwxD83gMoXZtdHBLe2cXfddttt6uzsVGdnp95880098sgj2r59uxYsWKDHH39ct99+e8H91NXVDd79u6amRhdeeOHg41OnTkmS1qxZo8suu0zbt29XOp3WiRMn8tb07W9/e7Cmt956q6j/LIDEi/EaVhzCDOoI1jm9/vrrtWHDhsH+3nfffVfd3d3q7+/XrbfeqgceeEDbtm2TJE2YMEHvvffeiD/r0KFDuvzyy1VTU6Mf/OAHOn36dM793njjjXrsscd08uRJSdKuXbt05MgRzZ8/Xz/+8Y91+vRp7d+/X5s2bRpxLUDViegaVsjC7PqIYJ3TqVOn6oEHHtDixYvV39+vuro6Pfzww7rlllsGb3T70EMPScoMh/v85z+viy66SFu2bBn2Z91555269dZb9f3vf19LliwZvGHBVVddpdraWs2cOVMrVqzQqlWr1NXVpTlz5sjd1dDQoJ/85Ce65ZZb9PLLL2vq1KlqbGwcXI4VgM5/r74I77ISJ5Y5rXAcVyROTU2mJT2UmZRtdFUiljkFUD0iuIYVOoIaQGWJ+l59ASprUEfRzZJkHE8kUkuL1NYmNTVlujuamjLPq7R/WirjxcSxY8eqr69PEydOHBzehpFzd/X19Wns2LFxlwKUX0tLVQfzUEUFtZl9QNJaSdMluaR/dPdhDYeYNGmSenp61NvbO+wikdvYsWM1adKkuMsAELFiW9SPSPqZu/+DmV0gqb7QPxiqrq5OkydPHu4/A4DEKxjUZnaxpPmSVkiSu5+QlHuqHQCg5Iq5mDhZUq+k75rZb8xsrZmNG/omM2s1s7SZpeneAIDSKSaox0iaI+kxd58t6Yike4a+yd3b3L3Z3ZsbGhpKXCYAJFcxQd0jqcfdX8s+36BMcAMAyqBgULv7HyS9bWZTspuulzT6+1IBAIpS7KiPuyS1Z0d87JH0uehKAgCcraigdvdOSTkXCwEARIu1PgAgcAQ1AASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHAENQAEjqAGgMAR1AAQOIIaAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDAEdQAEDiCGgACR1ADQOAIagAIHEENAIEjqIFK1t4upVJSTU3me3t73BUhAmPiLgDACLW3S62t0tGjmefd3ZnnktTSEl9dKLmiWtRm1mVmvzWzTjNLR10UgCKsXn0mpAccPZrZjqoynBb137v7wcgqATA8e/cObzsqFn3UQKVqbBzedlSsYoPaJW00sw4za831BjNrNbO0maV7e3tLVyGA3B58UKqvP3dbfX1mO6pKsUH9d+4+R9LHJX3RzOYPfYO7t7l7s7s3NzQ0lLRIADm0tEhtbVJTk2SW+d7WxoXEKlRUH7W778t+P2Bmz0m6RtLmKAsDUISWFoI5AQq2qM1snJlNGHgsabGk30VdGAAgo5gW9WWSnjOzgfc/5e4/i7QqAMCggi1qd9/j7jOzX9PcnSsVQCkxuxAFMDMRiBOzC1EExlEDcWJ2IYpAUANxYnYhikBQA3FidiGKQFADcWJ2IYpAUANxYnYhisCoDyBuzC5EAbSoASBwBDUABI6gBoDAEdQAEDiCGhgJ1udAGTHqAxgu1udAmdGiBoaL9TlQZgQ1UIyzuzq6u3O/h/U5EBG6PoBChnZ15MP6HIgILWqgkFxdHUOxPgciRFADhZyvS4P1OVAGdH0AhTQ25u6XbmqSurrKXg6ShxY1IJ1/XDRLkSJmBDUwcLGwu1tyPzMueiCsWYoUMTN3L/lOm5ubPZ1Ol3y/QCRSKbo2EDsz63D35lyv0aIGuG8hAkdQA9y3EIEjqAEuFiJwRQe1mdWa2W/M7KdRFgSUHRcLEbjhjKNeJWmnpL+KqBYgPty3EAErqkVtZpMkLZW0NtpyAABDFdv18S1JX5PUH10pAIBcCga1md0s6YC7dxR4X6uZpc0s3dvbW7ICASDpimlRXyfpE2bWJelHkhaa2fqhb3L3NndvdvfmhoaGEpcJAMlVMKjd/V53n+TuKUmfkvSyu3868sqAQrhvIRKC1fNQmbhvIRKEtT5QmVifA1WGtT5QfVifAwlCUKMysT4HEoSgRmVifQ4kCEGNysT6HEgQRn2gcrE+BxKCFjUABI6gBoDAEdQAEDiCGgACR1ADQOAIagAYpajXB2N4HgCMQjnWB6NFDQCjsHr1mZAecPRoZnupENSIDutFIwHKsT4YQY1oDJwPdndL7mfOBwlrlEBIbYByrA9GUCMa5TgfRCKF1gYox/pgBDWiwXrRiEhobYByrA9GUGP0cp2Hsl40IhJiG6ClJXNjof7+zPdSrxVGUGN08p2H3nQT60UjEklsAxDUGJ1856EvvMB60YhEEu8ZQVBjdM53Hhr1+WAJhTSKAOeXxHtGMDMRo9PYmPtu4BV0HlqOmWUoraTdM4IWNUanCs5DQxtFAAxFUGN0quA8NMRRBMDZ6PrA6FX4eWgV9N6gytGiRuJVQe8NqlzBoDazsWa21cy2m9nvzexfy1EYUC5V0HuDKldM18f/SVro7ofNrE7SL83sv9391YhrA8qmwntvUOUKBrW7u6TD2ad12S+PsigAwBlF9VGbWa2ZdUo6IOnn7v5ajve0mlnazNK9vb0lLhMAkquooHb30+4+S9IkSdeY2fQc72lz92Z3b25oaChxmQByYUZlMgxr1Ie7/1nSJklLIqkGQNFCW5cZ0Slm1EeDmX0g+/giSTdIeiPiuoCghdCSZUZlchQz6uNySd8zs1plgv0/3f2n0ZYFhCuUtUGYUZkclhnUUVrNzc2eTqdLvl8gBKlU7pmMTU2ZRQKTVgdKw8w63L0512vMTKwCIZyGJ0koLVlmVCYHQV3huKBUfqHcYYQZlclB10eF4/S3/Ib2UUuZliwhidGg66OK7d0rLVe73lJKp1Wjt5TScrVzQSlCtGRRbixzWuFWfrBdD/W1apwyzbuUuvUfatWlH5QkkiMqrA2CcqJFXeG+odWDIT1gnI7qG6r+wbRcREVS0KKucOPfzd3HkW97tQhlLDNQDrSoK10oQxDKjFl5SBKCOnSFzu8TOpg2lLHMpUAXDgohqENWzCDphA5BqJYTCcbBoxiMow4Zg6TzqpaxzPyKMYBx1JWqms7vS6xaTiT4FaMYjPoIWWNj7uZWpZ3fR6QaxjLzK0YxaFGHLKEXCpOEXzGKQVCHrFrO75EXv2IUg4uJABAALiYCQAUjqAEgcAQ1AASOoAaAwBHUABA4ghoAAkdQA0DgCGpEhuU7gdJgrQ9EgjuwAKVDixqR4A4sQOkUDGoz+7CZbTKzHWb2ezNbVY7CUNlYvhMonWJa1KckfcXdp0qaJ+mLZjY12rJQ6arlDixACAoGtbvvd/dt2cfvSdop6UNRFxYcrowNC8t3AqUzrD5qM0tJmi3ptRyvtZpZ2szSvb29JSovENzYbthYvhMonaKXOTWz8ZL+R9KD7v7s+d5bdcuccmM7ABEb9TKnZlYn6RlJ7YVCuipxZQxAjIoZ9WGSviNpp7s/HH1JAeLKGIAYFdOivk7SZyQtNLPO7NdNEdcVj3wXDLkyBiBGBWcmuvsvJVkZaolXMVPpVq/OdHc0NmZCmitjAMqAeyYO4IIhgBhxz8RicMEQQKAI6gEFLhgy3yXZ+P0jTgT1gPNcMGS+y/lVe4jx+0fs3L3kX3PnzvWKtH69e1OTu1nm+/r17p55mPkTPferqSnGWgOxfr17ff25x6W+fvDQVQV+/ygHSWnPk6lcTCxCTU3mT3MoM6m/v/z1hCQJ12D5/aMcuJg4SnHOdwm9WyEJ12CZ74S4EdRFiGu+SyX0jSYhxJjvhLgR1EWIayW4SrhLShJCjJUAETf6qANWKX2j7e1M2gRG63x91NzcNmCNjbkv1IXWrdDSQjADUaLrI2BJ6FYAUBhBHTD6RgFIdH0Ej24FALSoASBwBPVZQp9cAiCZ6PrIKua+AQAQB1rUWZUwuQRAMgUT1HF3OyRhzQoAlSmIoG5vl178XLte6U7plNfole6UXvxce1nDOglrVgCoTEEE9Wur2vXvJ1uVUrdq5EqpW/9+slWvrSpfUjO5BECoggjqL/et1jid20E8Tkf15b7ydRAzuQRAqIJYlKnfalSj99fRL1ONB7T6EABEJPgbBxydmLsjON92AEiSIIJ6/CMP6tQF53YQn7qgXuMfoYMYAIIIarW0aMy6czuIx6yjgxgApCJmJprZOkk3Szrg7tMjq4TVhwAgp2Ja1E9KWhJxHQCAPAoGtbtvlvRuGWoBAORQsj5qM2s1s7SZpXt7e0u1WwBIvJIFtbu3uXuzuzc3NDSUarcAkHhhjPoAAOQVyXrUHR0dB80sx/2zz3GppINRfH6F47jkx7HJj2OTWyUdl6Z8LxScQm5mP5S0QJkf+I+S7nP374y2IjNL55sumWQcl/w4NvlxbHKrluNSsEXt7svLUQgAIDf6qAEgcHEGdVuMnx0yjkt+HJv8ODa5VcVxiWSZUwBA6dD1AQCBI6gBIHBlDWoz+7CZbTKzHWb2ezNbVc7PrwRmVmtmvzGzn8ZdS0jM7ANmtsHM3jCznWZ2bdw1hcDM/iX7t/Q7M/uhmY2Nu6a4mNk6MztgZr87a9sHzeznZrY7+/2SOGscqXK3qE9J+oq7T5U0T9IXzWxqmWsI3SpJO+MuIkCPSPqZu/+NpJniGMnMPiTpnyU1Z5cgrpX0qXiritWTev9Kn/dIesnd/1rSS9nnFaesQe3u+919W/bxe8r8sX2onDWEzMwmSVoqaW3ctYTEzC6WNF/SdyTJ3U+4+59jLSocYyRdZGZjJNVLeifmemKTZ6XPT0r6Xvbx9yQtK2dNpRJbH7WZpSTNlvRaXDUE6FuSviaJO/qea7KkXknfzXYLrTWzcXEXFTd33yfp3yTtlbRf0iF33xhvVcG5zN33Zx//QdJlcRYzUrEEtZmNl/SMpC+5+1/iqCE0ZjZwF52OuGsJ0BhJcyQ95u6zJR1RhZ7CllK2v/WTyvxHdoWkcWb26XirCpdnxiJX5Hjksge1mdUpE9Lt7v5suT8/YNdJ+oSZdUn6kaSFZrY+3pKC0SOpx90Hzr42KBPcSbdI0lvu3uvuJyU9K+lvY64pNH80s8slKfv9QMz1jEi5R32YMv2MO9394XJ+dujc/V53n+TuKWUuCL3s7rSOJLn7HyS9bWZTspuul7QjxpJCsVfSPDOrz/5tXS8usg71vKTbso9vk/RfMdYyYuVuUV8n6TPKtBY7s183lbkGVKa7JLWb2euSZkn6RrzlxC97hrFB0jZJv1Xm77kqpkyPRHalzy2SpphZj5n9k6RvSrrBzHYrcwbyzThrHCmmkANA4JiZCACBI6gBIHAENQAEjqAGgMAR1AAQOIIaAAJHUANA4P4fJOC0kP28eAoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y_ = linear_model(x_train)\n", "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的例子可以看到,更新之后红色的线跑到了蓝色的线下面,没有特别好的拟合蓝色的真实值,所以我们需要在进行几次更新" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(3.1358)\n", "epoch: 0, loss: 3.135772228240967\n", "tensor(0.3551)\n", "epoch: 1, loss: 0.355089008808136\n", "tensor(0.3030)\n", "epoch: 2, loss: 0.30295446515083313\n", "tensor(0.3013)\n", "epoch: 3, loss: 0.30131959915161133\n", "tensor(0.3006)\n", "epoch: 4, loss: 0.3006228804588318\n", "tensor(0.2999)\n", "epoch: 5, loss: 0.2999469041824341\n", "tensor(0.2993)\n", "epoch: 6, loss: 0.299274742603302\n", "tensor(0.2986)\n", "epoch: 7, loss: 0.2986060082912445\n", "tensor(0.2979)\n", "epoch: 8, loss: 0.2979407012462616\n", "tensor(0.2973)\n", "epoch: 9, loss: 0.29727888107299805\n" ] } ], "source": [ "for e in range(10): # 进行 10 次更新\n", " y_ = linear_model(x_train)\n", " loss = get_loss(y_, y_train)\n", " \n", " w.grad.zero_() # 记得归零梯度\n", " b.grad.zero_() # 记得归零梯度\n", " loss.backward()\n", " print(loss.data)\n", " w.data = w.data - 1e-2 * w.grad.data # 更新 w\n", " b.data = b.data - 1e-2 * b.grad.data # 更新 b \n", " print('epoch: {}, loss: {}'.format(e, loss.item()))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y_ = linear_model(x_train)\n", "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "经过 10 次更新,我们发现红色的预测结果已经比较好的拟合了蓝色的真实值。\n", "\n", "现在你已经学会了你的第一个机器学习模型了,再接再厉,完成下面的小练习。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**小练习:**\n", "\n", "重启 notebook 运行上面的线性回归模型,但是改变训练次数以及不同的学习率进行尝试得到不同的结果" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 多项式回归模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们更进一步,讲一讲多项式回归。什么是多项式回归呢?非常简单,根据上面的线性回归模型\n", "\n", "$$\n", "\\hat{y} = w x + b\n", "$$\n", "\n", "这里是关于 x 的一个一次多项式,这个模型比较简单,没有办法拟合比较复杂的模型,所以我们可以使用更高次的模型,比如\n", "\n", "$$\n", "\\hat{y} = w_0 + w_1 x + w_2 x^2 + w_3 x^3 + \\cdots\n", "$$\n", "\n", "这样就能够拟合更加复杂的模型,这就是多项式模型,这里使用了 x 的更高次,同理还有多元回归模型,形式也是一样的,只是出了使用 x,还是更多的变量,比如 y、z 等等,同时他们的 loss 函数和简单的线性回归模型是一致的。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "首先我们可以先定义一个需要拟合的目标函数,这个函数是个三次的多项式" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3\n" ] } ], "source": [ "# 定义一个多变量函数\n", "\n", "w_target = np.array([0.5, 3, 2.4]) # 定义参数\n", "b_target = np.array([0.9]) # 定义参数\n", "\n", "f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(\n", " b_target[0], w_target[0], w_target[1], w_target[2]) # 打印出函数的式子\n", "\n", "print(f_des)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们可以先画出这个多项式的图像" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出这个函数的曲线\n", "x_sample = np.arange(-3, 3.1, 0.1)\n", "y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 + w_target[2] * x_sample ** 3\n", "\n", "plt.plot(x_sample, y_sample, label='real curve')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接着我们可以构建数据集,需要 x 和 y,同时是一个三次多项式,所以我们取了 $x,\\ x^2, x^3$" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# 构建数据 x 和 y\n", "# x 是一个如下矩阵 [x, x^2, x^3]\n", "# y 是函数的结果 [y]\n", "\n", "x_train = np.stack([x_sample ** i for i in range(1, 4)], axis=1)\n", "x_train = torch.from_numpy(x_train).float() # 转换成 float tensor\n", "\n", "y_train = torch.from_numpy(y_sample).float().unsqueeze(1) # 转化成 float tensor " ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([61, 3])\n" ] } ], "source": [ "print(x_train.size())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接着我们可以定义需要优化的参数,就是前面这个函数里面的 $w_i$" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "# 定义参数和模型\n", "w = Variable(torch.randn(3, 1), requires_grad=True)\n", "b = Variable(torch.zeros(1), requires_grad=True)\n", "\n", "# 将 x 和 y 转换成 Variable\n", "x_train = Variable(x_train)\n", "y_train = Variable(y_train)\n", "\n", "def multi_linear(x):\n", " return torch.mm(x, w) + b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们可以画出没有更新之前的模型和真实的模型之间的对比" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新之前的模型\n", "y_pred = multi_linear(x_train)\n", "\n", "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以发现,这两条曲线之间存在差异,我们计算一下他们之间的误差" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(447.3372, grad_fn=)\n" ] } ], "source": [ "# 计算误差,这里的误差和一元的线性模型的误差是相同的,前面已经定义过了 get_loss\n", "loss = get_loss(y_pred, y_train)\n", "print(loss)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "# 自动求导\n", "loss.backward()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ -60.7756],\n", " [ -81.7448],\n", " [-401.0452]])\n", "tensor([-15.4545])\n" ] } ], "source": [ "# 查看一下 w 和 b 的梯度\n", "print(w.grad)\n", "print(b.grad)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# 更新一下参数\n", "w.data = w.data - 0.001 * w.grad.data\n", "b.data = b.data - 0.001 * b.grad.data" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新一次之后的模型\n", "y_pred = multi_linear(x_train)\n", "\n", "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "因为只更新了一次,所以两条曲线之间的差异仍然存在,我们进行 100 次迭代" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 20, Loss: 22.71861\n", "epoch 40, Loss: 5.37627\n", "epoch 60, Loss: 1.32816\n", "epoch 80, Loss: 0.38091\n", "epoch 100, Loss: 0.15742\n" ] } ], "source": [ "# 进行 100 次参数更新\n", "for e in range(100):\n", " y_pred = multi_linear(x_train)\n", " loss = get_loss(y_pred, y_train)\n", " \n", " w.grad.data.zero_()\n", " b.grad.data.zero_()\n", " loss.backward()\n", " \n", " # 更新参数\n", " w.data = w.data - 0.001 * w.grad.data\n", " b.data = b.data - 0.001 * b.grad.data\n", " if (e + 1) % 20 == 0:\n", " print('epoch {}, Loss: {:.5f}'.format(e+1, loss.data.item()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到更新完成之后 loss 已经非常小了,我们画出更新之后的曲线对比" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新之后的结果\n", "y_pred = multi_linear(x_train)\n", "\n", "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,经过 100 次更新之后,可以看到拟合的线和真实的线已经完全重合了" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "**小练习:上面的例子是一个三次的多项式,尝试使用二次的多项式去拟合它,看看最后能做到多好**\n", "\n", "**提示:参数 `w = torch.randn(2, 1)`,同时重新构建 x 数据集**" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }