{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 线性模型和梯度下降\n", "这是神经网络的第一课,我们会学习一个非常简单的模型,线性回归,同时也会学习一个优化算法-梯度下降法,对这个模型进行优化。线性回归是监督学习里面一个非常简单的模型,同时梯度下降也是深度学习中应用最广的优化算法,我们将从这里开始我们的深度学习之旅" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 一元线性回归\n", "一元线性模型非常简单,假设我们有变量 $x_i$ 和目标 $y_i$,每个 i 对应于一个数据点,希望建立一个模型\n", "\n", "$$\n", "\\hat{y}_i = w x_i + b\n", "$$\n", "\n", "$\\hat{y}_i$ 是我们预测的结果,希望通过 $\\hat{y}_i$ 来拟合目标 $y_i$,通俗来讲就是找到这个函数拟合 $y_i$ 使得误差最小,即最小化\n", "\n", "$$\n", "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "那么如何最小化这个误差呢?\n", "\n", "这里需要用到**梯度下降**,这是我们接触到的第一个优化算法,非常简单,但是却非常强大,在深度学习中被大量使用,所以让我们从简单的例子出发了解梯度下降法的原理" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 梯度下降法\n", "在梯度下降法中,我们首先要明确梯度的概念,随后我们再了解如何使用梯度进行下降。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 梯度\n", "梯度在数学上就是导数,如果是一个多元函数,那么梯度就是偏导数。比如一个函数f(x, y),那么 f 的梯度就是 \n", "\n", "$$\n", "(\\frac{\\partial f}{\\partial x},\\ \\frac{\\partial f}{\\partial y})\n", "$$\n", "\n", "可以称为 grad f(x, y) 或者 $\\nabla f(x, y)$。具体某一点 $(x_0,\\ y_0)$ 的梯度就是 $\\nabla f(x_0,\\ y_0)$。\n", "\n", "下面这个图片是 $f(x) = x^2$ 这个函数在 x=1 处的梯度\n", "\n", "![](https://ws3.sinaimg.cn/large/006tNc79ly1fmarbuh2j3j30ba0b80sy.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "梯度有什么意义呢?从几何意义来讲,一个点的梯度值是这个函数变化最快的地方,具体来说,对于函数 f(x, y),在点 $(x_0, y_0)$ 处,沿着梯度 $\\nabla f(x_0,\\ y_0)$ 的方向,函数增加最快,也就是说沿着梯度的方向,我们能够更快地找到函数的极大值点,或者反过来沿着梯度的反方向,我们能够更快地找到函数的最小值点。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 梯度下降法\n", "有了对梯度的理解,我们就能了解梯度下降发的原理了。上面我们需要最小化这个误差,也就是需要找到这个误差的最小值点,那么沿着梯度的反方向我们就能够找到这个最小值点。\n", "\n", "我们可以来看一个直观的解释。比如我们在一座大山上的某处位置,由于我们不知道怎么下山,于是决定走一步算一步,也就是在每走到一个位置的时候,求解当前位置的梯度,沿着梯度的负方向,也就是当前最陡峭的位置向下走一步,然后继续求解当前位置梯度,向这一步所在位置沿着最陡峭最易下山的位置走一步。这样一步步的走下去,一直走到觉得我们已经到了山脚。当然这样走下去,有可能我们不能走到山脚,而是到了某一个局部的山峰低处。\n", "\n", "类比我们的问题,就是沿着梯度的反方向,我们不断改变 w 和 b 的值,最终找到一组最好的 w 和 b 使得误差最小。\n", "\n", "在更新的时候,我们需要决定每次更新的幅度,比如在下山的例子中,我们需要每次往下走的那一步的长度,这个长度称为学习率,用 $\\eta$ 表示,这个学习率非常重要,不同的学习率都会导致不同的结果,学习率太小会导致下降非常缓慢,学习率太大又会导致跳动非常明显,可以看看下面的例子\n", "\n", "![](https://ws2.sinaimg.cn/large/006tNc79ly1fmgn23lnzjg30980gogso.gif)\n", "\n", "可以看到上面的学习率较为合适,而下面的学习率太大,就会导致不断跳动\n", "\n", "最后我们的更新公式就是\n", "\n", "$$\n", "w := w - \\eta \\frac{\\partial f(w,\\ b)}{\\partial w} \\\\\n", "b := b - \\eta \\frac{\\partial f(w,\\ b)}{\\partial b}\n", "$$\n", "\n", "通过不断地迭代更新,最终我们能够找到一组最优的 w 和 b,这就是梯度下降法的原理。\n", "\n", "最后可以通过这张图形象地说明一下这个方法\n", "\n", "![](https://ws3.sinaimg.cn/large/006tNc79ly1fmarxsltfqj30gx091gn4.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上面是原理部分,下面通过一个例子来进一步学习线性模型" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import numpy as np\n", "from torch.autograd import Variable\n", "\n", "torch.manual_seed(2017)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# 读入数据 x 和 y\n", "x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],\n", " [9.779], [6.182], [7.59], [2.167], [7.042],\n", " [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)\n", "\n", "y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],\n", " [3.366], [2.596], [2.53], [1.221], [2.827],\n", " [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPrElEQVR4nO3df4gc933G8ec5SdS+OMRtdSSqrLstNKQkprbSxbVrKMauwU2NXagLLlvXKSkHIW3sYih1BC4JXEmhuD9iiFnsNEq7uAmySV0TtxWJITE0CitV/iUZYqjubFepznYt293UraJP/5gVkq67t7On2ZvZ77xfsMzMd0e7H4a7R9+b/cysI0IAgOk3U3YBAIBiEOgAkAgCHQASQaADQCIIdABIxNay3nj79u3RaDTKensAmEoHDx58LSLmBj1XWqA3Gg11u92y3h4AppLt5WHPccoFABJBoANAIkYGuu2LbH/P9jO2X7D92QH7fNz2qu3D/cfvTqZcAMAwec6hvyvp+oh4x/Y2SU/bfjIivrtmv69GxO8VXyIAII+RgR7ZzV7e6W9u6z+4AQwAVEyuc+i2t9g+LOmEpP0RcWDAbr9u+1nb+2zvGvI6i7a7trurq6sbrxoAplCnIzUa0sxMtux0in39XIEeET+KiCslXSbpKtuXr9nlHyQ1IuLnJO2XtHfI67QjohkRzbm5gW2UAJCkTkdaXJSWl6WIbLm4WGyoj9XlEhFvSnpK0k1rxl+PiHf7mw9J+vlCqgOAROzZI/V654/1etl4UfJ0uczZvrS/frGkGyW9uGafHeds3iLpaHElAsD0W1kZb3wj8nS57JC01/YWZf8BfC0inrD9OUndiHhc0qdt3yLplKQ3JH28uBIBYPrNz2enWQaNFyVPl8uzknYPGL/vnPV7Jd1bXFkAkJalpeyc+bmnXWZns/GicKUoAGyCVktqt6WFBcnOlu12Nl6U0m7OBQB102oVG+BrMUMHgEQQ6ACSNekLeaqGUy4AknTmQp4zH0KeuZBHmuxpjzIxQweQpM24kKdqCHQASdqMC3mqhkAHkKRhF+wUeSFP1RDoAJK0tJRduHOuoi/kqRoCHUCSNuNCnqqhywVAsiZ9IU/VMEMHgEQQ6ACQCAIdABJBoANAIgh0AEgEgQ4AiSDQASARBDowhrrdjhXThQuLgJzqeDtWTBdm6EBOdbwdK6YLgQ7kVMfbsWK6EOhATnW8HSumC4EO5FTH27FiuhDoQE51vB0rpgtdLsAY6nY7VkwXZugAkAgCHQASQaADQCIIdABIBIEOAIkg0AEgEQQ6ACSCQAeARBDoAJCIkYFu+yLb37P9jO0XbH92wD4/Zvurtl+yfcB2YyLVAgCGyjNDf1fS9RFxhaQrJd1k++o1+3xC0n9GxM9I+nNJf1polQCAkUYGemTe6W9u6z9izW63StrbX98n6QbbLqxKAMBIuc6h295i+7CkE5L2R8SBNbvslPSyJEXEKUknJf3kgNdZtN213V1dXb2gwgEA58sV6BHxo4i4UtJlkq6yfflG3iwi2hHRjIjm3NzcRl4CADDEWF0uEfGmpKck3bTmqVcl7ZIk21slvU/S6wXUBwDIKU+Xy5ztS/vrF0u6UdKLa3Z7XNKd/fXbJH0rItaeZwcATFCeL7jYIWmv7S3K/gP4WkQ8YftzkroR8bikhyX9je2XJL0h6faJVQwAGGhkoEfEs5J2Dxi/75z1/5b0G8WWBgAYB1eKAonrdKRGQ5qZyZadTtkVYVL4TlEgYZ2OtLgo9XrZ9vJyti3x3agpYoYOJGzPnrNhfkavl40jPQQ6kLCVlfHGMd0IdCBh8/PjjWO6EehAwpaWpNnZ88dmZ7NxpIdAByakCt0lrZbUbksLC5KdLdttPhBNFV0uwARUqbuk1SLA64IZOjABdJegDAQ6MAF0l6AMBDowAXSXoAwEOjABdJegDAR6TVSh46JO6C5BGehyqYEqdVzUCd0l2GzM0GuAjgugHgj0GqDjAqgHAr0G6LgA6oFArwE6LoB6INBrgI4LoB7ocqkJOi6A9DFDB4BEEOgAkAgCHQASQaADQCIIdABIBIEOAIkg0AEgEQQ6ksetg1EXXFiEpHHrYNQJM3QkjVsHo04IdCSNWwejTgh0JI1bB6NOCHQkjVsHo04IdCQtpVsH062DUehyQfJSuHUw3TrIY+QM3fYu20/ZPmL7Bdt3DdjnOtsnbR/uP+6bTLlAPdGtgzzyzNBPSbonIg7Zfq+kg7b3R8SRNft9JyJuLr5EAHTrII+RM/SIOB4Rh/rrb0s6KmnnpAsDcBbdOshjrA9FbTck7ZZ0YMDT19h+xvaTtj8y5N8v2u7a7q6uro5fLVBTdOsgj9yBbvsSSY9Kujsi3lrz9CFJCxFxhaQvSPr6oNeIiHZENCOiOTc3t8GSgfpJqVsHk+OIGL2TvU3SE5L+KSLuz7H/MUnNiHht2D7NZjO63e4YpQIAbB+MiOag5/J0uVjSw5KODgtz2x/o7yfbV/Vf9/WNlwwAGFeeLpdrJd0h6Tnbh/tjn5E0L0kR8aCk2yR90vYpST+UdHvkmfoDAAozMtAj4mlJHrHPA5IeKKooAMD4uPQfABJBoANAIgh0AEgEgQ4AiSDQASARBDoAJIJAB4BEEOgAkAgCHQASQaADQCIIdABIBIEOAIkg0AEgEQQ6ACSCQAeARBDoAJAIAh0AEkGgA0AiCHQASASBDgCJINABIBEEOgAkgkAHgEQQ6ACQCAIdABJBoANAIgh0AEgEgY7SdTpSoyHNzGTLTqfsioDptLXsAlBvnY60uCj1etn28nK2LUmtVnl1AdOIGTpKtWfP2TA/o9fLxgGMh0BHqVZWxhsHMByBjlLNz483DmA4Ah2lWlqSZmfPH5udzcYBjIdAR6laLandlhYWJDtbttt8IApsBF0uKF2rRYADRRg5Q7e9y/ZTto/YfsH2XQP2se2/sv2S7Wdtf3Qy5QIAhskzQz8l6Z6IOGT7vZIO2t4fEUfO2edXJH2w//gFSV/sLwEAm2TkDD0ijkfEof7625KOStq5ZrdbJX0lMt+VdKntHYVXCwAYaqwPRW03JO2WdGDNUzslvXzO9iv6/6Ev24u2u7a7q6urY5YKAFhP7kC3fYmkRyXdHRFvbeTNIqIdEc2IaM7NzW3kJQAAQ+QKdNvblIV5JyIeG7DLq5J2nbN9WX8MALBJ8nS5WNLDko5GxP1Ddntc0m/3u12ulnQyIo4XWCcAYIQ8XS7XSrpD0nO2D/fHPiNpXpIi4kFJ35D0MUkvSepJ+p3CKwUArGtkoEfE05I8Yp+Q9KmiigIAjI9L/wEgEQQ6ACSCQAeARBDoAJAIAh0AEkGgA0AiCHQASASBDgCJINABIBEEOgAkgkAHgEQQ6ACQCAIdABJBoANAIgh0AEgEgQ4AiSDQASARBDoAJIJAL1CnIzUa0sxMtux0yq4Im42fAZQpz5dEI4dOR1pclHq9bHt5OduWpFarvLqwefgZQNmcfb/z5ms2m9Htdkt570loNLJf4LUWFqRjxza7GpSBnwFsBtsHI6I56DlOuRRkZWW8caSHnwGUjUAvyPz8eON1U4dzy/wMoGwEekGWlqTZ2fPHZmez8bo7c255eVmKOHtuObVQ52cAZSPQC9JqSe12dr7UzpbtNh+GSdKePWc/KDyj18vGU8LPAMrGh6KYuJmZbGa+li2dPr359QDTjA9FUSrOLQObg0DHxHFuGdgcBDomjnPLwOYg0BNR9bbAViu7uOb06WxJmAPF49L/BHDJOQCJGXoS6tIWCGB9BHoCuOQcgESgJ4G2QAASgZ4E2gIBSDkC3faXbJ+w/fyQ56+zfdL24f7jvuLLxHpoCwQg5ety+bKkByR9ZZ19vhMRNxdSETak1SLAgbobOUOPiG9LemMTagEAXICizqFfY/sZ20/a/siwnWwv2u7a7q6urhb01gAAqZhAPyRpISKukPQFSV8ftmNEtCOiGRHNubm5At4aAHDGBQd6RLwVEe/0178haZvt7RdcGQBgLBcc6LY/YNv99av6r/n6hb4uAGA8I7tcbD8i6TpJ222/IumPJW2TpIh4UNJtkj5p+5SkH0q6Pcr61gwAqLGRgR4Rvzni+QeUtTUCAErElaIAkAgCHQASQaADQCIIdABIBIEOAIkg0AEgEQQ6ACSCQAeARBDoAJAIAn1MnY7UaEgzM9my0ym7IgDI5PnGIvR1OtLiotTrZdvLy9m2xLcFASgfM/Qx7NlzNszP6PWycQAoG4E+hpWV8cYBYDMR6GOYnx9vHAA2E4E+hqUlaXb2/LHZ2WwcAMpGoI+h1ZLabWlhQbKzZbvNB6IAqmGqAr0KLYOtlnTsmHT6dLYkzAFUxdS0LdIyCADrm5oZOi2DALC+qQl0WgYBYH1TE+i0DALA+qYm0GkZBID1TU2g0zIIAOubmi4XKQtvAhwABpuaGToAYH0EOgAkgkAHgEQQ6ACQCAIdABLhiCjnje1VScs5dt0u6bUJlzONOC7DcWwG47gMN03HZiEi5gY9UVqg52W7GxHNsuuoGo7LcBybwTguw6VybDjlAgCJINABIBHTEOjtsguoKI7LcBybwTguwyVxbCp/Dh0AkM80zNABADkQ6ACQiEoGuu1dtp+yfcT2C7bvKrumKrG9xfa/2n6i7FqqxPaltvfZftH2UdvXlF1TVdj+g/7v0vO2H7F9Udk1lcX2l2yfsP38OWM/YXu/7e/3lz9eZo0bVclAl3RK0j0R8WFJV0v6lO0Pl1xTldwl6WjZRVTQX0r6x4j4WUlXiGMkSbK9U9KnJTUj4nJJWyTdXm5VpfqypJvWjP2RpG9GxAclfbO/PXUqGegRcTwiDvXX31b2i7mz3KqqwfZlkn5V0kNl11Iltt8n6ZckPSxJEfE/EfFmqUVVy1ZJF9veKmlW0r+XXE9pIuLbkt5YM3yrpL399b2Sfm0zaypKJQP9XLYbknZLOlByKVXxF5L+UNLpkuuomp+WtCrpr/unox6y/Z6yi6qCiHhV0p9JWpF0XNLJiPjncquqnPdHxPH++g8kvb/MYjaq0oFu+xJJj0q6OyLeKruestm+WdKJiDhYdi0VtFXSRyV9MSJ2S/ovTemfzUXrnw++Vdl/ej8l6T22f6vcqqorsl7uqeznrmyg296mLMw7EfFY2fVUxLWSbrF9TNLfSbre9t+WW1JlvCLplYg485fcPmUBD+mXJf1bRKxGxP9KekzSL5ZcU9X8h+0dktRfnii5ng2pZKDbtrJzoUcj4v6y66mKiLg3Ii6LiIayD7W+FRHMtCRFxA8kvWz7Q/2hGyQdKbGkKlmRdLXt2f7v1g3iA+O1Hpd0Z3/9Tkl/X2ItG1bJQFc2E71D2Qz0cP/xsbKLQuX9vqSO7WclXSnpT8otpxr6f7Xsk3RI0nPKfu+TuNR9I2w/IulfJH3I9iu2PyHp85JutP19ZX/RfL7MGjeKS/8BIBFVnaEDAMZEoANAIgh0AEgEgQ4AiSDQASARBDoAJIJAB4BE/B/WmKZIJX5BAgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出图像\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "plt.plot(x_train, y_train, 'bo')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([-1.1286], requires_grad=True)\n" ] } ], "source": [ "# 转换成 Tensor\n", "x_train = torch.from_numpy(x_train)\n", "y_train = torch.from_numpy(y_train)\n", "\n", "# 定义参数 w 和 b\n", "w = Variable(torch.randn(1), requires_grad=True) # 随机初始化\n", "b = Variable(torch.zeros(1), requires_grad=True) # 使用 0 进行初始化\n", "print(w)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "# 构建线性回归模型\n", "x_train = Variable(x_train)\n", "y_train = Variable(y_train)\n", "\n", "def linear_model(x):\n", " return x * w + b\n", "\n", "def logistc_regression(x):\n", " return torch.sigmoid(x*w+b) " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "y_ = linear_model(x_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "经过上面的步骤我们就定义好了模型,在进行参数更新之前,我们可以先看看模型的输出结果长什么样" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**思考:红色的点表示预测值,似乎排列成一条直线,请思考一下这些点是否在一条直线上?**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这个时候需要计算我们的误差函数,也就是\n", "\n", "$$\n", "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n", "$$" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# 计算误差\n", "def get_loss(y_, y):\n", " return torch.mean((y_ - y) ** 2)\n", "\n", "loss = get_loss(y_, y_train)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(94.9309, grad_fn=)\n" ] } ], "source": [ "# 打印一下看看 loss 的大小\n", "print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义好了误差函数,接下来我们需要计算 w 和 b 的梯度了,这时得益于 PyTorch 的自动求导,我们不需要手动去算梯度,有兴趣的同学可以手动计算一下,w 和 b 的梯度分别是\n", "\n", "$$\n", "\\frac{\\partial}{\\partial w} = \\frac{2}{n} \\sum_{i=1}^n x_i(w x_i + b - y_i) \\\\\n", "\\frac{\\partial}{\\partial b} = \\frac{2}{n} \\sum_{i=1}^n (w x_i + b - y_i)\n", "$$" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# 自动求导\n", "loss.backward()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([-126.6150])\n", "tensor([-18.3376])\n" ] } ], "source": [ "# 查看 w 和 b 的梯度\n", "print(w.grad)\n", "print(b.grad)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# 更新一次参数\n", "w.data = w.data - 1e-2 * w.grad.data\n", "b.data = b.data - 1e-2 * b.grad.data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "更新完成参数之后,我们再一次看看模型输出的结果" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y_ = linear_model(x_train)\n", "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的例子可以看到,更新之后红色的线跑到了蓝色的线下面,没有特别好的拟合蓝色的真实值,所以我们需要在进行几次更新" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(1.9596)\n", "epoch: 0, loss: 1.9595526456832886\n", "tensor(0.2388)\n", "epoch: 1, loss: 0.23876741528511047\n", "tensor(0.2067)\n", "epoch: 2, loss: 0.20673297345638275\n", "tensor(0.2060)\n", "epoch: 3, loss: 0.2059527039527893\n", "tensor(0.2058)\n", "epoch: 4, loss: 0.20575186610221863\n", "tensor(0.2056)\n", "epoch: 5, loss: 0.2055628001689911\n", "tensor(0.2054)\n", "epoch: 6, loss: 0.20537473261356354\n", "tensor(0.2052)\n", "epoch: 7, loss: 0.20518775284290314\n", "tensor(0.2050)\n", "epoch: 8, loss: 0.20500165224075317\n", "tensor(0.2048)\n", "epoch: 9, loss: 0.2048165202140808\n" ] } ], "source": [ "for e in range(10): # 进行 10 次更新\n", " y_ = linear_model(x_train)\n", " loss = get_loss(y_, y_train)\n", " \n", " w.grad.zero_() # 记得归零梯度\n", " b.grad.zero_() # 记得归零梯度\n", " loss.backward()\n", " print(loss.data)\n", " w.data = w.data - 1e-2 * w.grad.data # 更新 w\n", " b.data = b.data - 1e-2 * b.grad.data # 更新 b \n", " print('epoch: {}, loss: {}'.format(e, loss.item()))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y_ = linear_model(x_train)\n", "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "经过 10 次更新,我们发现红色的预测结果已经比较好的拟合了蓝色的真实值。\n", "\n", "现在你已经学会了你的第一个机器学习模型了,再接再厉,完成下面的小练习。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**小练习:**\n", "\n", "重启 notebook 运行上面的线性回归模型,但是改变训练次数以及不同的学习率进行尝试得到不同的结果" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 多项式回归模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们更进一步,讲一讲多项式回归。什么是多项式回归呢?非常简单,根据上面的线性回归模型\n", "\n", "$$\n", "\\hat{y} = w x + b\n", "$$\n", "\n", "这里是关于 x 的一个一次多项式,这个模型比较简单,没有办法拟合比较复杂的模型,所以我们可以使用更高次的模型,比如\n", "\n", "$$\n", "\\hat{y} = w_0 + w_1 x + w_2 x^2 + w_3 x^3 + \\cdots\n", "$$\n", "\n", "这样就能够拟合更加复杂的模型,这就是多项式模型,这里使用了 x 的更高次,同理还有多元回归模型,形式也是一样的,只是出了使用 x,还是更多的变量,比如 y、z 等等,同时他们的 loss 函数和简单的线性回归模型是一致的。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "首先我们可以先定义一个需要拟合的目标函数,这个函数是个三次的多项式" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3\n" ] } ], "source": [ "# 定义一个多变量函数\n", "\n", "w_target = np.array([0.5, 3, 2.4]) # 定义参数\n", "b_target = np.array([0.9]) # 定义参数\n", "\n", "f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(\n", " b_target[0], w_target[0], w_target[1], w_target[2]) # 打印出函数的式子\n", "\n", "print(f_des)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们可以先画出这个多项式的图像" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAlIklEQVR4nO3deXxU9b3/8dcnG2FfQkAkQED2TYGwVcuDFheqXjdai61L1YpetbWtt7Xq7bX9aW+1trbaai1Vi1aLIGq1dSlqVVzKLpssBtmSCCQsCQSyznx+f2TgorIlM8mZmbyfj0ceM3PmzPl+zhDeOfOd7/kec3dERCQ5pQRdgIiINB6FvIhIElPIi4gkMYW8iEgSU8iLiCQxhbyISBI77pA3s8fMrNjMVh2yrJOZvWZm+ZHbjpHlZmYPmNl6M1thZiMbo3gRETm6+hzJzwAmf2bZj4E33L0f8EbkMcBXgH6Rn2nAH6IrU0REGsLqczKUmeUC/3D3oZHH64CJ7r7VzLoBb7n7ADP7Y+T+zM+ud7Ttd+7c2XNzcxu2JyIizdSSJUt2uHv24Z5Li3LbXQ8J7m1A18j97kDBIesVRpYdNeRzc3NZvHhxlCWJiDQvZrb5SM/F7ItXr/tIUO85EsxsmpktNrPFJSUlsSpHRESIPuS3R7ppiNwWR5YXAT0OWS8nsuxz3H26u+e5e1529mE/bYiISANFG/IvAldE7l8BvHDI8ssjo2zGAWXH6o8XEZHYO+4+eTObCUwEOptZIXAHcDcw28yuBjYDF0dWfxk4G1gP7AeubGiBNTU1FBYWUllZ2dBNyGdkZmaSk5NDenp60KWISCM77pB390uO8NSkw6zrwA0NLepQhYWFtG3bltzcXMwsFpts1tydnTt3UlhYSO/evYMuR0QaWdyf8VpZWUlWVpYCPkbMjKysLH0yEmkm4j7kAQV8jOn9FGk+EiLkRUSS2f2v57Ngw85G2bZCvgnk5uayY8eOoMsQkTi0oaSc37z+EQs27mqU7Svk68HdCYfDQZcRN3WISPT+Mn8z6anG1DE9jr1yAyjkj2HTpk0MGDCAyy+/nKFDh1JQUMC9997L6NGjGT58OHfcccfBdS+44AJGjRrFkCFDmD59+jG3/eqrrzJy5EhOPvlkJk2qG6T005/+lF/96lcH1xk6dCibNm36XB133nknP/zhDw+uN2PGDG688UYAnnzyScaMGcMpp5zCtddeSygUitXbISIxtK+qljmLCzl7WDe6tM1slDainbumSf3s7x+y+pM9Md3m4BPbccd/DDnqOvn5+Tz++OOMGzeOuXPnkp+fz8KFC3F3zjvvPObNm8eECRN47LHH6NSpExUVFYwePZopU6aQlZV12G2WlJRwzTXXMG/ePHr37s2uXcf+qHZoHSUlJYwfP557770XgFmzZnH77bezZs0aZs2axXvvvUd6ejrXX389Tz31FJdffnn93xwRaVR/W1bE3qpaLh/fq9HaSKiQD0qvXr0YN24cAHPnzmXu3LmMGDECgPLycvLz85kwYQIPPPAAzz//PAAFBQXk5+cfMeTnz5/PhAkTDo5V79SpU73qyM7Opk+fPsyfP59+/fqxdu1aTj31VB588EGWLFnC6NGjAaioqKBLly7RvQEiEnPuzhPvb2bIie0Y2bNjo7WTUCF/rCPuxtK6deuD992dW2+9lWuvvfZT67z11lu8/vrr/Pvf/6ZVq1ZMnDixQWPR09LSPtXffug2Dq0DYOrUqcyePZuBAwdy4YUXYma4O1dccQW/+MUv6t22iDSdBRt3sW77Xn45ZXijDmtWn3w9nXXWWTz22GOUl5cDUFRURHFxMWVlZXTs2JFWrVqxdu1a5s+ff9TtjBs3jnnz5rFx40aAg901ubm5LF26FIClS5cefP5wLrzwQl544QVmzpzJ1KlTAZg0aRJz5syhuLj44HY3bz7iLKQiEpAn/r2J9i3T+Y+TT2zUdhLqSD4enHnmmaxZs4bx48cD0KZNG5588kkmT57Mww8/zKBBgxgwYMDBbpUjyc7OZvr06Vx00UWEw2G6dOnCa6+9xpQpU3jiiScYMmQIY8eOpX///kfcRseOHRk0aBCrV69mzJgxAAwePJi77rqLM888k3A4THp6Og8++CC9ejVen5+I1M/Wsgr++eF2rj6tNy0zUhu1rXpdGaqx5eXl+WcvGrJmzRoGDRoUUEXJS++rSHDum7uO3725nrf/60v0zGoV9fbMbIm75x3uOXXXiIg0oaraEH9duIUvD+gSk4A/FoW8iEgTenXVNnaUV3NZIw6bPFRChHw8dSklA72fIsF5/P1N5Ga1YkK/prkSXtyHfGZmJjt37lQwxciB+eQzMxvn7DoRObLlBaUs3VLKZeNzSUlpmtlg4350TU5ODoWFhegi37Fz4MpQItK0Hnl3I21bpHFxXtP9/4v7kE9PT9cVjEQk4RWVVvDyyq1cdWoubTOb7tKbcd9dIyKSDGa8V3di47dObdqDVoW8iEgj21tZw9MLCzh7WDe6d2jZpG0r5EVEGtmsRQXsrarlmi82fddzTELezL5vZh+a2Sozm2lmmWbW28wWmNl6M5tlZhmxaEtEJJHUhsL8+b1NjMntxPCcDk3eftQhb2bdge8Cee4+FEgFpgL3AL9x977AbuDqaNsSEUk0r364jaLSCr4dwFE8xK67Jg1oaWZpQCtgK/BlYE7k+ceBC2LUlohIQnB3/vTORnKzWjFpUNdAaog65N29CPgVsIW6cC8DlgCl7l4bWa0Q6H6415vZNDNbbGaLNRZeRJLJks27WV5QytWn9Sa1iU5++qxYdNd0BM4HegMnAq2Bycf7enef7u557p6Xnd00p/mKiDSFR97ZSPuW6UwZFdzJh7Horjkd2OjuJe5eAzwHnAp0iHTfAOQARTFoS0QkIWzasY9/rt7GN8f2pFVGcOedxiLktwDjzKyV1V3DahKwGngT+GpknSuAF2LQlohIQvjjvI/JSE3hyiY++emzYtEnv4C6L1iXAisj25wO3AL8wMzWA1nAo9G2JSKSCLaVVTJnSSEX5/Ugu22LQGuJyWcId78DuOMzizcAY2KxfRGRRPKndzYQdpg2oU/QpeiMVxGRWNq1r5q/LtjC+aecSI9OjX/lp2NRyIuIxNCM9zZSWRvi+oknBV0KoJAXEYmZvZU1zHh/E2cNPoG+XdoGXQ6gkBcRiZmnFmxhT2Ut138pPo7iQSEvIhITlTUhHnlnI1/s1zmQiciORCEvIhIDzywuYEd5FddP7Bt0KZ+ikBcRiVJNKMzDb29gZM8OjOvTKehyPkUhLyISpeeWFlJUWsENX+pL3Yn/8UMhLyISheraMA+8sZ7hOe358sAuQZfzOQp5EZEoPLOkgKLSCr5/Rv+4O4oHhbyISINV1Yb4/b/WM6JnByb2j8+p0hXyIiINNGtRAVvLKrn5jAFxeRQPCnkRkQaprAnx4JvrGZPbiVP7ZgVdzhEp5EVEGuCvC7awfU9V3PbFH6CQFxGpp4rqEA+99THj+2Qx/qT4PYoHhbyISL09OX8zO8rrjuLjnUJeRKQe9lXV8vDbH/PFfp0Z0zu+zm49HIW8iEg9PPruRnbuq06Io3hQyIuIHLcd5VX88e2POWtIV0b27Bh0OcdFIS8icpx+90Y+lbVhfjR5YNClHDeFvIjIcdi0Yx9PLdjC10f34KTsNkGXc9xiEvJm1sHM5pjZWjNbY2bjzayTmb1mZvmR28T4bCMichj3zl1HemoK35vUL+hS6iVWR/L3A6+6+0DgZGAN8GPgDXfvB7wReSwiknCWF5Ty0oqtXDOhD13aZQZdTr1EHfJm1h6YADwK4O7V7l4KnA88HlntceCCaNsSEWlq7s4vXllD5zYZTJvQJ+hy6i0WR/K9gRLgz2b2gZk9Ymatga7uvjWyzjag6+FebGbTzGyxmS0uKSmJQTkiIrHz1roS5m/YxXcn9aNNi7Sgy6m3WIR8GjAS+IO7jwD28ZmuGXd3wA/3Ynef7u557p6XnR2fU3WKSPMUCjt3v7KW3KxWXDKmZ9DlNEgsQr4QKHT3BZHHc6gL/e1m1g0gclscg7ZERJrMnCUFrNu+lx+eNZD01MQcjBh11e6+DSgwswGRRZOA1cCLwBWRZVcAL0TblohIU9lTWcO9/1xHXq+OnD3shKDLabBYdTB9B3jKzDKADcCV1P0BmW1mVwObgYtj1JaISKN74PV8du6rZsaVY+J6KuFjiUnIu/syIO8wT02KxfZFRJrS+uJyZry/ia/n9WBo9/ZBlxOVxOxkEhFpJO7O//vHalpmpPJfZw049gvinEJeROQQb6wpZt5HJXzv9P50btMi6HKippAXEYmoqg1x50ur6dulDZeP7xV0OTGhkBcRiXjs3U1s3rmf/zl3cMIOmfys5NgLEZEoFe+p5Pf/yuf0QV2Z0D95TsxUyIuIAHe+tIaakPPf5wwKupSYUsiLSLP31rpi/r78E274Ul9yO7cOupyYUsiLSLNWUR3iJy+sok92a66bmHizTB5L4k2pJiISQ/e/kU/BrgqenjaOFmmpQZcTczqSF5Fma+22PTzyzgYuzsthXJ+soMtpFAp5EWmWwmHn1udW0q5lOrd+Jbm+bD2UQl5EmqW/LtzCB1tK+e9zBtGxdUbQ5TQahbyINDvFeyq559W1nNo3iwtHdA+6nEalkBeRZsXd+e+/raKqNsxdFwxL6GmEj4dCXkSalReWfcLc1dv5rzP70zvJxsQfjkJeRJqN7Xsq+Z8XVjGqV0euPi35xsQfjkJeRJoFd+fHz66gOhTmV187mdSU5O6mOUAhLyLNwjNLCnlzXQm3TB7YLLppDlDIi0jS+6S0gjv/vppxfTpxxfjcoMtpUgp5EUlq7s4tz64g5M69Xz2ZlGbSTXOAQl5EktqTC7bwTv4Objt7ED06tQq6nCYXs5A3s1Qz+8DM/hF53NvMFpjZejObZWbJe0qZiMSlddv2ctc/VjOhfzbfHNsz6HICEcsj+ZuANYc8vgf4jbv3BXYDV8ewLRGRo6qsCfGdmUtpm5nOr792ctKf9HQkMQl5M8sBzgEeiTw24MvAnMgqjwMXxKItEZHjcec/VvPR9nLuu/hkstu2CLqcwMTqSP63wI+AcORxFlDq7rWRx4XAYSeIMLNpZrbYzBaXlJTEqBwRac5eWbmVpxZs4doJfZLqeq0NEXXIm9m5QLG7L2nI6919urvnuXtednbz/scQkegVlVZwy7MrODmnPTefOSDocgIXiytDnQqcZ2ZnA5lAO+B+oIOZpUWO5nOAohi0JSJyRLWhMDfN/ICwwwOXjCAjTQMIo34H3P1Wd89x91xgKvAvd/8m8Cbw1chqVwAvRNuWiMjR/Pb1fBZv3s1dFwylV1bzOav1aBrzz9wtwA/MbD11ffSPNmJbItLMvbZ6O79/cz1fG5XDBUk+R3x9xPRC3u7+FvBW5P4GYEwsty8icjgfl5Tzg1nLGNa9PXdeMDTocuKKOqxEJKGVV9Vy3V+WkJ6WwsOXjSIzPTXokuKKQl5EEpa786M5y/m4pJzfXzKC7h1aBl1S3FHIi0jCmj5vAy+v3MaPvzKQL/TtHHQ5cUkhLyIJ6b31O7jn1bWcM7wb13yxeVzlqSEU8iKScNYXl/OfTy6hb5c2/HLK8GY7L83xUMiLSELZWV7FVTMWkZGWwqNXjKZ1i5gOEkw6endEJGFU1oS45onFbN9TyaxrxzfL+eHrSyEvIgkhHHZufmY5HxSU8tA3RnJKjw5Bl5QQ1F0jIgnhV3PX8dKKrdz6lYF8ZVi3oMtJGAp5EYl7Ty/cwkNvfcw3xvbUSJp6UsiLSFx7ddVWbnt+JRP6Z/Oz84ZoJE09KeRFJG7N+6iE78z8gFN6dODhS0eSnqrIqi+9YyISlxZv2sW0vyymb5e2/PlbY2iVoXEiDaGQF5G48+EnZVw5YxHd2rfkiavG0L5VetAlJSyFvIjElQ0l5Vz+6ELatkjjyW+PbdYX4Y4FhbyIxI0NJeV8408LAPjLt8dqVskYUCeXiMSF/O17+cYjCwiHnSe/PZaTstsEXVJSUMiLSODWbN3DpY8sICXFeHraOPp1bRt0SUlD3TUiEqhVRWVc8qf5pKemMEsBH3M6kheRwHywZTeXP7aQdpnpzLxmHD2zNOFYrEV9JG9mPczsTTNbbWYfmtlNkeWdzOw1M8uP3HaMvlwRSRZvf1TCpY8soFPrDGZfN14B30hi0V1TC9zs7oOBccANZjYY+DHwhrv3A96IPBYRYfaiAq6asYheWa155trxGkXTiKLurnH3rcDWyP29ZrYG6A6cD0yMrPY48BZwS7TtiUjicnd++3o+97+Rz4T+2Tz0zZG00UU/GlVM310zywVGAAuArpE/AADbgK6xbEtEEktNKMxtz63kmSWFfG1UDv970TDNRdMEYhbyZtYGeBb4nrvvOXSmOHd3M/MjvG4aMA2gZ8+esSpHROJI2f4abpy5lHfyd3DTpH587/R+mk2yicTkz6iZpVMX8E+5+3ORxdvNrFvk+W5A8eFe6+7T3T3P3fOys7NjUY6IxJF12/Zy3oPvMn/DTu6ZMozvn9FfAd+EYjG6xoBHgTXuft8hT70IXBG5fwXwQrRtiUhieWnFVi586D32V4d4eto4vj5an9abWiy6a04FLgNWmtmyyLLbgLuB2WZ2NbAZuDgGbYlIAgiFnXv/uY6H3/6YkT078IdLR9G1XWbQZTVLsRhd8y5wpM9ek6LdvogklpK9Vfxg9jLeyd/BN8f25I7/GEJGmr5gDYrGLolIzLy5tpgfzlnO3spa7pkyTN0zcUAhLyJRq6wJcfcra5nx/iYGntCWv14zjv6agyYuKORFJCrrtu3lpqc/YO22vVx5ai63TB5IZnpq0GVJhEJeRBqkJhRm+rwN3P9GPu0y05hx5WgmDugSdFnyGQp5Eam3ZQWl/PjZFazdtpezh53Az84bqsv0xSmFvIgct31Vtfx67kfMeH8jXdpmMv2yUZw55ISgy5KjUMiLyDG5O6+s2sbPX1pDUWkFl43rxY8mD6BtZnrQpckxKORF5KhWFJZy5z9Ws2jTbgae0JY5140nL7dT0GXJcVLIi8hhbS2r4N5X1/HcB0V0bpPBLy4axsV5PUhN0bwziUQhLyKfsqO8iunzNvDEvzcRdvjPiSdx/cST1DWToBTyIgJA8d5Kpr+9gScXbKa6Nsz5p3TnB2f0p0cnXZYvkSnkRZq5T0orePTdjTwVCfcLRnTnxi/1pU92m6BLkxhQyIs0Q+7Oks27+fP7m3h11TYALjilOzd+uS+9O7cOuDqJJYW8SDNSWRPilVVb+fN7m1hRWEa7zDSuPq03l43rpW6ZJKWQF0ly7s6yglKeXVrIi8s+YU9lLSdlt+bOC4YyZWR3WmUoBpKZ/nVFklTBrv38fcUnPLukkI9L9pGZnsJZQ07gq6NyOPWkzqRoKGSzoJAXSRLuTn5xOa+u2sY/P9zGh5/sAWB0bkemTejD2cO6aRhkM6SQF0lg+6pqWbhxF++u38G/1hazccc+AEb27MBtZw9k8pBu9MxSX3tzppAXSSD7q2tZUVjGgg27eG/9DpZu2U1t2MlIS2Fs705cfVpvzhzclS66nqpEKORF4lRNKMzGHftYUVjGB1t288GWUtZt30so7JjBsO7tuWZCH07r25lRvTrqQh1yWAp5kYBV1oTYsms/m3fuZ31xOeu27WHttr18XFJOTcgBaJuZxik9OnDD4L6M6NmBET060KFVRsCVSyJo9JA3s8nA/UAq8Ii7393YbYrEA3dnX3WI3fuqKSmvonhPJdv3VLFtTyXb91RSuLuCzTv3sX1P1aded2L7TAac0JaJA7ow8IS2DDmxHSdlt9FoGGmQRg15M0sFHgTOAAqBRWb2oruvbsx2pensr66lZG8VO8qrKN1fw57KGvZU1LKnou7+vuoQldUhKmrqfiprQlTVhgmFnZqQUxuqux/yuiPWyM1BqSlGih24NVJTjLQUIy01hdQUIz3VSEtJIT3VSE9NIS01cj8lhbTIsvRUiyxPIS3l/7Zx4DYlxTCrayfFDAMcCLvjXhfWYa/rPqkOhamp9YP391fXsr86xP6qEPsi98sqaijdX0NZRfXBI/FDpaUYXdq2oHvHlpzWN5vcrFb0zGpFr6zW9O7cmvYtNQJGYqexj+THAOvdfQOAmT0NnA8o5BNAKOwU7a6gYPd+Cnfvp3B3ReRnP8V7q9ixt4p91aEjvr5leiqtW6TRMiOFzLRUWmakkpmWSpsWaZGA/b8ATjE4cJxqVnfP3Qk5hMNO2L3uj0HYqY3c1oTCVNWEKQ/V1v3BCIepCdUtrwmFqQ051ZHbA8/FyoE/MK0y0miVkRr5qbvfr0sbOrTKoEOrdDq2SqdDywyy27agS7sWdG2XSadWGToqlybT2CHfHSg45HEhMLaR25R6cneKSitYVVTGR9vLWV9cTn5xORtKyqmqDR9cLzXF6NY+k+4dWnJyTgey27agc5sWdG6TQee2LejUKoN2LdNpl5lG28x0MtJSAtyrzztwRF4bPhD8dX8sDiw/cBt2JyVyZG9mWOQIP/3gJ4MUzakuCSPwL17NbBowDaBnz54BV9M8lO6vZvGm3awoLGV5YRkri8rYta/64PM5HVvSt0sbTuubRd8ubejZqTU5HVvSrX0maanxFdz1YWakGqSmpNIi8N98kabR2L/qRUCPQx7nRJYd5O7TgekAeXl5sfs8LQeV7q9mwcZdzN+wk/kbdrF22x7c647M+3Vpw+mDujAspwPDurenf9c2mstEJIk09v/mRUA/M+tNXbhPBb7RyG02e+7O2m17eWPNdt5YW8yyglLcoUVaCnm5HfnB6f0Z2yeLYd3b0zJDY6tFklmjhry715rZjcA/qRtC+Zi7f9iYbTZX4bCzcNMuXl65lTfWFFNUWgHA8Jz23DSpH6f27czwnPa0SFOoizQnjf653N1fBl5u7HaaqzVb9/C3ZUX8fdknfFJWSWZ6Cqf17cyNX+7Llwd2oatObxdp1tT5moDK9tcwZ2khsxcVsG77XtJSjAn9s7nlKwM5Y3BX9amLyEFKgwSysrCMv8zfxIvLP6GyJsyInh248/whnD2sG1ltWgRdnojEIYV8nKsNhXlp5VYee3cjywvLaJWRykUjc7h0bC8Gn9gu6PJEJM4p5ONUVW2IZ5cU8fDbH7Nl135Oym7Nz84bwoUju9NOF34QkeOkkI8z+6pqmblwC396ZwPb91Rxck57bj9nFGcM6qpT4UWk3hTycaImFObpRQXc/3o+O8qr+MJJWdx38Sl84aSsg3O5iIjUl0I+YO7Oq6u2ce8/17Fhxz7G5Hbij5eNZFSvTkGXJiJJQCEfoCWbd3HXS2v4YEsp/bq04ZHL85g0qIuO3EUkZhTyAdi1r5q7X1nD7MWFdG3XgnumDGPKyJyEnvxLROKTQr4JhcPOM0sKuPuVteytrOXaCX347qR+tNaUiCLSSJQuTWTdtr3c/vxKFm/ezejcjtx1wTAGnNA26LJEJMkp5BtZKOz86Z0N3Df3I1q3SOWXU4bz1VE5Gg4pIk1CId+Ituzcz83PLGPRpt2cNaQr/3vhME0/ICJNSiHfCNydWYsKuPMfq0kx476LT+bCEd01akZEmpxCPsbK9tdw8zPLeX3Ndr5wUhb3fu1kundoGXRZItJMKeRjaFVRGf/51BK2lVXyk3MHc+UXctX3LiKBUsjHgLszc2EBP/37h3RuncHsa8czomfHoMsSEVHIR6uiOsTtf1vJc0uL+GK/ztw/dQSdWmcEXZaICKCQj0pRaQVXz1jEuu17uWlSP747qR+p6p4RkTiikG+gFYWlXP34YiqrQzz2rdF8aUCXoEsSEfkchXwDvLpqK9+btYys1i146vqx9O+qM1dFJD5FNSOWmd1rZmvNbIWZPW9mHQ557lYzW29m68zsrKgrjQPuzsNvf8x1Ty5lULd2/O2GUxXwIhLXop328DVgqLsPBz4CbgUws8HAVGAIMBl4yMxSo2wrULWhMLc9v5K7X1nLOcO7MfOacWS31dmrIhLfogp5d5/r7rWRh/OBnMj984Gn3b3K3TcC64Ex0bQVpKraEN+Z+QEzFxZww5dO4ndTR5CZntB/s0SkmYhln/xVwKzI/e7Uhf4BhZFln2Nm04BpAD179oxhObFRUR3i2ieXMO+jEn5y7mCuPq130CWJiBy3Y4a8mb0OnHCYp2539xci69wO1AJP1bcAd58OTAfIy8vz+r6+Me2prOHqGYtYsnk3v5wynItH9wi6JBGRejlmyLv76Ud73sy+BZwLTHL3AyFdBByaiDmRZQljZ3kVV/x5Ieu27eV3l4zknOHdgi5JRKTeoh1dMxn4EXCeu+8/5KkXgalm1sLMegP9gIXRtNWUSvZW8fXp88nfXs70y/MU8CKSsKLtk/890AJ4LTKN7nx3v87dPzSz2cBq6rpxbnD3UJRtNYld+6q59JEFFO2u4PGrxjCuT1bQJYmINFhUIe/ufY/y3M+Bn0ez/aZWVlHDZY8uYNPOffz5W6MV8CKS8KIdJ5809lbWcMVjC8nfXs4fLxvFF/p2DrokEZGoKeSB/dW1XDVjEauKyvj9N0YwUfPQiEiSaPYhX1kT4tuPL2bJ5t3cP3UEZw453GhREZHE1KwnKAuHnZufWc77H+/kvotP1igaEUk6zfpI/n9fXsNLK7Zy29kDuWhkzrFfICKSYJptyD/67kYeeXcj3/pCLtd8sU/Q5YiINIpmGfIvrdjKXS+tZvKQE/jJuYOJjPEXEUk6zS7kF27cxfdnL2NUz478duopulyfiCS1ZhXyG0rK+fbji8jp2JI/XZ6n6YJFJOk1m5DfU1nDNU8sJi01hcevHEPH1hlBlyQi0uiaRciHws73nl7G5p37eeibI+nRqVXQJYmINIlmEfK/nruOf60t5o7zhmg+GhFpVpI+5P++/BMeeutjLhnTk0vHxt+Vp0REGlNSh/yqojJ+OGc5eb068rPzhmiopIg0O0kb8jvLq7j2L0vo2CqDP1w6ioy0pN1VEZEjSsq5a8Jh5/uzl1NSXsWc68aT3bZF0CWJiAQiKQ9vH573MfM+KuEn5w5meE6HoMsREQlM0oX8ok27+PXcjzhneDd90SoizV5ShfyufdV8568fkNOxJXdfNExftIpIs5c0ffLhsPOD2cvYta+a567/Am0z04MuSUQkcElzJP/HeRt4a10JPzl3EEO7tw+6HBGRuBCTkDezm83Mzaxz5LGZ2QNmtt7MVpjZyFi0cySLN+3iV3PXcc6wblw6rldjNiUiklCiDnkz6wGcCWw5ZPFXgH6Rn2nAH6Jt52gy01M5tW9nfjFF/fAiIoeKxZH8b4AfAX7IsvOBJ7zOfKCDmTXaBVSHdm/PE1eNoZ364UVEPiWqkDez84Eid1/+mae6AwWHPC6MLDvcNqaZ2WIzW1xSUhJNOSIi8hnHHF1jZq8DJxzmqduB26jrqmkwd58OTAfIy8vzY6wuIiL1cMyQd/fTD7fczIYBvYHlkX7wHGCpmY0BioAeh6yeE1kmIiJNqMHdNe6+0t27uHuuu+dS1yUz0t23AS8Cl0dG2YwDytx9a2xKFhGR49VYJ0O9DJwNrAf2A1c2UjsiInIUMQv5yNH8gfsO3BCrbYuISMMkzRmvIiLyeQp5EZEkZnU9K/HBzEqAzQ18eWdgRwzLCZL2JT4ly74ky36A9uWAXu6efbgn4irko2Fmi909L+g6YkH7Ep+SZV+SZT9A+3I81F0jIpLEFPIiIkksmUJ+etAFxJD2JT4ly74ky36A9uWYkqZPXkREPi+ZjuRFROQzkirkzezOyJWolpnZXDM7MeiaGsrM7jWztZH9ed7MOgRdU0OZ2dfM7EMzC5tZwo2EMLPJZrYucqWzHwddT0OZ2WNmVmxmq4KuJVpm1sPM3jSz1ZHfrZuCrqkhzCzTzBaa2fLIfvws5m0kU3eNmbVz9z2R+98FBrv7dQGX1SBmdibwL3evNbN7ANz9loDLahAzGwSEgT8C/+XuiwMu6biZWSrwEXAGdZPwLQIucffVgRbWAGY2ASin7oI+Q4OuJxqRixB1c/elZtYWWAJckGj/LlY3hW9rdy83s3TgXeCmyMWWYiKpjuQPBHxEaz59taqE4u5z3b028nA+ddM1JyR3X+Pu64Kuo4HGAOvdfYO7VwNPU3fls4Tj7vOAXUHXEQvuvtXdl0bu7wXWcIQLE8WzyNXzyiMP0yM/Mc2tpAp5ADP7uZkVAN8E/ifoemLkKuCVoItopo77KmcSDDPLBUYACwIupUHMLNXMlgHFwGvuHtP9SLiQN7PXzWzVYX7OB3D32929B/AUcGOw1R7dsfYlss7tQC11+xO3jmdfRGLNzNoAzwLf+8wn+YTh7iF3P4W6T+tjzCymXWmNNZ98oznSlaoO4ynq5rW/oxHLicqx9sXMvgWcC0zyOP/ypB7/LolGVzmLU5E+7GeBp9z9uaDriZa7l5rZm8BkIGZfjifckfzRmFm/Qx6eD6wNqpZomdlk4EfAee6+P+h6mrFFQD8z621mGcBU6q58JgGKfGH5KLDG3e8Lup6GMrPsAyPnzKwldV/wxzS3km10zbPAAOpGcmwGrnP3hDzqMrP1QAtgZ2TR/AQeKXQh8DsgGygFlrn7WYEWVQ9mdjbwWyAVeMzdfx5sRQ1jZjOBidTNdrgduMPdHw20qAYys9OAd4CV1P1/B7jN3V8Orqr6M7PhwOPU/W6lALPd/f/FtI1kCnkREfm0pOquERGRT1PIi4gkMYW8iEgSU8iLiCQxhbyISBJTyIuIJDGFvIhIElPIi4gksf8P49VH+I9HxDQAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出这个函数的曲线\n", "x_sample = np.arange(-3, 3.1, 0.1)\n", "y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 + w_target[2] * x_sample ** 3\n", "\n", "plt.plot(x_sample, y_sample, label='real curve')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接着我们可以构建数据集,需要 x 和 y,同时是一个三次多项式,所以我们取了 $x,\\ x^2, x^3$" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# 构建数据 x 和 y\n", "# x 是一个如下矩阵 [x, x^2, x^3]\n", "# y 是函数的结果 [y]\n", "\n", "x_train = np.stack([x_sample ** i for i in range(1, 4)], axis=1)\n", "x_train = torch.from_numpy(x_train).float() # 转换成 float tensor\n", "\n", "y_train = torch.from_numpy(y_sample).float().unsqueeze(1) # 转化成 float tensor " ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([61, 3])\n" ] } ], "source": [ "print(x_train.size())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接着我们可以定义需要优化的参数,就是前面这个函数里面的 $w_i$" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# 定义参数和模型\n", "w = Variable(torch.randn(3, 1), requires_grad=True)\n", "b = Variable(torch.zeros(1), requires_grad=True)\n", "\n", "# 将 x 和 y 转换成 Variable\n", "x_train = Variable(x_train)\n", "y_train = Variable(y_train)\n", "\n", "def multi_linear(x):\n", " return torch.mm(x, w) + b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们可以画出没有更新之前的模型和真实的模型之间的对比" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新之前的模型\n", "y_pred = multi_linear(x_train)\n", "\n", "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以发现,这两条曲线之间存在差异,我们计算一下他们之间的误差" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(509.5237, grad_fn=)\n" ] } ], "source": [ "# 计算误差,这里的误差和一元的线性模型的误差是相同的,前面已经定义过了 get_loss\n", "loss = get_loss(y_pred, y_train)\n", "print(loss)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "# 自动求导\n", "loss.backward()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ -64.6688],\n", " [ -84.8521],\n", " [-431.2343]])\n", "tensor([-16.0116])\n" ] } ], "source": [ "# 查看一下 w 和 b 的梯度\n", "print(w.grad)\n", "print(b.grad)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# 更新一下参数\n", "w.data = w.data - 0.001 * w.grad.data\n", "b.data = b.data - 0.001 * b.grad.data" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新一次之后的模型\n", "y_pred = multi_linear(x_train)\n", "\n", "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "因为只更新了一次,所以两条曲线之间的差异仍然存在,我们进行 100 次迭代" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 20, Loss: 1.55844\n", "epoch 40, Loss: 0.53303\n", "epoch 60, Loss: 0.28755\n", "epoch 80, Loss: 0.22432\n", "epoch 100, Loss: 0.20383\n" ] } ], "source": [ "# 进行 100 次参数更新\n", "for e in range(100):\n", " y_pred = multi_linear(x_train)\n", " loss = get_loss(y_pred, y_train)\n", " \n", " w.grad.data.zero_()\n", " b.grad.data.zero_()\n", " loss.backward()\n", " \n", " # 更新参数\n", " w.data = w.data - 0.001 * w.grad.data\n", " b.data = b.data - 0.001 * b.grad.data\n", " if (e + 1) % 20 == 0:\n", " print('epoch {}, Loss: {:.5f}'.format(e+1, loss.data.item()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到更新完成之后 loss 已经非常小了,我们画出更新之后的曲线对比" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新之后的结果\n", "y_pred = multi_linear(x_train)\n", "\n", "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,经过 100 次更新之后,可以看到拟合的线和真实的线已经完全重合了" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "**小练习:上面的例子是一个三次的多项式,尝试使用二次的多项式去拟合它,看看最后能做到多好**\n", "\n", "**提示:参数 `w = torch.randn(2, 1)`,同时重新构建 x 数据集**" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }