{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 线性模型和梯度下降\n", "这是神经网络的第一课,我们会学习一个非常简单的模型,线性回归,同时也会学习一个优化算法-梯度下降法,对这个模型进行优化。线性回归是监督学习里面一个非常简单的模型,同时梯度下降也是深度学习中应用最广的优化算法,我们将从这里开始我们的深度学习之旅" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 一元线性回归\n", "一元线性模型非常简单,假设我们有变量 $x_i$ 和目标 $y_i$,每个 i 对应于一个数据点,希望建立一个模型\n", "\n", "$$\n", "\\hat{y}_i = w x_i + b\n", "$$\n", "\n", "$\\hat{y}_i$ 是我们预测的结果,希望通过 $\\hat{y}_i$ 来拟合目标 $y_i$,通俗来讲就是找到这个函数拟合 $y_i$ 使得误差最小,即最小化\n", "\n", "$$\n", "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "那么如何最小化这个误差呢?\n", "\n", "这里需要用到**梯度下降**,这是我们接触到的第一个优化算法,非常简单,但是却非常强大,在深度学习中被大量使用,所以让我们从简单的例子出发了解梯度下降法的原理" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 梯度下降法\n", "在梯度下降法中,我们首先要明确梯度的概念,随后我们再了解如何使用梯度进行下降。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 梯度\n", "梯度在数学上就是导数,如果是一个多元函数,那么梯度就是偏导数。比如一个函数f(x, y),那么 f 的梯度就是 \n", "\n", "$$\n", "(\\frac{\\partial f}{\\partial x},\\ \\frac{\\partial f}{\\partial y})\n", "$$\n", "\n", "可以称为 grad f(x, y) 或者 $\\nabla f(x, y)$。具体某一点 $(x_0,\\ y_0)$ 的梯度就是 $\\nabla f(x_0,\\ y_0)$。\n", "\n", "下面这个图片是 $f(x) = x^2$ 这个函数在 x=1 处的梯度\n", "\n", "![](https://ws3.sinaimg.cn/large/006tNc79ly1fmarbuh2j3j30ba0b80sy.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "梯度有什么意义呢?从几何意义来讲,一个点的梯度值是这个函数变化最快的地方,具体来说,对于函数 f(x, y),在点 $(x_0, y_0)$ 处,沿着梯度 $\\nabla f(x_0,\\ y_0)$ 的方向,函数增加最快,也就是说沿着梯度的方向,我们能够更快地找到函数的极大值点,或者反过来沿着梯度的反方向,我们能够更快地找到函数的最小值点。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 梯度下降法\n", "有了对梯度的理解,我们就能了解梯度下降发的原理了。上面我们需要最小化这个误差,也就是需要找到这个误差的最小值点,那么沿着梯度的反方向我们就能够找到这个最小值点。\n", "\n", "我们可以来看一个直观的解释。比如我们在一座大山上的某处位置,由于我们不知道怎么下山,于是决定走一步算一步,也就是在每走到一个位置的时候,求解当前位置的梯度,沿着梯度的负方向,也就是当前最陡峭的位置向下走一步,然后继续求解当前位置梯度,向这一步所在位置沿着最陡峭最易下山的位置走一步。这样一步步的走下去,一直走到觉得我们已经到了山脚。当然这样走下去,有可能我们不能走到山脚,而是到了某一个局部的山峰低处。\n", "\n", "类比我们的问题,就是沿着梯度的反方向,我们不断改变 w 和 b 的值,最终找到一组最好的 w 和 b 使得误差最小。\n", "\n", "在更新的时候,我们需要决定每次更新的幅度,比如在下山的例子中,我们需要每次往下走的那一步的长度,这个长度称为学习率,用 $\\eta$ 表示,这个学习率非常重要,不同的学习率都会导致不同的结果,学习率太小会导致下降非常缓慢,学习率太大又会导致跳动非常明显,可以看看下面的例子\n", "\n", "![](https://ws2.sinaimg.cn/large/006tNc79ly1fmgn23lnzjg30980gogso.gif)\n", "\n", "可以看到上面的学习率较为合适,而下面的学习率太大,就会导致不断跳动\n", "\n", "最后我们的更新公式就是\n", "\n", "$$\n", "w := w - \\eta \\frac{\\partial f(w,\\ b)}{\\partial w} \\\\\n", "b := b - \\eta \\frac{\\partial f(w,\\ b)}{\\partial b}\n", "$$\n", "\n", "通过不断地迭代更新,最终我们能够找到一组最优的 w 和 b,这就是梯度下降法的原理。\n", "\n", "最后可以通过这张图形象地说明一下这个方法\n", "\n", "![](https://ws3.sinaimg.cn/large/006tNc79ly1fmarxsltfqj30gx091gn4.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上面是原理部分,下面通过一个例子来进一步学习线性模型" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import numpy as np\n", "from torch.autograd import Variable\n", "\n", "torch.manual_seed(2017)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# 读入数据 x 和 y\n", "x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],\n", " [9.779], [6.182], [7.59], [2.167], [7.042],\n", " [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)\n", "\n", "y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],\n", " [3.366], [2.596], [2.53], [1.221], [2.827],\n", " [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAPrElEQVR4nO3df4gc933G8ec5SdS+OMRtdSSqrLstNKQkprbSxbVrKMauwU2NXagLLlvXKSkHIW3sYih1BC4JXEmhuD9iiFnsNEq7uAmySV0TtxWJITE0CitV/iUZYqjubFepznYt293UraJP/5gVkq67t7On2ZvZ77xfsMzMd0e7H4a7R9+b/cysI0IAgOk3U3YBAIBiEOgAkAgCHQASQaADQCIIdABIxNay3nj79u3RaDTKensAmEoHDx58LSLmBj1XWqA3Gg11u92y3h4AppLt5WHPccoFABJBoANAIkYGuu2LbH/P9jO2X7D92QH7fNz2qu3D/cfvTqZcAMAwec6hvyvp+oh4x/Y2SU/bfjIivrtmv69GxO8VXyIAII+RgR7ZzV7e6W9u6z+4AQwAVEyuc+i2t9g+LOmEpP0RcWDAbr9u+1nb+2zvGvI6i7a7trurq6sXUDYATJ9OR2o0pJmZbNnpFPv6uQI9In4UEVdKukzSVbYvX7PLP0hqRMTPSdovae+Q12lHRDMimnNzA9soASBJnY60uCgtL0sR2XJxsdhQH6vLJSLelPSUpJvWjL8eEe/2Nx+S9PPFlAcAadizR+r1zh/r9bLxouTpcpmzfWl//WJJN0p6cc0+O87ZvEXS0eJKBIDpt7Iy3vhG5Oly2SFpr+0tyv4D+FpEPGH7c5K6EfG4pE/bvkXSKUlvSPp4cSUCwPSbn89OswwaL0qeLpdnJe0eMH7fOev3Srq3uLIAIC1LS9k583NPu8zOZuNF4UpRANgErZbUbksLC5KdLdvtbLwopd2cCwDqptUqNsDXYoYOAIkg0AEka9IX8lQNp1wAJOnMhTxnPoQ8cyGPNNnTHmVihg4gSZtxIU/VEOgAkrQZF/JUDYEOIEnDLtgp8kKeqiHQASRpaSm7cOdcRV/IUzUEOoAkbcaFPFVDlwuAZE36Qp6qYYYOAIkg0AEgEQQ6ACSCQAeARBDoAJAIAh0AEkGgA0AiCHRgDHW7HSumCxcWATnV8XasmC7M0IGc6ng7VkwXAh3IqY63Y8V0IdCBnOp4O1ZMFwIdyKmOt2PFdCHQgZzqeDtWTBe6XIAx1O12rJguzNABIBEEOgAkgkAHgEQQ6ACQCAIdABJBoANAIgh0AEgEgQ4AiSDQASARIwPd9kW2v2f7Gdsv2P7sgH1+zPZXbb9k+4DtxiSKBQAMl2eG/q6k6yPiCklXSrrJ9tVr9vmEpP+MiJ+R9OeS/rTYMgEAo4wM9Mi809/c1n/Emt1ulbS3v75P0g22XViVAICRcp1Dt73F9mFJJyTtj4gDa3bZKellSYqIU5JOSvrJAa+zaLtru7u6unphlQMAzpMr0CPiRxFxpaTLJF1l+/KNvFlEtCOiGRHNubm5jbwEAGCIsbpcIuJNSU9JumnNU69K2iVJtrdKep+k14soEACQT54ulznbl/bXL5Z0o6QX1+z2uKQ7++u3SfpWRKw9zw4AmKA8X3CxQ9Je21uU/QfwtYh4wvbnJHUj4nFJD0v6G9svSXpD0u0TqxgAMNDIQI+IZyXtHjB+3znr/y3pN4otDQAwDq4UBRLX6UiNhjQzky07nbIrwqTwnaJAwjodaXFR6vWy7eXlbFviu1FTxAwdSNiePWfD/IxeLxtHegh0IGErK+ONY7oR6EDC5ufHG8d0I9CBhC0tSbOz54/NzmbjSA+BDkxIFbpLWi2p3ZYWFiQ7W7bbfCCaKrpcgAmoUndJq0WA1wUzdGAC6C5BGQh0YALoLkEZCHRgAuguQRkIdGAC6C5BGQj0mqhCx0Wd0F2CMtDlUgNV6rioE7pLsNmYodcAHRdAPRDoNUDHBVAPBHoN0HEB1AOBXgN0XAD1QKDXAB0XQD3Q5VITdFwA6WOGDgCJINABIBEEOgAkgkAHgEQQ6ACQCAIdABJBoANAIgh0JI9bB6MuuLAISePWwagTZuhIGrcORp0Q6Egatw5GnRDoSBq3DkadEOhIGrcORp0Q6EhaSrcOplsHo9DlguSlcOtgunWQx8gZuu1dtp+yfcT2C7bvGrDPdbZP2j7cf9w3mXKBeqJbB3nkmaGfknRPRByy/V5JB23vj4gja/b7TkTcXHyJAOjWQR4jZ+gRcTwiDvXX35Z0VNLOSRcG4Cy6dZDHWB+K2m5I2i3pwICnr7H9jO0nbX9kyL9ftN213V1dXR27WKCu6NZBHrkD3fYlkh6VdHdEvLXm6UOSFiLiCklfkPT1Qa8REe2IaEZEc25ubqM1A7WTUrcOJscRMXone5ukJyT9U0Tcn2P/Y5KaEfHasH2azWZ0u90xSgUA2D4YEc1Bz+XpcrGkhyUdHRbmtj/Q30+2r+q/7usbLxkAMK48XS7XSrpD0nO2D/fHPiNpXpIi4kFJt0n6pO1Tkn4o6fbIM/UHABRmZKBHxNOSPGKfByQ9UFRRAIDxcek/ACSCQAeARBDoAJAIAh0AEkGgA0AiCHQASASBDgCJINABIBEEOgAkgkAHgEQQ6ACQCAIdABJBoANAIgh0AEgEgQ4AiSDQASARBDoAJIJAB4BEEOgAkAgCHQASQaADQCIIdABIBIEOAIkg0AEgEQQ6ACSCQAeARBDoAJAIAh2l63SkRkOamcmWnU7ZFQHTaWvZBaDeOh1pcVHq9bLt5eVsW5JarfLqAqYRM3SUas+es2F+Rq+XjQMYD4GOUq2sjDcOYDgCHaWanx9vHMBwBDpKtbQkzc6ePzY7m40DGA+BjlK1WlK7LS0sSHa2bLf5QBTYCLpcULpWiwAHijByhm57l+2nbB+x/YLtuwbsY9t/Zfsl28/a/uhkygUADJNnhn5K0j0Rccj2eyUdtL0/Io6cs8+vSPpg//ELkr7YXwIANsnIGXpEHI+IQ/31tyUdlbRzzW63SvpKZL4r6VLbOwqvFgAw1FgfitpuSNot6cCap3ZKevmc7Vf0/0Nfthdtd213V1dXx6sUALCu3IFu+xJJj0q6OyLe2sibRUQ7IpoR0Zybm9vISwAAhsgV6La3KQvzTkQ8NmCXVyXtOmf7sv4YAGCT5OlysaSHJR2NiPuH7Pa4pN/ud7tcLelkRBwvsE4AwAh5ulyulXSHpOdsH+6PfUbSvCRFxIOSviHpY5JektST9DvFlwoAWM/IQI+IpyV5xD4h6VNFFQUAGB+X/gNAIgh0AEgEgQ4AiSDQASARBDoAJIJAB4BEEOgAkAgCHQASQaADQCIIdABIBIEOAIkg0AEgEQQ6ACSCQAeARBDoAJAIAh0AEkGgA0AiCHQASASBXqBOR2o0pJmZbNnplF0RNhs/AyhTni+JRg6djrS4KPV62fbycrYtSa1WeXVh8/AzgLI5+37nzddsNqPb7Zby3pPQaGS/wGstLEjHjm12NSgDPwPYDLYPRkRz0HOccinIysp440gPPwMoG4FekPn58cbrpg7nlvkZQNkI9IIsLUmzs+ePzc5m43V35tzy8rIUcfbccmqhzs8AykagF6TVktrt7HypnS3bbT4Mk6Q9e85+UHhGr5eNp4SfAZSND0UxcTMz2cx8LVs6fXrz6wGmGR+KolScWwY2B4GOiePcMrA5CHRMHOeWgc1BoCei6m2BrVZ2cc3p09mSMAeKx6X/CeCScwASM/Qk1KUtEMD6CPQEcMk5AIlATwJtgQAkAj0JtAUCkHIEuu0v2T5h+/khz19n+6Ttw/3HfcWXifXQFghAytfl8mVJD0j6yjr7fCcibi6kImxIq0WAA3U3coYeEd+W9MYm1AIAuABFnUO/xvYztp+0/ZFhO9letN213V1dXS3orQEAUjGBfkjSQkRcIekLkr4+bMeIaEdEMyKac3NzBbw1AOCMCw70iHgrIt7pr39D0jbb2y+4MgDAWC440G1/wLb761f1X/P1C31dAMB4Rna52H5E0nWSttt+RdIfS9omSRHxoKTbJH3S9ilJP5R0e5T1rRkAUGMjAz0ifnPE8w8oa2sEAJSIK0UBIBEEOgAkgkAHgEQQ6ACQCAIdABJBoANAIgh0AEgEgQ4AiSDQASARBPqYOh2p0ZBmZrJlp1N2RQCQyfONRejrdKTFRanXy7aXl7NtiW8LAlA+Zuhj2LPnbJif0etl4wBQNgJ9DCsr440DwGYi0McwPz/eOABsJgJ9DEtL0uzs+WOzs9k4AJSNQB9DqyW129LCgmRny3abD0QBVMNUBXoVWgZbLenYMen06WxJmAOoiqlpW6RlEADWNzUzdFoGAWB9UxPotAwCwPqmJtBpGQSA9U1NoNMyCADrm5pAp2UQANY3NV0uUhbeBDgADDY1M3QAwPoIdABIBIEOAIkg0AEgEQQ6ACTCEVHOG9urkpZz7Lpd0msTLmcacVyG49gMxnEZbpqOzUJEzA16orRAz8t2NyKaZddRNRyX4Tg2g3Fchkvl2HDKBQASQaADQCKmIdDbZRdQURyX4Tg2g3Fchkvi2FT+HDoAIJ9pmKEDAHIg0AEgEZUMdNu7bD9l+4jtF2zfVXZNVWJ7i+1/tf1E2bVUie1Lbe+z/aLto7avKbumqrD9B/3fpedtP2L7orJrKovtL9k+Yfv5c8Z+wvZ+29/vL3+8zBo3qpKBLumUpHsi4sOSrpb0KdsfLrmmKrlL0tGyi6igv5T0jxHxs5KuEMdIkmR7p6RPS2pGxOWStki6vdyqSvVlSTetGfsjSd+MiA9K+mZ/e+pUMtAj4nhEHOqvv63sF3NnuVVVg+3LJP2qpIfKrqVKbL9P0i9JeliSIuJ/IuLNcquqlK2SLra9VdKspH8vuZ7SRMS3Jb2xZvhWSXv763sl/dqmFlWQSgb6uWw3JO2WdKDcSirjLyT9oaTTZRdSMT8taVXSX/dPRz1k+z1lF1UFEfGqpD+TtCLpuKSTEfHP5VZVOe+PiOP99R9Ien+ZxWxUpQPd9iWSHpV0d0S8VXY9ZbN9s6QTEXGw7FoqaKukj0r6YkTslvRfmtI/m4vWPx98q7L/9H5K0nts/1a5VVVXZL3cU9nPXdlAt71NWZh3IuKxsuupiGsl3WL7mKS/k3S97b8tt6TKeEXSKxFx5i+5fcoCHtIvS/q3iFiNiP+V9JikXyy5pqr5D9s7JKm/PFFyPRtSyUC3bWXnQo9GxP1l11MVEXFvRFwWEQ1lH2p9KyKYaUmKiB9Ietn2h/pDN0g6UmJJVbIi6Wrbs/3frRvEB8ZrPS7pzv76nZL+vsRaNqySga5sJnqHshno4f7jY2UXhcr7fUkd289KulLSn5RcTyX0/2rZJ+mQpOeU/d4ncan7Rth+RNK/SPqQ7Vdsf0LS5yXdaPv7yv6i+XyZNW4Ul/4DQCKqOkMHAIyJQAeARBDoAJAIAh0AEkGgA0AiCHQASASBDgCJ+D/WmKZIW+19fgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出图像\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "plt.plot(x_train, y_train, 'bo')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([2.2691], requires_grad=True)\n" ] } ], "source": [ "# 转换成 Tensor\n", "x_train = torch.from_numpy(x_train)\n", "y_train = torch.from_numpy(y_train)\n", "\n", "# 定义参数 w 和 b\n", "w = Variable(torch.randn(1), requires_grad=True) # 随机初始化\n", "b = Variable(torch.zeros(1), requires_grad=True) # 使用 0 进行初始化\n", "print(w)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# 构建线性回归模型\n", "x_train = Variable(x_train)\n", "y_train = Variable(y_train)\n", "\n", "def linear_model(x):\n", " return x * w + b" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "y_ = linear_model(x_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "经过上面的步骤我们就定义好了模型,在进行参数更新之前,我们可以先看看模型的输出结果长什么样" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAWJUlEQVR4nO3df4zU9Z3H8dd7cSuuEM/ihqB0d2nTcCI/FlgNnlfKCQJXTYWYNGf2FJI22Fo827Re9PhDE91rc2nlzv6hbpVTy9arxR81PdJikYY2paW7HloKhk28XVxEWdGj8ivA7vv+mNkF1pmd2Zn5zvf7mXk+ksnMfGeYee9X5zWf+Xw/38/H3F0AgPDUxF0AAKAwBDgABIoAB4BAEeAAECgCHAACdUE53+yyyy7zpqamcr4lAASvq6vrfXevH7m9rAHe1NSkzs7Ocr4lAATPzHozbacLBQAClTPAzexTZrbNzPaY2Z/N7O709gfM7ICZ7UpfvhB9uQCAIfl0oZyR9C13f83MJkrqMrNX0o+td/fvRVceACCbnAHu7gclHUzf/sjM9kq6olQFnD59Wn19fTp58mSpXrLqjR8/XlOnTlVtbW3cpQCI0JgOYppZk6S5kv4g6TpJa83sdkmdSrXSP8zwb9ZIWiNJDQ0NH3vNvr4+TZw4UU1NTTKzsdaPEdxdhw8fVl9fn6ZNmxZ3OQAilPdBTDObIOl5Sd9w979IelTSZyQ1K9VC/36mf+fu7e7e4u4t9fUfGwWjkydPatKkSYR3iZiZJk2axC8aICk6OqSmJqmmJnXd0VGyl86rBW5mtUqFd4e7vyBJ7v7eOY//UNLPCy2C8C4t9ieQEB0d0po10vHjqfu9van7ktTaWvTL5zMKxSQ9KWmvuz98zvYp5zxtpaTdRVcDAJVk3bqz4T3k+PHU9hLIpwvlOkm3Sbp+xJDBfzOzP5nZG5L+TtI3S1JRgJqamvT+++/HXQaApNm/f2zbxyhngLv7b93d3H22uzenL5vd/TZ3n5Xe/sX0aJXIRdidJCl1EHBwcLC0LwqgOmUYuDHq9jEK6kzMoe6k3l7J/Wx3UrEh3tPTo+nTp+v222/XzJkz9eCDD+rqq6/W7Nmzdf/99w8/b8WKFZo/f76uuuoqtbe3F/nXAKh4bW1SXd352+rqUttLIKgAj7I7qbu7W3feeafWr1+vAwcOaOfOndq1a5e6urq0fft2SdKGDRvU1dWlzs5OPfLIIzp8+HDxbwygcrW2Su3tUmOjZJa6bm8vyQFMqcyTWRUryu6kxsZGLViwQN/+9re1ZcsWzZ07V5J09OhRdXd3a+HChXrkkUf04osvSpLefvttdXd3a9KkScW/OYDK1dpassAeKagAb2hIdZtk2l6siy++WFKqD/y+++7THXfccd7jv/71r/WrX/1KO3bsUF1dnRYtWsRYawCxCqoLJeLuJEnSsmXLtGHDBh09elSSdODAAR06dEhHjhzRpZdeqrq6Or355pv6/e9/X7o3BYACBNUCH/oVsm5dqtukoSEV3qX8dbJ06VLt3btX1157rSRpwoQJ2rhxo5YvX67HHntMV155paZPn64FCxaU7k0BoADm7mV7s5aWFh+5oMPevXt15ZVXlq2GasF+BSqHmXW5e8vI7UF1oQAAziLAASBQBDgABIoAB4BAEeAAECgCHAACRYCPwVNPPaV33nln+P5XvvIV7dmzp+jX7enp0Y9//OMx/7vVq1dr06ZNRb8/gDCFF+BRzyc7ipEB/sQTT2jGjBlFv26hAQ6guoUV4BHNJ7tx40Zdc801am5u1h133KGBgQGtXr1aM2fO1KxZs7R+/Xpt2rRJnZ2dam1tVXNzs06cOKFFixZp6MSkCRMm6J577tFVV12lJUuWaOfOnVq0aJE+/elP6+WXX5aUCurPfe5zmjdvnubNm6ff/e53kqR7771Xv/nNb9Tc3Kz169drYGBA99xzz/CUto8//rik1Dwta9eu1fTp07VkyRIdOnSoqL8bqEgxNvLKzt3Ldpk/f76PtGfPno9ty6qx0T0V3edfGhvzf40M73/TTTf5qVOn3N39a1/7mj/wwAO+ZMmS4ed8+OGH7u7++c9/3v/4xz8Obz/3viTfvHmzu7uvWLHCb7jhBj916pTv2rXL58yZ4+7ux44d8xMnTri7+759+3xof2zbts1vvPHG4dd9/PHH/cEHH3R395MnT/r8+fP9rbfe8ueff96XLFniZ86c8QMHDvgll1ziP/3pT7P+XUDV2bjRva7u/Hyoq0ttD5ikTs+QqUHNhRLFfLJbt25VV1eXrr76aknSiRMntHz5cr311lu66667dOONN2rp0qU5X+cTn/iEli9fLkmaNWuWLrzwQtXW1mrWrFnq6emRJJ0+fVpr167Vrl27NG7cOO3bty/ja23ZskVvvPHGcP/2kSNH1N3dre3bt+vWW2/VuHHjdPnll+v6668v+O8GKtJoiwZENKVrnMIK8Ajmk3V3rVq1St/5znfO297W1qZf/vKXeuyxx/Tcc89pw4YNo75ObW3t8GrwNTU1uvDCC4dvnzlzRpK0fv16TZ48Wa+//roGBwc1fvz4rDX94Ac/0LJly87bvnnz5oL+RqBqRLwGZdKE1QcewXyyixcv1qZNm4b7kz/44AP19vZqcHBQt9xyix566CG99tprkqSJEyfqo48+Kvi9jhw5oilTpqimpkY/+tGPNDAwkPF1ly1bpkcffVSnT5+WJO3bt0/Hjh3TwoUL9ZOf/EQDAwM6ePCgtm3bVnAtQEWKeA3KpAmrBR7BfLIzZszQQw89pKVLl2pwcFC1tbV6+OGHtXLlyuHFjYda56tXr9ZXv/pVXXTRRdqxY8eY3+vOO+/ULbfcomeeeUbLly8fXkRi9uzZGjdunObMmaPVq1fr7rvvVk9Pj+bNmyd3V319vV566SWtXLlSr776qmbMmKGGhobhKW8BpLW1pQY2nNuNUupFAxKE6WQrFPsVVaujI9pFA2KQbTrZsFrgAJBLhGtQJk1YfeAAgGGJCPByduNUA/YnUB1iD/Dx48fr8OHDhE6JuLsOHz6cdYgigMoRex/41KlT1dfXp/7+/rhLqRjjx4/X1KlT4y4DQMRiD/Da2lpNmzYt7jIAIDixd6EAAApDgANAoAhwAAgUAQ4AgSLAASBQOQPczD5lZtvMbI+Z/dnM7k5v/6SZvWJm3enrS6MvFwAwJJ8W+BlJ33L3GZIWSPq6mc2QdK+kre7+WUlb0/cBAGWSM8Dd/aC7v5a+/ZGkvZKukHSzpKfTT3ta0oqoigQAfNyY+sDNrEnSXEl/kDTZ3Q+mH3pX0uQs/2aNmXWaWSdnWwJA6eQd4GY2QdLzkr7h7n8597H0opsZJzNx93Z3b3H3lvr6+qKKBQCclVeAm1mtUuHd4e4vpDe/Z2ZT0o9PkXQomhIBAJnkMwrFJD0paa+7P3zOQy9LWpW+vUrSz0pfHgAgm3wms7pO0m2S/mRmu9Lb/kXSdyU9Z2ZfltQr6UvRlAgAyCRngLv7byVZlocXl7YcAEC+OBMTAAJFgANAoAhwAAgUAQ4AgSLAASBQBDgABIoAB4BAEeBAJerokJqapJqa1HVHR9wVIQL5nIkJICQdHdKaNdLx46n7vb2p+5LU2hpfXSg5WuBApVm37mx4Dzl+PLUdFYUAByrN/v1j245gEeBApWloGNt2BIsABypNW5tUV3f+trq61HZUFAIcqDStrVJ7u9TYKJmlrtvbOYBZgRiFAlSi1lYCuwrQAgeSjPHcGAUtcCCpGM+NHGiBA0nFeG7kQIADScV4buRAgANJxXhu5ECAA0nFeG7kQIADScV4buTAKBQgyRjPjVHQAgeAQBHgABAoAhwAAkWAA6XEqe8oIw5iAqXCqe8oM1rgQDHObXGvWsWp7ygrWuBAoUa2uAcGMj+PU98REVrgQKEyTTaVCae+IyIEOFCofFrWnPqOCBHgQC7ZRpZka1mPG8ep7ygL+sCB0Yw2sqSt7fzHpFSLm9BGmeRsgZvZBjM7ZGa7z9n2gJkdMLNd6csXoi0TiMloiyow2RRiZu4++hPMFko6KukZd5+Z3vaApKPu/r2xvFlLS4t3dnYWWCoQg5oaKdNnxEwaHCx/PahKZtbl7i0jt+dsgbv7dkkfRFIVkHQsqoAEK+Yg5lozeyPdxXJptieZ2Roz6zSzzv7+/iLeDogBiyogwQoN8EclfUZSs6SDkr6f7Ynu3u7uLe7eUl9fX+DbATGhnxsJVtAoFHd/b+i2mf1Q0s9LVhGQNCyqgIQqqAVuZlPOubtS0u5szwUARCNnC9zMnpW0SNJlZtYn6X5Ji8ysWZJL6pF0R4Q1AgAyyGcUyq3uPsXda919qrs/6e63ufssd5/t7l9094PlKBbIG/NyowpwJiYqD/Nyo0owFwoqz2hnTwIVhABH5ck2SyDzcqPCEOCoPJw9iSpBgKPycPYkqgQBjsrD2ZOoEoxCQWXi7ElUAVrgABAoAhwAAkWAA0CgCHAACBQBDgCBIsARDyabAorGMEKUH5NNASVBCxzlx2RTQEkQ4Cg/JpsCSoIAR3Sy9XMz2RRQEvSBIxqj9XO3tZ3/mMRkU0ABCHBEY7R+7p6es8/Zvz/V8m5r4wAmMEbm7mV7s5aWFu/s7Czb+yFGNTVSpv+3zKTBwfLXAwTMzLrcvWXkdvrAEQ36uYHIEeCIBosqAJEjwBENFlUAIsdBTESHRRWASNECB4BAEeAAECgCHAACRYADQKAIcAAIFAEOAIEiwAEgUAQ4AASKAAeAQOUMcDPbYGaHzGz3Ods+aWavmFl3+vrSaMtEQVg4GKho+bTAn5K0fMS2eyVtdffPStqavo8kGVpQobc3Na3r0IIKhDhQMXIGuLtvl/TBiM03S3o6fftpSStKXBeKxcLBQMUrtA98srsfTN9+V9LkbE80szVm1mlmnf39/QW+HcaMhYOBilf0QUxPLemTdVkfd2939xZ3b6mvry/27ZAvFlQAKl6hAf6emU2RpPT1odKVhLyNdpCSBRWAildogL8saVX69ipJPytNOchbroOULKgAVLycixqb2bOSFkm6TNJ7ku6X9JKk5yQ1SOqV9CV3H3mg82NY1LiEmppSoT1SY+PZVd8BVIRsixrnXJHH3W/N8tDioqtC4ThICVQ9zsQMFQcpgapHgIeKg5RA1SPAQ8VBSqDqsSp9yFj1HahqtMABIFAEOAAEigAHgEAR4AAQKAIcAAJFgANAoAhwAAgUAQ4AgSLAASBQBHixWPkdQEw4lb4YQ4sqDC0ePLSogsQp7gAiRwu8GKz8DiBGBHgxWFQBQIwI8Hxk6+dmUQUAMaIPPJfR+rnb2s5/TGJRBQBlQ4DnMlo/99DiwevWpbpNGhpS4c0BTABlkHNV+lIKclX6mhop0z4ykwYHy18PgKqTbVV6+sBzoZ8bQEIR4LmweDCAhCLAc2HxYAAJxUHMfLB4MIAEogUOAIEiwAEgUAQ4AASKAAeAQBHgABAoAhwAAkWAA0CgCHAACBQBDgCBKupMTDPrkfSRpAFJZzLNlgUAiEYpWuB/5+7NkYU3q74DQEbJnguFVd8BIKtiW+AuaYuZdZnZmkxPMLM1ZtZpZp39/f1je3VWfQeArIoN8L9193mS/l7S181s4cgnuHu7u7e4e0t9ff3YXp1V3wEgq6IC3N0PpK8PSXpR0jWlKGoYq+EAQFYFB7iZXWxmE4duS1oqaXepCpPEajgAMIpiWuCTJf3WzF6XtFPSf7v7L0pTVhqr4QBAVqxKDwAJx6r0AFBhCHAACBQBDgCBIsABIFAEOABEKMrpnAhwABUlSfPfDU3n1NsruZ+dzqlUNRHgQA5JCgSMLurAHKuop3NiHDgwipETYkqpk4E5nyyZmppSoT1SY6PU01PualJf+pki1kwaHMz/dRgHDhSACTHDkrT576KezokAB0Yx1kBISndLUuoot6TNfxf1dE4EeAWr1g9xKY0lEJLS/5qUOuKQtPnvIp/Oyd3Ldpk/f76jPDZudK+rc099hFOXurrU9mqwcaN7Y6O7Weq60L97LPuxsfH85w1dGhsL/zsKkZQ64lKq//ZJIqnTM2QqAV6hqvlDXOovr3wDwSzzPjcr9C8pLIyiqAPxyhbgjEKpUKU6+h2iuEYilPp9Cx0Bk7SRGCgeo1CqTNIO5pRTXCMRSt3/WugImKT1AyM6BHiFquYPcVxfXqU+YFXoFxHroFQPArxCJf1DHOUImTi/vFpbU90Ug4Op62L2dzFfRKWsA8lFgFewpH6Iox7mlvQvr3xV868o5IeDmCg7DrLlr6Mj1ee9f3+q5d3WFt4XEYrHQcyIcLJMdtn2TdJOdy5G1P/9k/orCslAgBch7jPekvzlMdq+qZQRMnH/9wc4kacIcZ4sk/QzLUfbN0mvPV/VfLIUyktZTuShBZ6HJHYFJH2WvNH2TaUcZKykriCEKfEBHnc3QVK7ApIeHrn2TSX07VZKVxDClegAT0If42gt3TiHeSU9PKphCFw1/I1IuEz9KlFdxtoHnoQ+xlwTA8U181kI/ciVOCvcSNXwNyJ+CnEyqyRMyJTkMcuMEQaqQ5DjwJPQTZDkn8mV0I8MoHCJDvAkhGeljJgAUHkuiLuA0QyFZNzdBK2tBDaA5El0gEuEJwBkk+guFABAdgQ4AASKAAeAQBHgABAoAhwAAlXWMzHNrF9ShvMaz3OZpPfLUE6I2DeZsV+yY99kF9K+aXT3+pEbyxrg+TCzzkynjIJ9kw37JTv2TXaVsG/oQgGAQBHgABCoJAZ4e9wFJBj7JjP2S3bsm+yC3zeJ6wMHAOQniS1wAEAeCHAACFRiAtzMPmVm28xsj5n92czujrumJDGzcWb2P2b287hrSRIz+ysz22Rmb5rZXjO7Nu6aksLMvpn+LO02s2fNbHzcNcXBzDaY2SEz233Otk+a2Stm1p2+vjTOGguVmACXdEbSt9x9hqQFkr5uZjNirilJ7pa0N+4iEug/JP3C3f9a0hyxjyRJZnaFpH+S1OLuMyWNk/QP8VYVm6ckLR+x7V5JW939s5K2pu8HJzEB7u4H3f219O2PlPogXhFvVclgZlMl3SjpibhrSRIzu0TSQklPSpK7n3L3/4u3qkS5QNJFZnaBpDpJ78RcTyzcfbukD0ZsvlnS0+nbT0taUdaiSiQxAX4uM2uSNFfSH+KtJDH+XdI/SyrTUs7BmCapX9J/pruXnjCzi+MuKgnc/YCk70naL+mgpCPuviXeqhJlsrsfTN9+V9LkOIspVOIC3MwmSHpe0jfc/S9x1xM3M7tJ0iF374q7lgS6QNI8SY+6+1xJxxToT+FSS/fp3qzUl9zlki42s3+Mt6pk8tRY6iDHUycqwM2sVqnw7nD3F+KuJyGuk/RFM+uR9F+SrjezjfGWlBh9kvrcfeiX2ialAh3SEkn/6+797n5a0guS/ibmmpLkPTObIknp60Mx11OQxAS4mZlSfZl73f3huOtJCne/z92nunuTUgehXnV3WlKS3P1dSW+b2fT0psWS9sRYUpLsl7TAzOrSn63F4gDvuV6WtCp9e5Wkn8VYS8ESE+BKtTRvU6qFuSt9+ULcRSHx7pLUYWZvSGqW9K8x15MI6V8lmyS9JulPSn3Wgz91vBBm9qykHZKmm1mfmX1Z0ncl3WBm3Ur9WvlunDUWilPpASBQSWqBAwDGgAAHgEAR4AAQKAIcAAJFgANAoAhwAAgUAQ4Agfp/1cKknXhPKTMAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**思考:红色的点表示预测值,似乎排列成一条直线,请思考一下这些点是否在一条直线上?**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这个时候需要计算我们的误差函数,也就是\n", "\n", "$$\n", "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n", "$$" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# 计算误差\n", "def get_loss(y_, y):\n", " return torch.mean((y_ - y) ** 2)\n", "\n", "loss = get_loss(y_, y_train)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(153.3520, grad_fn=)\n" ] } ], "source": [ "# 打印一下看看 loss 的大小\n", "print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义好了误差函数,接下来我们需要计算 w 和 b 的梯度了,这时得益于 PyTorch 的自动求导,我们不需要手动去算梯度,有兴趣的同学可以手动计算一下,w 和 b 的梯度分别是\n", "\n", "$$\n", "\\frac{\\partial}{\\partial w} = \\frac{2}{n} \\sum_{i=1}^n x_i(w x_i + b - y_i) \\\\\n", "\\frac{\\partial}{\\partial b} = \\frac{2}{n} \\sum_{i=1}^n (w x_i + b - y_i)\n", "$$" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# 自动求导\n", "loss.backward()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([161.0043])\n", "tensor([22.8730])\n" ] } ], "source": [ "# 查看 w 和 b 的梯度\n", "print(w.grad)\n", "print(b.grad)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# 更新一次参数\n", "w.data = w.data - 1e-2 * w.grad.data\n", "b.data = b.data - 1e-2 * b.grad.data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "更新完成参数之后,我们再一次看看模型输出的结果" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAV/klEQVR4nO3df3BV5Z3H8c83IYoB1lrMOFqaXDqzw8oP+RUdXLeURUQqTovj/lEmbaW7TmwtLt22dnT4w+6otX/sSLU7o2YptS2x7YradTpuh6q4tFOU3tBgW1CYxQSDtITYUvm1/Mh3/7g3AeK93Jvknnuee8/7NZPJvedezv3mZPLhOc95nueYuwsAEK6auAsAAJwfQQ0AgSOoASBwBDUABI6gBoDAjYlip5deeqmnUqkodg0AVamjo+Oguzfkei2SoE6lUkqn01HsGgCqkpl153uNrg8ACFzBoDazKWbWedbXX8zsS+UoDgBQRNeHu78paZYkmVmtpH2Snou4LgBA1nD7qK+X9L/unrcvJZ+TJ0+qp6dHx48fH+4/RR5jx47VpEmTVFdXF3cpACI03KD+lKQf5nrBzFoltUpSY2Pj+17v6enRhAkTlEqlZGbDrRNDuLv6+vrU09OjyZMnx10OgAgVfTHRzC6Q9AlJT+d63d3b3L3Z3ZsbGt4/wuT48eOaOHEiIV0iZqaJEydyhgKEoL1dSqWkmprM9/b2ku5+OC3qj0va5u5/HOmHEdKlxfEEAtDeLrW2SkePZp53d2eeS1JLS0k+YjjD85YrT7cHACTW6tVnQnrA0aOZ7SVSVFCb2ThJN0h6tmSfXGFSqZQOHjwYdxkAQrN37/C2j0BRQe3uR9x9orsfKtknFxBll4+7q7+/v3Q7BJBcOQZPnHf7CAQ5M3Ggy6e7W3I/0+UzmrDu6urSlClT9NnPflbTp0/X/fffr6uvvlpXXXWV7rvvvsH3LVu2THPnztW0adPU1tZWgp8GQFV78EGpvv7cbfX1me0lEmRQR9Xls3v3bt15551as2aN9u3bp61bt6qzs1MdHR3avHmzJGndunXq6OhQOp3Wo48+qr6+vtF9KIDq1tIitbVJTU2SWeZ7W1vJLiRKES3KNFpRdfk0NTVp3rx5+upXv6qNGzdq9uzZkqTDhw9r9+7dmj9/vh599FE991xm4uXbb7+t3bt3a+LEiaP7YADVraWlpME8VJBB3diY6e7ItX00xo0bJynTR33vvffqjjvuOOf1V155RS+++KK2bNmi+vp6LViwgHHKAGIXZNdH1F0+N954o9atW6fDhw9Lkvbt26cDBw7o0KFDuuSSS1RfX6833nhDr776amk+EABGIcgW9cAZxOrVme6OxsZMSJfqzGLx4sXauXOnrr32WknS+PHjtX79ei1ZskSPP/64rrzySk2ZMkXz5s0rzQcCwCiYu5d8p83NzT70xgE7d+7UlVdeWfLPSjqOK1AdzKzD3ZtzvRZk1wcA4AyCGgACR1ADQOAIagAIHEENAIEjqAEgcAR1Dk8++aTeeeedwee33367duzYMer9dnV16amnnhr2v1uxYoU2bNgw6s8HUJnCDeqIb21zPkODeu3atZo6deqo9zvSoAaQbGEGdRTrnEpav369rrnmGs2aNUt33HGHTp8+rRUrVmj69OmaMWOG1qxZow0bNiidTqulpUWzZs3SsWPHtGDBAg1M4Bk/frzuvvtuTZs2TYsWLdLWrVu1YMECfeQjH9Hzzz8vKRPIH/3oRzVnzhzNmTNHv/rVryRJ99xzj37xi19o1qxZWrNmjU6fPq277757cLnVJ554QlJmLZKVK1dqypQpWrRokQ4cODCqnxtAhXP3kn/NnTvXh9qxY8f7tuXV1OSeiehzv5qait9Hjs+/+eab/cSJE+7u/oUvfMG//vWv+6JFiwbf86c//cnd3T/2sY/5r3/968HtZz+X5C+88IK7uy9btsxvuOEGP3HihHd2dvrMmTPd3f3IkSN+7Ngxd3fftWuXDxyPTZs2+dKlSwf3+8QTT/j999/v7u7Hjx/3uXPn+p49e/yZZ57xRYsW+alTp3zfvn1+8cUX+9NPP5335wJQ+SSlPU+mBrnWRxTrnL700kvq6OjQ1VdfLUk6duyYlixZoj179uiuu+7S0qVLtXjx4oL7ueCCC7RkyRJJ0owZM3ThhReqrq5OM2bMUFdXlyTp5MmTWrlypTo7O1VbW6tdu3bl3NfGjRv1+uuvD/Y/Hzp0SLt379bmzZu1fPly1dbW6oorrtDChQtH/HMDqHxhdn1EcGsbd9dtt92mzs5OdXZ26s0339Qjjzyi7du3a8GCBXr88cd1++23F9xPXV3d4N2/a2pqdOGFFw4+PnXqlCRpzZo1uuyyy7R9+3al02mdOHEib03f/va3B2t66623ivrPAki8GK9hxSHMoI5gndPrr79eGzZsGOzvfffdd9Xd3a3+/n7deuuteuCBB7Rt2zZJ0oQJE/Tee++N+LMOHTqkyy+/XDU1NfrBD36g06dP59zvjTfeqMcee0wnT56UJO3atUtHjhzR/Pnz9eMf/1inT5/W/v37tWnTphHXAlSdiK5hhSzMro8I1jmdOnWqHnjgAS1evFj9/f2qq6vTww8/rFtuuWXwRrcPPfSQpMxwuM9//vO66KKLtGXLlmF/1p133qlbb71V3//+97VkyZLBGxZcddVVqq2t1cyZM7VixQqtWrVKXV1dmjNnjtxdDQ0N+slPfqJbbrlFL7/8sqZOnarGxsbB5VgB6Pz36ovwLitxYpnTCsdxReLU1GRa0kOZSdlGVyVimVMA1SOCa1ihI6gBVJao79UXoLIGdRTdLEnG8UQitbRIbW1SU1Omu6OpKfO8SvunpTJeTBw7dqz6+vo0ceLEweFtGDl3V19fn8aOHRt3KUD5tbRUdTAPVVRQm9kHJK2VNF2SS/pHdx/WcIhJkyapp6dHvb29w68SOY0dO1aTJk2KuwwAESu2Rf2IpJ+5+z+Y2QWS6gv9g6Hq6uo0efLk4f4zAEi8gkFtZhdLmi9phSS5+wlJuafaAQBKrpiLiZMl9Ur6rpn9xszWmtm4oW8ys1YzS5tZmu4NACidYoJ6jKQ5kh5z99mSjki6Z+ib3L3N3ZvdvbmhoaHEZQJAchUT1D2Setz9tezzDcoENwCgDAoGtbv/QdLbZjYlu+l6SaO/LxUAoCjFjvq4S1J7dsTHHkmfi64kAMDZigpqd++UlHOxEABAtFjrAwACR1ADQOAIagAIHEENAIEjqAEgcAQ1AASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHAENQAEjqAGgMAR1AAQOIIaAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDAEdQAEDiCGqhk7e1SKiXV1GS+t7fHXREiMCbuAgCMUHu71NoqHT2aed7dnXkuSS0t8dWFkiuqRW1mXWb2WzPrNLN01EUBKMLq1WdCesDRo5ntqCrDaVH/vbsfjKwSAMOzd+/wtqNi0UcNVKrGxuFtR8UqNqhd0kYz6zCz1lxvMLNWM0ubWbq3t7d0FQLI7cEHpfr6c7fV12e2o6oUG9R/5+5zJH1c0hfNbP7QN7h7m7s3u3tzQ0NDSYsEkENLi9TWJjU1SWaZ721tXEisQkX1Ubv7vuz3A2b2nKRrJG2OsjAARWhpIZgToGCL2szGmdmEgceSFkv6XdSFAQAyimlRXybpOTMbeP9T7v6zSKsCAAwq2KJ29z3uPjP7Nc3duVIBlBKzC1EAMxOBODG7EEVgHDUQJ2YXoggENRAnZheiCAQ1ECdmF6IIBDUQJ2YXoggENRAnZheiCIz6AOLG7EIUQIsaAAJHUANA4AhqAAgcQQ0AgSOogZFgfQ6UEaM+gOFifQ6UGS1qYLhYnwNlRlADxTi7q6O7O/d7WJ8DEaHrAyhkaFdHPqzPgYjQogYKydXVMRTrcyBCBDVQyPm6NFifA2VA1wdQSGNj7n7ppiapq6vs5SB5aFED0vnHRbMUKWJGUAMDFwu7uyX3M+OiB8KapUgRM3P3ku+0ubnZ0+l0yfcLRCKVomsDsTOzDndvzvUaLWqA+xYicAQ1wH0LETiCGuBiIQJXdFCbWa2Z/cbMfhplQUDZcbEQgRvOOOpVknZK+quIagHiw30LEbCiWtRmNknSUklroy0HADBUsV0f35L0NUn9EdYCAMihYFCb2c2SDrh7R4H3tZpZ2szSvb29JSsQAJKumBb1dZI+YWZdkn4kaaGZrR/6Jndvc/dmd29uaGgocZkAkFwFg9rd73X3Se6ekvQpSS+7+6cjrwwohPsWIiFYPQ+VifsWIkFY6wOVifU5UGVY6wPVh/U5kCAENSoT63MgQQhqVCbW50CCENSoTKzPgQRh1AcqF+tzICFoUQNA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoBRinp9MIbnAcAolGN9MFrUADAKq1efCekBR49mtpcKQY3osF40EqAc64MR1IjGwPlgd7fkfuZ8kLBGCYTUBijH+mAENaJRjvNBJFJobYByrA9GUCMarBeNiITWBijH+mAENUYv13ko60UjIiG2AVpaMjcW6u/PfC/1WmEENUYn33noTTexXjQikcQ2AEGN0cl3HvrCC6wXjUgk8Z4RBDVG53znoVGfD5ZQSKMIcH5JvGcEMxMxOo2Nue8GXkHnoeWYWYbSSto9I2hRY3Sq4Dw0tFEEwFAENUanCs5DQxxFAJyNrg+MXoWfh1ZB7w2qHC1qJF4V9N6gyhUMajMba2ZbzWy7mf3ezP61HIUB5VIFvTeocsV0ffyfpIXuftjM6iT90sz+291fjbg2oGwqvPcGVa5gULu7SzqcfVqX/fIoiwIAnFFUH7WZ1ZpZp6QDkn7u7q/leE+rmaXNLN3b21vqOgEgsYoKanc/7e6zJE2SdI2ZTc/xnjZ3b3b35oaGhlLXCSAHZlQmw7BGfbj7nyVtkrQkmnIAFCu0dZkRnWJGfTSY2Qeyjy+SdIOkN6IuDAhZCC1ZZlQmRzGjPi6X9D0zq1Um2P/T3X8abVlAuEJZG4QZlclhmUEdpdXc3OzpdLrk+wVCkErlnsnY1JRZJDBpdaA0zKzD3ZtzvcbMxCoQwml4koTSkmVGZXIQ1BWOC0rlF8odRphRmRx0fVQ4Tn/Lb2gftZRpyRKSGA26PqrY3r3ScrXrLaV0WjV6SyktVzsXlCJESxblxjKnFW7lB9v1UF+rxinTvEupW/+hVl36QUkiOaLC2iAoJ1rUFe4bWj0Y0gPG6ai+oeofTMtFVCQFLeoKN/7d3H0c+bZXi1DGMgPlQIu60oUyBKHMmJWHJCGoQ1fo/D6hg2lDGctcCnThoBCCOmTFDJJO6BCEajmRYBw8isE46pAxSDqvahnLzK8YAxhHXamq6fy+xKrlRIJfMYrBqI+QNTbmbm5V2vl9RKphLDO/YhSDFnXIEnqhMEn4FaMYBHXIquX8HnnxK0YxuJgIAAHgYiIAVDCCGgACR1ADQOAIagAIHEENAIEjqAEgcAQ1AASOoEZkWL4TKA3W+kAkuAMLUDq0qBEJ7sAClE7BoDazD5vZJjPbYWa/N7NV5SgMlY3lO4HSKaZFfUrSV9x9qqR5kr5oZlOjLQuVrlruwAKEoGBQu/t+d9+WffyepJ2SPhR1YcHhytiwsHwnUDrD6qM2s5Sk2ZJey/Faq5mlzSzd29tbmupCwY3tho3lO4HSKXqZUzMbL+l/JD3o7s+e771Vt8wpN7YDELFRL3NqZnWSnpHUXiikqxJXxgDEqJhRHybpO5J2uvvD0ZcUIK6MAYhRMS3q6yR9RtJCM+vMft0UcV3xyHfBkCtjAGJUcGaiu/9SkpWhlngVM5Vu9epMd0djYyakuTIGoAy4Z+IALhgCiBH3TCwGFwwBBIqgHlDggiHzXZKN3z/iRFAPOM8FQ+a7nF+1hxi/f8TO3Uv+NXfuXK9I69e7NzW5m2W+r1/v7pmHmT/Rc7+ammKsNRDr17vX1597XOrrBw9dVeD3j3KQlPY8mcrFxCLU1GT+NIcyk/r7y19PSJJwDZbfP8qBi4mjFOd8l9C7FZJwDZb5TogbQV2EuOa7VELfaBJCjPlOiBtBXYS4VoKrhLukJCHEWAkQcaOPOmCV0jfa3s6kTWC0ztdHzc1tA9bYmPtCXWjdCi0tBDMQJbo+ApaEbgUAhRHUAaNvFIBE10fw6FYAQIsaAAJHUJ8l9MklAJKJro+sYu4bAABxoEWdVQmTSwAkUzBBHXe3QxLWrABQmYII6vZ26cXPteuV7pROeY1e6U7pxc+1lzWsk7BmBYDKFERQv7aqXf9+slUpdatGrpS69e8nW/XaqvIlNZNLAIQqiKD+ct9qjdO5HcTjdFRf7itfBzGTSwCEKohFmfqtRjV6fx39MtV4QKsPAUBEgr9xwNGJuTuC820HgCQJIqjHP/KgTl1wbgfxqQvqNf4ROogBIIigVkuLxqw7t4N4zDo6iAFAKmJmopmtk3SzpAPuPj2ySlh9CAByKqZF/aSkJRHXAQDIo2BQu/tmSe+WoRYAQA4l66M2s1YzS5tZure3t1S7BYDEK1lQu3ubuze7e3NDQ0OpdgsAiRfGqA8AQF6RrEfd0dFx0Mxy3D/7HJdKOhjF51c4jkt+HJv8ODa5VdJxacr3QsEp5Gb2Q0kLlPmB/yjpPnf/zmgrMrN0vumSScZxyY9jkx/HJrdqOS4FW9TuvrwchQAAcqOPGgACF2dQt8X42SHjuOTHscmPY5NbVRyXSJY5BQCUDl0fABA4ghoAAlfWoDazD5vZJjPbYWa/N7NV5fz8SmBmtWb2GzP7ady1hMTMPmBmG8zsDTPbaWbXxl1TCMzsX7J/S78zsx+a2di4a4qLma0zswNm9ruztn3QzH5uZruz3y+Js8aRKneL+pSkr7j7VEnzJH3RzKaWuYbQrZK0M+4iAvSIpJ+5+99ImimOkczsQ5L+WVJzdgniWkmfireqWD2p96/0eY+kl9z9ryW9lH1eccoa1O6+3923ZR+/p8wf24fKWUPIzGySpKWS1sZdS0jM7GJJ8yV9R5Lc/YS7/zneqoIxRtJFZjZGUr2kd2KuJzZ5Vvr8pKTvZR9/T9KyshZVIrH1UZtZStJsSa/FVUOAviXpa5K4o++5JkvqlfTdbLfQWjMbF3dRcXP3fZL+TdJeSfslHXL3jfFWFZzL3H1/9vEfJF0WZzEjFUtQm9l4Sc9I+pK7/yWOGkJjZgN30emIu5YAjZE0R9Jj7j5b0hFV6ClsKWX7Wz+pzH9kV0gaZ2afjreqcHlmLHJFjkcue1CbWZ0yId3u7s+W+/MDdp2kT5hZl6QfSVpoZuvjLSkYPZJ63H3g7GuDMsGddIskveXuve5+UtKzkv425ppC80czu1ySst8PxFzPiJR71Icp08+4090fLudnh87d73X3Se6eUuaC0MvuTutIkrv/QdLbZjYlu+l6STtiLCkUeyXNM7P67N/W9eIi61DPS7ot+/g2Sf8VYy0jVu4W9XWSPqNMa7Ez+3VTmWtAZbpLUruZvS5plqRvxFxP7LJnGBskbZP0W2X+nqtiyvRIZFf63CJpipn1mNk/SfqmpBvMbLcyZyDfjLPGkWIKOQAEjpmJABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHAENQAE7v8BJOC0kEQrnGYAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y_ = linear_model(x_train)\n", "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的例子可以看到,更新之后红色的线跑到了蓝色的线下面,没有特别好的拟合蓝色的真实值,所以我们需要在进行几次更新" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.2934)\n", "epoch: 0, loss: 0.2933782935142517\n", "tensor(0.2927)\n", "epoch: 1, loss: 0.29273974895477295\n", "tensor(0.2921)\n", "epoch: 2, loss: 0.29210466146469116\n", "tensor(0.2915)\n", "epoch: 3, loss: 0.2914726436138153\n", "tensor(0.2908)\n", "epoch: 4, loss: 0.29084399342536926\n", "tensor(0.2902)\n", "epoch: 5, loss: 0.2902185022830963\n", "tensor(0.2896)\n", "epoch: 6, loss: 0.28959622979164124\n", "tensor(0.2890)\n", "epoch: 7, loss: 0.28897714614868164\n", "tensor(0.2884)\n", "epoch: 8, loss: 0.28836125135421753\n", "tensor(0.2877)\n", "epoch: 9, loss: 0.2877485454082489\n" ] } ], "source": [ "for e in range(10): # 进行 10 次更新\n", " y_ = linear_model(x_train)\n", " loss = get_loss(y_, y_train)\n", " \n", " w.grad.zero_() # 记得归零梯度\n", " b.grad.zero_() # 记得归零梯度\n", " loss.backward()\n", " print(loss.data)\n", " w.data = w.data - 1e-2 * w.grad.data # 更新 w\n", " b.data = b.data - 1e-2 * b.grad.data # 更新 b \n", " print('epoch: {}, loss: {}'.format(e, loss.item()))" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y_ = linear_model(x_train)\n", "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "经过 10 次更新,我们发现红色的预测结果已经比较好的拟合了蓝色的真实值。\n", "\n", "现在你已经学会了你的第一个机器学习模型了,再接再厉,完成下面的小练习。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**小练习:**\n", "\n", "重启 notebook 运行上面的线性回归模型,但是改变训练次数以及不同的学习率进行尝试得到不同的结果" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 多项式回归模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们更进一步,讲一讲多项式回归。什么是多项式回归呢?非常简单,根据上面的线性回归模型\n", "\n", "$$\n", "\\hat{y} = w x + b\n", "$$\n", "\n", "这里是关于 x 的一个一次多项式,这个模型比较简单,没有办法拟合比较复杂的模型,所以我们可以使用更高次的模型,比如\n", "\n", "$$\n", "\\hat{y} = w_0 + w_1 x + w_2 x^2 + w_3 x^3 + \\cdots\n", "$$\n", "\n", "这样就能够拟合更加复杂的模型,这就是多项式模型,这里使用了 x 的更高次,同理还有多元回归模型,形式也是一样的,只是出了使用 x,还是更多的变量,比如 y、z 等等,同时他们的 loss 函数和简单的线性回归模型是一致的。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "首先我们可以先定义一个需要拟合的目标函数,这个函数是个三次的多项式" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3\n" ] } ], "source": [ "# 定义一个多变量函数\n", "\n", "w_target = np.array([0.5, 3, 2.4]) # 定义参数\n", "b_target = np.array([0.9]) # 定义参数\n", "\n", "f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(\n", " b_target[0], w_target[0], w_target[1], w_target[2]) # 打印出函数的式子\n", "\n", "print(f_des)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们可以先画出这个多项式的图像" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出这个函数的曲线\n", "x_sample = np.arange(-3, 3.1, 0.1)\n", "y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 + w_target[2] * x_sample ** 3\n", "\n", "plt.plot(x_sample, y_sample, label='real curve')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接着我们可以构建数据集,需要 x 和 y,同时是一个三次多项式,所以我们取了 $x,\\ x^2, x^3$" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# 构建数据 x 和 y\n", "# x 是一个如下矩阵 [x, x^2, x^3]\n", "# y 是函数的结果 [y]\n", "\n", "x_train = np.stack([x_sample ** i for i in range(1, 4)], axis=1)\n", "x_train = torch.from_numpy(x_train).float() # 转换成 float tensor\n", "\n", "y_train = torch.from_numpy(y_sample).float().unsqueeze(1) # 转化成 float tensor " ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([61, 3])\n" ] } ], "source": [ "print(x_train.size())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接着我们可以定义需要优化的参数,就是前面这个函数里面的 $w_i$" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# 定义参数和模型\n", "w = Variable(torch.randn(3, 1), requires_grad=True)\n", "b = Variable(torch.zeros(1), requires_grad=True)\n", "\n", "# 将 x 和 y 转换成 Variable\n", "x_train = Variable(x_train)\n", "y_train = Variable(y_train)\n", "\n", "def multi_linear(x):\n", " return torch.mm(x, w) + b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们可以画出没有更新之前的模型和真实的模型之间的对比" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新之前的模型\n", "y_pred = multi_linear(x_train)\n", "\n", "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以发现,这两条曲线之间存在差异,我们计算一下他们之间的误差" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(413.9844, grad_fn=)\n" ] } ], "source": [ "# 计算误差,这里的误差和一元的线性模型的误差是相同的,前面已经定义过了 get_loss\n", "loss = get_loss(y_pred, y_train)\n", "print(loss)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# 自动求导\n", "loss.backward()" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ -34.1391],\n", " [-146.6133],\n", " [-215.9149]])\n", "tensor([-27.0838])\n" ] } ], "source": [ "# 查看一下 w 和 b 的梯度\n", "print(w.grad)\n", "print(b.grad)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "# 更新一下参数\n", "w.data = w.data - 0.001 * w.grad.data\n", "b.data = b.data - 0.001 * b.grad.data" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新一次之后的模型\n", "y_pred = multi_linear(x_train)\n", "\n", "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "因为只更新了一次,所以两条曲线之间的差异仍然存在,我们进行 100 次迭代" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 20, Loss: 73.67843\n", "epoch 40, Loss: 17.97095\n", "epoch 60, Loss: 4.94101\n", "epoch 80, Loss: 1.87171\n", "epoch 100, Loss: 1.12812\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:14: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n", " \n" ] } ], "source": [ "# 进行 100 次参数更新\n", "for e in range(100):\n", " y_pred = multi_linear(x_train)\n", " loss = get_loss(y_pred, y_train)\n", " \n", " w.grad.data.zero_()\n", " b.grad.data.zero_()\n", " loss.backward()\n", " \n", " # 更新参数\n", " w.data = w.data - 0.001 * w.grad.data\n", " b.data = b.data - 0.001 * b.grad.data\n", " if (e + 1) % 20 == 0:\n", " print('epoch {}, Loss: {:.5f}'.format(e+1, loss.data[0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到更新完成之后 loss 已经非常小了,我们画出更新之后的曲线对比" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新之后的结果\n", "y_pred = multi_linear(x_train)\n", "\n", "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,经过 100 次更新之后,可以看到拟合的线和真实的线已经完全重合了" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "**小练习:上面的例子是一个三次的多项式,尝试使用二次的多项式去拟合它,看看最后能做到多好**\n", "\n", "**提示:参数 `w = torch.randn(2, 1)`,同时重新构建 x 数据集**" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }