{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 线性模型和梯度下降\n", "\n", "本节我们简单回顾一下线性回归模型,并演示一下如何使用PyTorch来对线性回归模型进行建模和模型参数计算。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. 一元线性回归\n", "一元线性模型非常简单,假设我们有变量 $x_i$ 和目标 $y_i$,每个 i 对应于一个数据点,希望建立一个模型\n", "\n", "$$\n", "\\hat{y}_i = w x_i + b\n", "$$\n", "\n", "$\\hat{y}_i$ 是我们预测的结果,希望通过 $\\hat{y}_i$ 来拟合目标 $y_i$,通俗来讲就是找到这个函数拟合 $y_i$ 使得误差最小,即最小化\n", "\n", "$$\n", "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "那么如何最小化这个误差呢?\n", "\n", "这里需要用到**梯度下降**,这是我们接触到的第一个优化算法,非常简单,但是却非常强大,在深度学习中被大量使用,所以让我们从简单的例子出发了解梯度下降法的原理" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. 梯度下降法\n", "在梯度下降法中,我们首先要明确梯度的概念,随后我们再了解如何使用梯度进行下降。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 梯度\n", "梯度在数学上就是导数,如果是一个多元函数,那么梯度就是偏导数。比如一个函数f(x, y),那么 f 的梯度就是 \n", "\n", "$$\n", "(\\frac{\\partial f}{\\partial x},\\ \\frac{\\partial f}{\\partial y})\n", "$$\n", "\n", "可以称为 grad f(x, y) 或者 $\\nabla f(x, y)$。具体某一点 $(x_0,\\ y_0)$ 的梯度就是 $\\nabla f(x_0,\\ y_0)$。\n", "\n", "下面这个图片是 $f(x) = x^2$ 这个函数在 x=1 处的梯度\n", "\n", "![](https://ws3.sinaimg.cn/large/006tNc79ly1fmarbuh2j3j30ba0b80sy.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "梯度有什么意义呢?从几何意义来讲,一个点的梯度值是这个函数变化最快的地方,具体来说,对于函数 f(x, y),在点 $(x_0, y_0)$ 处,沿着梯度 $\\nabla f(x_0,\\ y_0)$ 的方向,函数增加最快,也就是说沿着梯度的方向,我们能够更快地找到函数的极大值点,或者反过来沿着梯度的反方向,我们能够更快地找到函数的最小值点。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 梯度下降法\n", "有了对梯度的理解,我们就能了解梯度下降发的原理了。上面我们需要最小化这个误差,也就是需要找到这个误差的最小值点,那么沿着梯度的反方向我们就能够找到这个最小值点。\n", "\n", "我们可以来看一个直观的解释。比如我们在一座大山上的某处位置,由于我们不知道怎么下山,于是决定走一步算一步,也就是在每走到一个位置的时候,求解当前位置的梯度,沿着梯度的负方向,也就是当前最陡峭的位置向下走一步,然后继续求解当前位置梯度,向这一步所在位置沿着最陡峭最易下山的位置走一步。这样一步步的走下去,一直走到觉得我们已经到了山脚。当然这样走下去,有可能我们不能走到山脚,而是到了某一个局部的山峰低处。\n", "\n", "类比我们的问题,就是沿着梯度的反方向,我们不断改变 w 和 b 的值,最终找到一组最好的 w 和 b 使得误差最小。\n", "\n", "在更新的时候,我们需要决定每次更新的幅度,比如在下山的例子中,我们需要每次往下走的那一步的长度,这个长度称为学习率,用 $\\eta$ 表示,这个学习率非常重要,不同的学习率都会导致不同的结果,学习率太小会导致下降非常缓慢,学习率太大又会导致跳动非常明显,可以看看下面的例子\n", "\n", "![](https://ws2.sinaimg.cn/large/006tNc79ly1fmgn23lnzjg30980gogso.gif)\n", "\n", "可以看到上面的学习率较为合适,而下面的学习率太大,就会导致不断跳动\n", "\n", "最后我们的更新公式就是\n", "\n", "$$\n", "w := w - \\eta \\frac{\\partial f(w,\\ b)}{\\partial w} \\\\\n", "b := b - \\eta \\frac{\\partial f(w,\\ b)}{\\partial b}\n", "$$\n", "\n", "通过不断地迭代更新,最终我们能够找到一组最优的 w 和 b,这就是梯度下降法的原理。\n", "\n", "最后可以通过这张图形象地说明一下这个方法\n", "\n", "![](https://ws3.sinaimg.cn/large/006tNc79ly1fmarxsltfqj30gx091gn4.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.3 PyTorch实现\n", "\n", "上面是原理部分,下面通过一个例子来进一步学习线性模型" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import numpy as np\n", "from torch.autograd import Variable\n", "\n", "torch.manual_seed(2021)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPB0lEQVR4nO3df4xsZ13H8fd3uam6TQXSe2tM6e5CBKS5Biibpv5BlVRJbUybKGrJIoKVDWCq6F8k+4dGc/8gURNNiLrB366ILWJugjb1B9hIaHEuLbSUQNp699JS6aK0GjfQln7948z23m5m75y5O+fMc+a8X8lmZs6cO/t9ZrafPnPO8zwnMhNJUrkWZl2AJOn8DGpJKpxBLUmFM6glqXAGtSQV7kgTL3r06NFcWVlp4qUlaS6dOnXq65l5bNRzjQT1ysoKg8GgiZeWpLkUEdsHPeehD0kqnEEtSYUzqCWpcAa1JBXOoJakwhnUknRIW1uwsgILC9Xt1tZ0X7+R4XmS1BdbW7C+Dru71ePt7eoxwNradH6HPWpJOoSNjbMhvWd3t9o+LQa1JB3CmTOTbb8QBrUkHcLS0mTbL4RBLUmHcOIELC6+cNviYrV9WgxqSTqEtTXY3ITlZYiobjc3p3ciERz1IUmHtrY23WDezx61JBXOoJakwhnUklQ4g1qSCmdQS1LhDGpJKpxBLUmFM6glqXAGtSQVzqCWpMIZ1JJUOINakgpnUEtS4QxqSSqcQS1JhTOoJalwBrUkFc6glqTCGdSSVDiDWpIKZ1BLUuEMakkqnEEtSYWrFdQR8csR8UBEfCEi3tdwTZKkc4wN6og4DrwLuBp4LfDjEfF9TRcmSarU6VG/BrgnM3cz81ngX4GfaLYsSdKeOkH9APDGiLg0IhaBG4Ar9u8UEesRMYiIwc7OzrTrlKTeGhvUmflF4APAncAdwH3At0fst5mZq5m5euzYsWnXKUm9VetkYmb+UWa+ITOvBb4BfLnZsiRJe47U2SkiLsvMJyJiier49DXNliVJ2lMrqIGPRsSlwDPAL2bmk82VJEk6V62gzsw3Nl2IJGk0ZyZKUuEMakkqnEEtSYUzqCWpcAa1JBXOoJakwhnUklQ4g1qaI1tbsLICCwvV7dbWrCvSNNSdmSipcFtbsL4Ou7vV4+3t6jHA2trs6tLh2aOW5sTGxtmQ3rO7W21XtxnU0pw4c2ay7eoOg1qaE0tLk21XdxjU0pw4cQIWF1+4bXGx2q5uM6ilObG2BpubsLwMEdXt5qYnEueBoz6kObK2ZjDPI3vUklQ4g1qSCmdQS1LhDGpJKpxBLUmFM6glqXAGtSQVzqCWpMIZ1JJUOINakgpnUEtS4QxqSSqcQS1JhTOoJalwBrVUoJKuJl5SLX1lUKu3Sg2gvauJb29D5tmric+ivmnWUur73QWRmVN/0dXV1RwMBlN/XWla9gLo3Kt2Ly6WcUWUlZUqEPdbXobTp7tZS8nvdyki4lRmro58rk5QR8SvAL8AJHA/8M7M/OZB+xvUKl1JYbjfwkLVe90vAp57rpu1lPx+l+J8QT320EdEXA78ErCamceBFwE3T7dEqV1nzky2vU0lXU18WrWU/H53Qd1j1EeA74qII8Ai8NXmSpKaV1IY7lfS1cSnVUvJ73cXjA3qzHwM+C3gDPA48FRm3rl/v4hYj4hBRAx2dnamX6k0RSWF4X4lXU18WrWU/H53wdhj1BHxUuCjwM8ATwK3Abdn5l8e9G88Rq0u2NqCjY3q6/fSUhUanthqju/3+R3qZGJE/BRwfWbeMnz8duCazHzvQf/GoJakyRzqZCLVIY9rImIxIgK4DvjiNAuUJB2szjHqe4Dbgc9SDc1bADYbrkuSNFRr1Edm/lpmfn9mHs/Mn83MbzVdmKRuceZhc47MugBJ3bd/5uHeVHPwhOE0uNaHWmWvaz5tbLxwejhUjzc2ZlPPvLFHrdbY65pfzjxslj1qtabvva55/jbhzMNmGdRqTZ97XSUtXdoEZx42y6BWa/rc65r3bxMlTXufRwa1WtPnXlcfvk2srVVLlj73XHVrSE+PQa3W9LnX1edvEzo8g1qt6muvq8/fJnR4BrXUgj5/m9DhOY5aasnamsGsC2OPWpIKZ1BLOtA8T9LpEg99SBrJKf/lsEctaaR5n6TTJQa1pJH6MEmnKwxqSSM5SaccBrWkkZykUw6DWtJITtIph0EtdVBbw+b6OuW/NA7PkzrGYXP9Y49a6hiHzfWPQS11jMPm+segljrGYXP9Y1BLHeOwuf4xqKWOcdhc/zjqQ+og17buF3vUklQ4g1qSCmdQS1LhDGpJKpxBLUmFGxvUEfHqiLjvnJ//iYj3tVCbJIkaQZ2ZX8rM12Xm64A3ALvAx5ouTNLseFHbskw6jvo64OHM3G6iGEmz5+p85Zn0GPXNwIebKERSGVydrzy1gzoiLgJuBG474Pn1iBhExGBnZ2da9UlqmavzlWeSHvWPAZ/NzK+NejIzNzNzNTNXjx07Np3qJLXO1fnKM0lQv5UGD3t48kIqg6vzladWUEfExcCPAn/bRBF7Jy+2tyHz7MkLw1pqn6vzlScyc+ovurq6moPBoPb+KytVOO+3vFxdUFOS5l1EnMrM1VHPFTEz0ZMXknSwIoLakxeSdLAigtqTF/V50lXqnyKC2pMX9XjSVeqnIk4mqh5Pukrzq/iTiarHk65SPxnUHeJJV6mfDOoO8aSr1E8GdYd40lXqp0nXo9aMra0ZzFLf2KOWpMIZ1JJUOINakgpnUEtS4QxqSSqcQS1JhTOoJalwBnUPuVSq1C1OeOmZvaVSd3erx3tLpYITaaRS2aPumY2NsyG9Z3e32i6pTAZ1z0xjqVQPnUjtMqh75rBLpXqVGal9BnXPHHapVA+dSO0zqHvmsEulepUZqX2O+uihwyyVurQ0+rqNXmVGao49ak3Eq8xI7TOoO2wWoy+8yozUPg99dNQsJ654lRmpXfaoO8rRF1J/GNQd5egLqT8M6o467MQVSd1hUHeUoy+k/jCoO8rRF1J/1Br1EREvAT4EHAcS+PnM/HSDdakGR19I/VB3eN7vAndk5lsi4iJgcdw/kCRNx9igjogXA9cC7wDIzKeBp5stS5K0p84x6pcDO8CfRMS9EfGhiLh4/04RsR4Rg4gY7OzsTL1QSeqrOkF9BLgK+P3MfD3wf8D79++UmZuZuZqZq8eOHZtymZLUX3WC+lHg0cy8Z/j4dqrgliS1YGxQZ+Z/Al+JiFcPN10HPNhoVZKk59Ud9XErsDUc8fEI8M7mSpIknatWUGfmfcBqs6VIkkbp1cxEr54tqYt6sx71LNdvlqTD6E2P2vWbJXVVb4La9ZsldVVvgtr1myV1VW+C2vWbJXVVb4La9ZsldVVvRn2A6zdL6qbe9KglqasM6gI4EUfS+fTq0EeJnIgjaRx71DPmRBxJ4xjUM+ZEHEnjGNQz5kQcSeMY1DPmRBxJ4xjUM+ZEHEnjOOqjAE7EkXQ+9qglqXAGtSQVzqCWpMIZ1JJUOINakgpnUEtS4QxqSSqcQS1JhTOoJalwBrUkFc6glqTCGdSSVDiDWpIKZ1BLUuEMakkqXK31qCPiNPC/wLeBZzNztcmiJElnTXLhgDdl5tcbq0SSNNLcHPrY2oKVFVhYqG63tmZdkSRNR92gTuDOiDgVEeujdoiI9YgYRMRgZ2dnehXWsLUF6+uwvQ2Z1e36umEtaT5EZo7fKeLyzHwsIi4D/hG4NTPvOmj/1dXVHAwGUyzz/FZWqnDeb3kZTp9urQxJumARceqg83+1etSZ+djw9gngY8DV0yvv8M6cmWy7JHXJ2KCOiIsj4pK9+8CbgQeaLmwSS0uTbZekLqnTo/4e4N8i4nPAZ4CPZ+YdzZY1mRMnYHHxhdsWF6vtktR1Y4fnZeYjwGtbqOWCra1Vtxsb1eGOpaUqpPe2S1KXTTKOumhrawazpPk0N+OoJWleGdSSVDiDWpIKZ1BLUuEMakkqXK0p5BO/aMQOMGJS9/OOAn1dic+291Nf297XdsPkbV/OzGOjnmgkqMeJiEFf17S27ba9T/rabphu2z30IUmFM6glqXCzCurNGf3eEtj2fupr2/vabphi22dyjFqSVJ+HPiSpcAa1JBWu0aCOiOsj4ksR8VBEvH/E898RER8ZPn9PRKw0WU+barT9VyPiwYj4fET8c0Qsz6LOJoxr+zn7/WREZETMxfCtOu2OiJ8efu5fiIi/arvGptT4e1+KiE9ExL3Dv/kbZlHntEXEH0fEExEx8mIqUfm94fvy+Yi46oJ+UWY28gO8CHgYeAVwEfA54Mp9+7wX+IPh/ZuBjzRVT5s/Ndv+JmBxeP89fWr7cL9LgLuAu4HVWdfd0mf+SuBe4KXDx5fNuu4W274JvGd4/0rg9KzrnlLbrwWuAh444PkbgH8AArgGuOdCfk+TPeqrgYcy85HMfBr4a+CmffvcBPzZ8P7twHUREQ3W1Jaxbc/MT2Tm7vDh3cDLWq6xKXU+d4DfBD4AfLPN4hpUp93vAj6Ymd+A569BOg/qtD2B7x7efzHw1Rbra0xWF/n+7/PschPw51m5G3hJRHzvpL+nyaC+HPjKOY8fHW4buU9mPgs8BVzaYE1tqdP2c91C9X/deTC27cOvf1dk5sfbLKxhdT7zVwGviohPRcTdEXF9a9U1q07bfx14W0Q8Cvw9cGs7pc3cpFkw0txc4aWrIuJtwCrwQ7OupQ0RsQD8DvCOGZcyC0eoDn/8MNU3qLsi4gcy88lZFtWStwJ/mpm/HRE/CPxFRBzPzOdmXVgXNNmjfgy44pzHLxtuG7lPRByh+kr0Xw3W1JY6bScifgTYAG7MzG+1VFvTxrX9EuA48MmIOE113O7kHJxQrPOZPwqczMxnMvM/gC9TBXfX1Wn7LcDfAGTmp4HvpFq0aN7VyoJxmgzqfwdeGREvj4iLqE4Wnty3z0ng54b33wL8Sw6PwHfc2LZHxOuBP6QK6Xk5Vglj2p6ZT2Xm0cxcycwVquPzN2bmYDblTk2dv/e/o+pNExFHqQ6FPNJijU2p0/YzwHUAEfEaqqDeabXK2TgJvH04+uMa4KnMfHziV2n4jOgNVL2Gh4GN4bbfoPoPE6oP6zbgIeAzwCtmfRa3xbb/E/A14L7hz8lZ19xW2/ft+0nmYNRHzc88qA77PAjcD9w865pbbPuVwKeoRoTcB7x51jVPqd0fBh4HnqH6xnQL8G7g3ed85h8cvi/3X+jfulPIJalwzkyUpMIZ1JJUOINakgpnUEtS4QxqSSqcQS1JhTOoJalw/w9HECtz8n/B+wAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 生层测试数据\n", "x_train = np.random.rand(20, 1)\n", "y_train = x_train * 3 + 4 + 3*np.random.rand(20,1)\n", "\n", "# 画出图像\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "plt.plot(x_train, y_train, 'bo')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# 转换成 Tensor\n", "x_train = torch.from_numpy(x_train)\n", "y_train = torch.from_numpy(y_train)\n", "\n", "# 定义参数 w 和 b\n", "w = Variable(torch.randn(1), requires_grad=True) # 随机初始化\n", "b = Variable(torch.zeros(1), requires_grad=True) # 使用 0 进行初始化" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# 构建线性回归模型\n", "x_train = Variable(x_train)\n", "y_train = Variable(y_train)\n", "\n", "def linear_model(x):\n", " return x * w + b\n", "\n", "def logistc_regression(x):\n", " return torch.sigmoid(x*w+b) " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "y_ = linear_model(x_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "经过上面的步骤我们就定义好了模型,在进行参数更新之前,我们可以先看看模型的输出结果长什么样" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAXy0lEQVR4nO3df3Ac5X3H8c9XRmAELqGyygCOJDOTeBA2GFsQM5kYF4xxgQkwpC1UBJzBtYHCkLahA6PpQAs07UwbF5gkoCEOPyQSgmmop6UtBcw4P2xADoIkNtiJkY2MEyuGOOAf9Q99+8eehHzWSafbvb3dvfdr5ubuVqvd5zmZD889++zzmLsLAJA+NZUuAACgNAQ4AKQUAQ4AKUWAA0BKEeAAkFJHxXmyyZMne3Nzc5ynBIDUW7du3W/cvSF/e6wB3tzcrO7u7jhPCQCpZ2ZbRto+ZheKmS03sx1m9rNh237fzP7XzDblnk+MsrAAgLEV0wf+qKSFedvukPSiu39K0ou59wCAGI0Z4O6+WtL7eZsvl/RY7vVjkq6ItlgAgLGU2gd+krtvz73+laSTCu1oZkskLZGkxsbGI35+4MAB9fX1ad++fSUWBfkmTpyoKVOmqLa2ttJFAVBGoS9iurubWcEJVdy9Q1KHJLW2th6xX19fnyZNmqTm5maZWdjiVD13186dO9XX16epU6dWujgAyqjUceC/NrOTJSn3vKPUAuzbt0/19fWEd0TMTPX19XyjARKgq0tqbpZqaoLnrq5oj19qgK+UdH3u9fWS/j1MIQjvaPF5ApXX1SUtWSJt2SK5B89LlkQb4sUMI/yOpDWSpplZn5ndIOkfJV1kZpskzc+9BwDktLdLe/Ycvm3PnmB7VMbsA3f3awr86MLoipFugzcoTZ48udJFAZAQW7eOb3spUjcXSrn7lNxdAwMD0R4UQNUZYdDdqNtLkaoAL1efUm9vr6ZNm6brrrtO06dP1z333KNzzjlHZ555pu66666h/a644grNnj1bZ5xxhjo6OkLWBkCW3XefVFd3+La6umB7VFIV4OXsU9q0aZNuvvlmLVu2TNu2bdOrr76qnp4erVu3TqtXr5YkLV++XOvWrVN3d7ceeOAB7dy5M/yJAWRSW5vU0SE1NUlmwXNHR7A9KrFOZhVWOfuUmpqaNGfOHH3lK1/R888/r7PPPluS9NFHH2nTpk2aO3euHnjgAX3/+9+XJL377rvatGmT6uvrw58cQCa1tUUb2PlSFeCNjUG3yUjbwzruuOMkBX3gd955p5YuXXrYz19++WW98MILWrNmjerq6jRv3jzGWgOoqFR1ocTRp3TxxRdr+fLl+uijjyRJ27Zt044dO7Rr1y6deOKJqqur01tvvaW1a9dGd1IAKEGqWuCDX0Xa24Nuk8bGILyj/IqyYMECbdiwQeedd54k6fjjj1dnZ6cWLlyohx56SKeffrqmTZumOXPmRHdSACiBuRecxiRyra2tnr+gw4YNG3T66afHVoZqwecKZIeZrXP31vztqepCAQB8jAAHgJQiwAEgpQhwAEgpAhwAUooAB4CUIsDH4dFHH9V777039H7x4sVav3596OP29vbqySefHPfvLVq0SCtWrAh9fgDplL4AL/d8sqPID/BHHnlELS0toY9baoADqG7pCvAyzSfb2dmpc889VzNnztTSpUt16NAhLVq0SNOnT9eMGTO0bNkyrVixQt3d3Wpra9PMmTO1d+9ezZs3T4M3Jh1//PG6/fbbdcYZZ2j+/Pl69dVXNW/ePJ122mlauXKlpCCoP/e5z2nWrFmaNWuWfvzjH0uS7rjjDv3gBz/QzJkztWzZMh06dEi333770JS2Dz/8sKRgnpZbbrlF06ZN0/z587VjR8lLkQLIAneP7TF79mzPt379+iO2FdTU5B5E9+GPpqbijzHC+S+77DLfv3+/u7vfdNNNfvfdd/v8+fOH9vnggw/c3f3888/31157bWj78PeS/LnnnnN39yuuuMIvuugi379/v/f09PhZZ53l7u67d+/2vXv3urv7xo0bffDzWLVqlV966aVDx3344Yf9nnvucXf3ffv2+ezZs33z5s3+zDPP+Pz58/3gwYO+bds2P+GEE/zpp58uWC8A2SCp20fI1FTNhVKO+WRffPFFrVu3Tuecc44kae/evVq4cKE2b96sW2+9VZdeeqkWLFgw5nGOPvpoLVy4UJI0Y8YMHXPMMaqtrdWMGTPU29srSTpw4IBuueUW9fT0aMKECdq4ceOIx3r++ef15ptvDvVv79q1S5s2bdLq1at1zTXXaMKECTrllFN0wQUXlFxvAOmXri6UMqxR5O66/vrr1dPTo56eHr399tu6//779cYbb2jevHl66KGHtHjx4jGPU1tbO7QafE1NjY455pih1wcPHpQkLVu2TCeddJLeeOMNdXd3a//+/QXL9OCDDw6V6Z133inqfyLAaCp4+Qhlkq4AL8N8shdeeKFWrFgx1J/8/vvva8uWLRoYGNBVV12le++9Vz/5yU8kSZMmTdKHH35Y8rl27dqlk08+WTU1NXriiSd06NChEY978cUX65vf/KYOHDggSdq4caN2796tuXPn6qmnntKhQ4e0fft2rVq1quSyoLqUazlCVFa6ulDKMJ9sS0uL7r33Xi1YsEADAwOqra3V1772NV155ZVDixt/9atflRQM27vxxht17LHHas2aNeM+180336yrrrpKjz/+uBYuXDi0iMSZZ56pCRMm6KyzztKiRYt02223qbe3V7NmzZK7q6GhQc8++6yuvPJKvfTSS2ppaVFjY+PQlLfAWEZbjrCcK8agvJhONqP4XDFcTU3Q8s5nJuXaKUgwppMFqlgZLh8hAQhwoArEsRwh4peIAI+zG6ca8HkiX1ub1NEhNTUF3SZNTcF7+r/TreIXMSdOnKidO3eqvr5+aBgeSufu2rlzpyZOnFjpoiBh2toI7KypeIBPmTJFfX196u/vr3RRMmPixImaMmVKpYsBoMwqHuC1tbWaOnVqpYsBAKmTiD5wAMD4EeBAyiTllviklKOaEeBAniQHU1JuiY+yHEn+vJMu1J2YZvaXkhZLckk/lfQld99XaP+R7sQEkmQwmIbfdl5Xl5whd83NQVjma2qScpNepqocSf+8k6LQnZglB7iZnSrph5Ja3H2vmX1P0nPu/mih3yHAkXRJCchCknJLfFTlSPrnnRTlupX+KEnHmtlRkuokvTfG/kCilWHK+Ugl5Zb4qMqR9M876UoOcHffJumfJW2VtF3SLnd/Pn8/M1tiZt1m1s1YbyRdUgKykKTcEh9VOZL+eSddyQFuZidKulzSVEmnSDrOzK7N38/dO9y91d1bGxoaSi8pEIOkBGQhSbklfjzlGO0iZdI/78QbaZ21Yh6S/ljSt4a9v07SN0b7nZHWxATc3Ts7g6VNzYLnzs7qKkuS6h+lzk73urrDl7Ctqzu8flmte5RUYE3MMBcxPyNpuaRzJO2V9GjuJA8W+h0uYmIk1T4SIcv15yJlNCIfhZI76N9J+lNJByW9Lmmxu/9fof0JcIyk2v8jz3L9kzJqJu0KBXiouVDc/S5Jd4U5BlDtIxGyXP/GxpH/58RFymhwJyYqrtpHImS5/lykLC8CHBVX7f+Rp7H+xd7+npRRM5k10pXNcj0YhYJCqn0kQprqX8zIEkRLUY9CKQUXMYH0y/JF16RiVXoAkcjyRde0IcCBDIljatYsX3RNGwIcyIi45gpP40XXrCLAgYxobz/8bk4peN/eHu15GFmSHFzEBDKCux6zi4uYQMbRN119CHAgI+ibrj4EOJARcfVNswhxcoSazApAsrS1lfdiYv7Ut4MjXQbPjXjRAgdQtLhGuqA4iQ9wvq4BycFdmMmS6ACP68YEAMVhpEuyJDrA+bpWPL6pIA6MdEmWRAc4X9eKwzcVxIW7MJMl0XdiMm1lcficgGxL5Z2YfF0rDt9UgOqU6ADn61pxuLAEVKdEB7gUhHVvbzAZT28v4T0SvqkA1SnxAY6x8U0FqE4EeEZE8U2FoYhAujAXCiQxxwWQRrTAISmam6ZowQPxogUOSeGHItKCB+JHCxySwg9FZNoDIH4EeAaV0pURdigiNxMB8SPAM6bUeVHCDkXkZiIgfomeCwXjV6l5UfL7wKWgBc94dCC8ssyFYmafMLMVZvaWmW0ws/PCHA/hVaorg5uJgPiFHYVyv6T/dvcvmNnRkurG+gWUV2PjyC3wOLoyyr0eI4DDldwCN7MTJM2V9C1Jcvf97v7biMqFEjEvClA9wnShTJXUL+nbZva6mT1iZsfl72RmS8ys28y6+/v7Q5wOxaArA6geJV/ENLNWSWslfdbdXzGz+yX9zt3/ttDvcBETAMavHBcx+yT1ufsrufcrJM0KcbyK4jZwAGlTcoC7+68kvWtm03KbLpS0PpJSxYw1JQGkUdgbeW6V1GVmb0qaKekfQpeoArgNHEAahRpG6O49ko7ol0kbbgMHkEbcSq9k3wZO3zyAQghwJXfsNH3zAEZDgCu5Y6fpmwcwGiazSrCamqDlnc8sWPsSQHUoy2RWKK8k980DqDwCPMGS2jcPIBkI8ARLat88gGRgUeOEY4pWAIXQAgeAlCLAASClCHAASCkCHABSigAHgJTKfIAzGRSArMr0MMLByaAG5xMZnAxKYmgegPTLdAucyaAAZFmmA5yFGgBkWaYDnMmgAGRZpgOcyaAAZFmmA5zJoABkWaZHoUhMBgUguzLdAgeALCPAASClCHAASCkCHABSigAHgJQiwAEgpQhwAEgpAhwAUooAB4CUIsABIKUIcABIqdABbmYTzOx1M/uPKAoEAChOFC3w2yRtiOA4AIBxCBXgZjZF0qWSHommOACAYoVtgf+rpL+RNFBoBzNbYmbdZtbd398f8nQAgEElB7iZXSZph7uvG20/d+9w91Z3b21oaCj1dACAPGFa4J+V9Hkz65X0XUkXmFlnJKUCAIyp5AB39zvdfYq7N0u6WtJL7n5tZCUDAIyKceAAkFKRrInp7i9LejmKYwEAikMLHABSigAHgJQiwAEgpQhwAEgpAhwAUooAB4CUIsABIKUIcABIKQIcAFKKAAeAlCLAASClCHAASCkCHABSigAHgJQiwAEgpQhwAEgpAhwAUooAB4CUIsABIKUIcABIKQIcAFKKAAeAlCLAASClCHAASCkCHABSigAHgJQiwAEgpQhwACiXri6puVmqqQmeu7oiPfxRkR4NABDo6pKWLJH27Aneb9kSvJektrZITkELHADKob394/AetGdPsD0iBDgAlMPWrePbXgICHADKobFxfNtLUHKAm9knzWyVma03s5+b2W2RlQoA0u6++6S6usO31dUF2yMSpgV+UNJfu3uLpDmS/sLMWqIpFgCkXFub1NEhNTVJZsFzR0dkFzClEKNQ3H27pO251x+a2QZJp0paH1HZACDd2toiDex8kfSBm1mzpLMlvTLCz5aYWbeZdff390dxOgBZUuax0lkWOsDN7HhJz0j6srv/Lv/n7t7h7q3u3trQ0BD2dADSZrSAHhwrvWWL5P7xWGlCvCihAtzMahWEd5e7/1s0RQKQCsW0nMcK6BjGSmeZuXtpv2hmkh6T9L67f7mY32ltbfXu7u6SzgcgIbq6pKVLpd27D99eV3fkRbrm5iC08zU1Sb29QfiPlEFm0sBAlKVONTNb5+6t+dvDtMA/K+mLki4ws57c45IQxwOQdF1d0pe+dGR4SyO3nMe6mSWGsdJZVnKAu/sP3d3c/Ux3n5l7PBdl4QAkTHu7dOBA4Z/nB/ZYAR3DWOks405MoFpEMdpjrNvA8wN7rICOYax0lhHgQJYUCumoRnuM1rVhdmTLuZiAbmsL+sMHBoJnwrt47h7bY/bs2Q4gpM5O96Ymd7PgubPz4+11de5BRAePurqP9x++ffDR1DT+c9fWjnysm26Ktp4YIqnbR8hUWuBAEhTbvTFaS3q0IXlRzYzX1iZ9+9tSff3H2+rrpc5O6RvfGN+xEFrJwwhLwTBCYAT5E/9LIw/Jk0Yflrd1a+EheY2Now/nQ6KVYxghgCiM52aW0VrSo434YLRHJhHgQKWNp3uj1JBmtEcmEeBApY3nZpYwIc1oj8whwIF8cc+ON57uDUIaw7AqPTBcDCuJH2HwuIOjRQa7Qwqdr8xzTCM9GIUCDDfW5EtABTAKBShGDCuJA1EhwIHhmB0PKUKAIxmSsqwW46WRIgQ4Ki9Jy2oxXhopQoAjOqW2opO2rBZD8ZASDCNENMIMv+PCIVASWuD4WJh+6DCtaC4cAiUhwBEI2w8dphXNhUOgJAR4FpXSkg7bDx2mFc2FQ6AkBHhWDIa2mfTFL46/JR22HzpsK5oLh8C4EeBJV0xrenj3h3TkpP7FtKTD9kPTigZix1woSVbsSi2F5u8Yzixo3YY9F4DYMRdKGhXbL11MN8dYLWla0EDqMA48yYrtly603uGgYvuimaYUSBVa4ElWbL/0SBcQzYJnWtJAZhHgSVbsyI6Ruj+eeCK4mMmIDiCzCPBBSZkNb7jx9EszDA+oOtkP8PEOw6v0bHj5CGYABWQ7wIsN5qTNhgcARch2gIcdhsdseAASLPkBHqZvejzD8EbCbHgAEixUgJvZQjN728x+YWZ3RFWoIWH7psMMw2M2PAAJV3KAm9kESV+X9EeSWiRdY2YtURVMUvi+6TDD8Bg7DSDhwtyJea6kX7j7Zkkys+9KulzS+igKJil83/RgALe3B7/T2BiEd6FheAQ2gBQJE+CnSnp32Ps+SZ/J38nMlkhaIkmN4+1TLnSL+HiOQzADyKiyX8R09w53b3X31oaGhvH9Mn3TAFBQmADfJumTw95PyW2LDn3TAFBQmC6U1yR9ysymKgjuqyX9WSSlGo4uEAAYUckB7u4HzewWSf8jaYKk5e7+88hKBgAYVaj5wN39OUnPRVQWAMA4JP9OTADAiAhwAEgpAhwAUirWVenNrF/SGMuna7Kk38RQnKSp1npL1L0a616t9ZZKq3uTux9xI02sAV4MM+t299ZKlyNu1VpvibpXY92rtd5StHWnCwUAUooAB4CUSmKAd1S6ABVSrfWWqHs1qtZ6SxHWPXF94ACA4iSxBQ4AKAIBDgApVZEAH2stTTM7xsyeyv38FTNrrkAxy6KIuv+Vma03szfN7EUza6pEOcuh2DVUzewqM3Mzy8Qws2LqbWZ/kvu7/9zMnoy7jOVSxL/3RjNbZWav5/7NX1KJckbNzJab2Q4z+1mBn5uZPZD7XN40s1klncjdY30omLnwl5JOk3S0pDckteTtc7Okh3Kvr5b0VNzlrGDd/1BSXe71TdVU99x+kyStlrRWUmulyx3T3/xTkl6XdGLu/R9Uutwx1r1D0k251y2Seitd7ojqPlfSLEk/K/DzSyT9lySTNEfSK6WcpxIt8KG1NN19v6TBtTSHu1zSY7nXKyRdaGYWYxnLZcy6u/sqdx9cyXmtgoUysqCYv7sk3SPpnyTti7NwZVRMvf9c0tfd/QNJcvcdMZexXIqpu0v6vdzrEyS9F2P5ysbdV0t6f5RdLpf0uAfWSvqEmZ083vNUIsBHWkvz1EL7uPtBSbsk1cdSuvIqpu7D3aDg/9JZMGbdc18jP+nu/xlnwcqsmL/5pyV92sx+ZGZrzWxhbKUrr2Lqfreka82sT8HU1LfGU7SKG28WjCjUfOAoHzO7VlKrpPMrXZY4mFmNpK9JWlTholTCUQq6UeYp+Ma12sxmuPtvK1momFwj6VF3/xczO0/SE2Y23d0HKl2wNKhEC7yYtTSH9jGzoxR8tdoZS+nKq6h1RM1svqR2SZ939/+LqWzlNlbdJ0maLullM+tV0C+4MgMXMov5m/dJWunuB9z9HUkbFQR62hVT9xskfU+S3H2NpIkKJnvKukjWFK5EgA+tpWlmRyu4SLkyb5+Vkq7Pvf6CpJc81/OfcmPW3czOlvSwgvDOSl+oNEbd3X2Xu09292Z3b1bQ//95d++uTHEjU8y/92cVtL5lZpMVdKlsjrGM5VJM3bdKulCSzOx0BQHeH2spK2OlpOtyo1HmSNrl7tvHfZQKXaG9REEr45eS2nPb/l7Bf7BS8Ed8WtIvJL0q6bRKX1WOse4vSPq1pJ7cY2WlyxxX3fP2fVkZGIVS5N/cFHQfrZf0U0lXV7rMMda9RdKPFIxQ6ZG0oNJljqje35G0XdIBBd+wbpB0o6Qbh/3Nv577XH5a6r91bqUHgJTiTkwASCkCHABSigAHgJQiwAEgpQhwAEgpAhwAUooAB4CU+n81PmNJdk5fugAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**思考:红色的点表示预测值,似乎排列成一条直线,请思考一下这些点是否在一条直线上?**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "这个时候需要计算我们的误差函数,也就是\n", "\n", "$$\n", "E = \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n", "$$" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# 计算误差\n", "def get_loss(y_, y):\n", " return torch.sum((y_ - y) ** 2)\n", "\n", "loss = get_loss(y_, y_train)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(748.8935, dtype=torch.float64, grad_fn=)\n" ] } ], "source": [ "# 打印一下看看 loss 的大小\n", "print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "定义好了误差函数,接下来我们需要计算 w 和 b 的梯度了,这时得益于 PyTorch 的自动求导,我们不需要手动去算梯度,有兴趣的同学可以手动计算一下,w 和 b 的梯度分别是\n", "\n", "$$\n", "\\frac{\\partial}{\\partial w} = \\frac{2}{n} \\sum_{i=1}^n x_i(w x_i + b - y_i) \\\\\n", "\\frac{\\partial}{\\partial b} = \\frac{2}{n} \\sum_{i=1}^n (w x_i + b - y_i)\n", "$$" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# 自动求导\n", "loss.backward()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([-125.1102])\n", "tensor([-243.2102])\n" ] } ], "source": [ "# 查看 w 和 b 的梯度\n", "print(w.grad)\n", "print(b.grad)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# 更新一次参数\n", "w.data = w.data - 1e-2 * w.grad.data\n", "b.data = b.data - 1e-2 * b.grad.data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "更新完成参数之后,我们再一次看看模型输出的结果" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD6CAYAAAC4RRw1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZQ0lEQVR4nO3dfZBV9X3H8fd3cRXXUKPLxpqS3cWZ1IogCIvFaSVUEWh0fChpJmYTJU2CmujYdmpLhuloq9s0mVTaZNrEHUN8YLUqaVMmtQl5wJImol4MGoMGErroqg0rGnwAwsN++8e5i8v17t6755x7z8P9vGbu3HvPPXvO73cWvvu739/DMXdHRESypynpAoiISDgK4CIiGaUALiKSUQrgIiIZpQAuIpJRCuAiIhlVMYCb2Woz22VmT4/YdrKZfcfMthefT6ptMUVEpJRVGgduZvOBN4C73X16cdvngVfc/e/NbAVwkrv/VaWTTZ482Ts7O6OXWkSkgWzevPlld28r3X5MpR90941m1lmy+VJgQfH1XcDDQMUA3tnZSaFQqLSbiIiMYGY7y20PmwM/xd1fKr7+P+CUkMcREZGQIndiepCDGTUPY2bLzaxgZoXBwcGopxMRkaKwAfyXZnYqQPF512g7unuvu3e5e1db29tSOCIiElLFHPgo1gFXAX9ffP6PsAU4ePAgAwMD7N+/P+whpMTEiROZMmUKzc3NSRdFRGqoYgA3s/sIOiwnm9kAcBNB4H7AzD4O7AQ+GLYAAwMDTJo0ic7OTsws7GGkyN3ZvXs3AwMDTJ06NeniiEgNVUyhuPsV7n6quze7+xR3/6q773b3C9z9ve6+0N1fCVuA/fv309raquAdEzOjtbVV32hEUqCvDzo7oakpeO7ri/f4YVMosVLwjpeup0jy+vpg+XLYuzd4v3Nn8B6guzuec2gqfQw6Ozt5+eWXky6GiKTIypVvBe9he/cG2+OiAF7C3RkaGkq6GCKScc89N77tYWQugNcip9Tf38/pp5/OlVdeyfTp07nllluYO3cuZ511FjfddNOR/S677DLmzJnDmWeeSW9vb/QTi0hutbePb3sYqciBV6uWOaXt27dz11138dprr7F27Voee+wx3J1LLrmEjRs3Mn/+fFavXs3JJ5/Mvn37mDt3LkuXLqW1tTXaiUUkl3p6jo5XAC0twfa4ZKoFXsucUkdHB/PmzWP9+vWsX7+es88+m9mzZ/Pss8+yfft2AL74xS8yc+ZM5s2bx/PPP39ku4hIqe5u6O2Fjg4wC557e+PrwISMtcBrmVM64YQTgCAH/pnPfIarr776qM8ffvhhvvvd7/LII4/Q0tLCggULNFRPRMbU3R1vwC6VqRZ4PXJKixcvZvXq1bzxxhsAvPDCC+zatYs9e/Zw0kkn0dLSwrPPPsumTZviO6mISAiZCuA9PUEOaaS4c0qLFi3iwx/+MOeeey4zZszgAx/4AK+//jpLlizh0KFDnHHGGaxYsYJ58+bFd1IRkRAq3tAhTl1dXV66HvgzzzzDGWecUfUx+vqCnPdzzwUt756e2n5FyarxXlcRSS8z2+zuXaXbM5UDh9rnlEREsiJTKRQREXmLAriISEYpgIuIZJQCuIhIRimAi4hklAL4ONx55528+OKLR95/4hOfYOvWrZGP29/fz7333jvun1u2bBlr166NfH4RySYF8HEoDeB33HEH06ZNi3zcsAFcRBpb9gJ4DdaTXbNmDeeccw6zZs3i6quv5vDhwyxbtozp06czY8YMVq1axdq1aykUCnR3dzNr1iz27dvHggULGJ6Y9I53vIMbb7yRM888k4ULF/LYY4+xYMECTjvtNNatWwcEgfq8885j9uzZzJ49mx/96EcArFixgh/84AfMmjWLVatWcfjwYW688cYjS9refvvtQLBOy3XXXcfpp5/OwoUL2bVrV+S6i0iGuXvoB3AD8DTwU+BPK+0/Z84cL7V169a3bRvVmjXuLS3u8NajpSXYHtLWrVv94osv9gMHDri7+7XXXus333yzL1y48Mg+r776qru7v+997/PHH3/8yPaR7wF/6KGH3N39sssu8wsvvNAPHDjgW7Zs8ZkzZ7q7+5tvvun79u1zd/dt27b58PXYsGGDX3TRRUeOe/vtt/stt9zi7u779+/3OXPm+I4dO/zrX/+6L1y40A8dOuQvvPCCn3jiif7ggw+OWi8RyQeg4GViauiZmGY2HfgkcA5wAPiWmX3T3X8e+a/KaMZaTzbk9Mzvfe97bN68mblz5wKwb98+lixZwo4dO7j++uu56KKLWLRoUcXjHHvssSxZsgSAGTNmcNxxx9Hc3MyMGTPo7+8H4ODBg1x33XVs2bKFCRMmsG3btrLHWr9+PU899dSR/PaePXvYvn07Gzdu5IorrmDChAm8+93v5vzzzw9VZxHJhyhT6c8AHnX3vQBm9t/AHwGfj6NgZdVgPVl356qrruKzn/3sUdt7enr49re/zVe+8hUeeOABVq9ePeZxmpubj9xMuKmpieOOO+7I60OHDgGwatUqTjnlFJ588kmGhoaYOHHiqGX60pe+xOLFi4/a/tBDD4Wqo4jkU5Qc+NPAeWbWamYtwPuB95TuZGbLzaxgZoXBwcEIp6Mm68lecMEFrF279kg++ZVXXmHnzp0MDQ2xdOlSbr31Vp544gkAJk2axOuvvx76XHv27OHUU0+lqamJe+65h8OHD5c97uLFi/nyl7/MwYMHAdi2bRtvvvkm8+fP5/777+fw4cO89NJLbNiwIXRZRCT7QrfA3f0ZM/scsB54E9gCHC6zXy/QC8FqhGHPB9TkHkXTpk3j1ltvZdGiRQwNDdHc3Mxtt93G5ZdffuTmxsOt82XLlnHNNddw/PHH88gjj4z7XJ/61KdYunQpd999N0uWLDlyE4mzzjqLCRMmMHPmTJYtW8YNN9xAf38/s2fPxt1pa2vjG9/4Bpdffjnf//73mTZtGu3t7Zx77rmh6y0i2RfbcrJm9nfAgLv/y2j7xLGcrNaTrY6WkxXJj5osJ2tm73L3XWbWTpD/rv1dDrSerIgIEH098K+bWStwEPi0u/8qepFERKQakQK4u58XV0FERGR8UjETM648vAR0PUUaQ+IBfOLEiezevVtBJybuzu7du0cdYy4i+ZH4PTGnTJnCwMAAkceIyxETJ05kypQpSRdDUkYDuPIn8QDe3NzM1KlTky6GSK719R09hWLnzuA9KIhnWeIpFBGpvbGWEZLsUgAXaQA1WEZIUkABXKQB1GAZIUkBBXCRBtDTEywbNFLEZYQkBRTARRpAdzf09kJHB5gFz7296sDMusRHoYhIfWgZofxRC1xEJKMUwEVEMkoBXEQkoxTARUQySgFcRCSjFMBFRDJKAVwkY/r6oLMTmpqC576+pEskSVEAFymR5gA5vKrgzp3g/taqgkmUMa7rlObrnXruHvoB/BnwU+Bp4D5g4lj7z5kzx0XSbM0a95YW9yA8Bo+WlmB7GnR0HF224UdHR33LEdd1Svv1Tgug4GViqnnIO+GY2W8B/wNMc/d9ZvYA8JC73znaz3R1dXmhUAh1PpF66OwMWrWlOjqgv7/epXm7pqYgzJUyg6Gh+pUjruuU9uudFma22d27SrdHTaEcAxxvZscALcCLEY8nkqi0L7uallUF47pOab/eaRc6gLv7C8AXgOeAl4A97r4+roKJJCEtAXI0aVlVMK7rlPbrnXahA7iZnQRcCkwF3g2cYGYfKbPfcjMrmFlB972UtEtLgBxNWlYVjOs6pf16p165xHg1D+CPga+OeH8l8C9j/Yw6MSUL1qwJOgXNgmd1qJUX13XS9a6MGnRi/i6wGpgL7APuLJ7kS6P9jDoxRRpPX19w783nngtSIz09WtZ2vGLvxHT3R4G1wBPAT4rH6g1dQmloGgucT2kat55HkUahuPtN7v477j7d3T/q7r+Oq2DSOPSfPL9/wFauhL17j962d2+wXaILnUIJQykUKafRxwIP/wEbGehaWvJxy7O0jFvPulqNAxeJrNHHAue5laphgrWlAC6Ja/T/5Hn+A6ZhgrWlAC6Ja/T/5Hn+A5aWcet5pQAuiWv0/+R5/wPW3R30ZQwNBc+N8nutBwVwSYVG/k+exT9geR01kzXHJF0AEQmCdZoD9kilo2aGh31CduqQF2qBi8i45HnUTNYogIvIuOR51EzWKICLyLjkedRM1iiAi+RIPToX8z5qJksUwEVyol5rymRx1ExeaS0UkZxo9DVl8kxroYjknDoXG48CuEhOqHOx8SiAi+SEOhcbjwK4SE6oc7HxaCq9SI5kaUq+RBe6BW5mp5vZlhGP18zsT2Msm4iIjCHKTY1/5u6z3H0WMAfYC/x7XAUTkXTSSoTpEVcK5QLgF+5eZhSqiOSFViJMl7g6MT8E3BfTsUQkpbQSYbpEDuBmdixwCfDgKJ8vN7OCmRUGBwejnk5EEqTJQukSRwv8D4En3P2X5T50915373L3rra2tnEfXPk2kfTQZKF0iSOAX0GN0if1WpxHRKqjyULpEimAm9kJwIXAv8VTnKMp3yaSLposlC6RAri7v+nure6+J64CjaR8W/WUapJ6aeQbUKdNqqfSK99WHaWaRBpTqgO48m3VUapJpDGlOoAr31YdpZpEGlPqF7PS4jyVtbeXvxOLUk0i+ZbqFrhUR6kmkcakAJ4DSjWJNKbUp1CkOko1iTQetcBFRDJKAVyO0GQgkWxRCkUArfMskkVqgQugyUAiWaQALkA8k4GUghGpLwVwAaKvO6P1WETqTwFcgOiTgZSCEak/BXABok8G0nosIvWnAJ5DYXPRUdZ51tK/IvWnAJ4zSeWitR6LSP0pgOdMUrlorcciUn/m7uF/2OydwB3AdMCBP3H3R0bbv6urywuFQujzSWVNTUHLu5RZkBoRkewxs83u3lW6PWoL/J+Ab7n77wAzgWciHk8iUi5apHGEDuBmdiIwH/gqgLsfcPdfxVQuCUm5aJHGEaUFPhUYBL5mZj82szvM7ISYyiUhKRctkiI1np4cOgduZl3AJuD33P1RM/sn4DV3/+uS/ZYDywHa29vn7Cx37y8RkbwpXSEOgq/DIVpUo+XAowTw3wQ2uXtn8f15wAp3v2i0n1Enpog0jM7O8jer7egIJlqMQ+ydmO7+f8DzZnZ6cdMFwNawxxMRyZU6TE+OOgrleqDPzJ4CZgF/F7lECdFKeiISqzoMCYsUwN19i7t3uftZ7n6Zu78aV8HqSSvpiUjs6jAkTDMx0Up6IlIDdRgSFmkm5niltRNTsxdFJM1qNRMzFzR7UUSySAEczV4UkWxSACfdsxc1OkZERnNM0gVIi+7udATskUoncg2PjoH0lVVE6k8t8BTT6BgRGYsCeIrpPpMiMhYF8BTT6BgRGYsCeIppdIyIjEUBPMXSPDpGRJKnAJ5y3d3BypNDQ8GzgrfkjsbKhqZhhCKSHI2VjUQtcBFJjsbKRqIALiK1NVaKRGNlI1EAF5HaqbTYvsbKRqIALiK1UylForGykSiAi8j49PXB5MnB2Nbhx+TJ5UePVEqRaKxsJBqFIiLV6+uDj30MDh48evvu3cF2ODr4treXvzP7yBRJGleSy4hILXAz6zezn5jZFjNL3612RCReK1e+PXgPO3jw7aNHlCKpqTha4H/g7i/HcBwRSbtKo0NKPx9uWa9cGXzW3h4Eb7W4Y5H7HLgmeYnEqNLokHKfazpxzUQN4A6sN7PNZrY8jgLFqdIIJhEZp54eaG4u/1lzs1IjdRY1gP++u88G/hD4tJnNL93BzJabWcHMCoODgxFPNz6a5CUyQhxfR7u74Wtfg9bWo7e3tgbb1bquK3P3eA5kdjPwhrt/YbR9urq6vFCoX19nU1PQ8i5lFnybE2kYpWuOQNCZqCF7mWBmm929q3R76Ba4mZ1gZpOGXwOLgKfDFzF+muQlDWe0Vra+juZSlFEopwD/bmbDx7nX3b8VS6li0tNTvtGhNJ3k0lgr+2nNkVwK3QJ39x3uPrP4ONPdUxcWNclLcilMK1tfR3Mp9zMxNclLciVsK/uee/R1NIdyPw5cJFfCtrL1dTSXFMBFsmSsVnalaeuaUJM7CuAiaVDtGG21smUEBXCRpI1nyrBa2TKCArhI0sYzRlutbBkhtpmY1aj3TEyRTNCUYakg9pmYIhITjdGWkBTARZKmmx5ISArgIklTXltCyv1MTJFM0JRhCUEtcBGRjFIAFyml+/BJRiiFIjLSWItFKcUhKaMWuMhIuvGBZIgCuMhIuvGBZIgCuMhImlQjGaIALjKSJtVIhkQO4GY2wcx+bGbfjKNAIonSpBrJkDhGodwAPAP8RgzHEkmeJtVIRkRqgZvZFOAi4I54iiMiItWKmkL5R+AvAa15KSJSZ6EDuJldDOxy980V9ltuZgUzKwwODoY9neSdZj+KjFuUFvjvAZeYWT/wr8D5ZramdCd373X3Lnfvamtri3A6ya3x3FJMRI6I5Y48ZrYA+At3v3is/XRHHimrszMI2qU6OoL7Ooo0ON2RR9JLsx9FQoklgLv7w5Va3yKj0uxHkVDUApf4hO2I1OxHkVAUwCUeUToiNftRJJRYOjGrpU7MHFNHpEjNqBNTaksdkSJ1pwAu8VBHpEjdKYDLW6LMhlRHpEjdKYBLIOpsSHVEitSdOjEloE5IkdRSJ6aMTZ2QIpmjAJ5HYXLZ6oQUyRwF8LwYDtpm8NGPjj+XrU5IkcxRAM+DkR2QEATukfbuhZUrxz6GOiFFMkedmHkwWgfkSGYwpBsniWSROjGzqpp8djUdjcpli+SOAniaVTs2u1JwVi5bJJcUwNNs5cogfz1SuXx2uQ5Is+BZuWyR3FIAT7Nqx2aX64C8556g1d7fr+AtklPHJF0AGUN7e/nOyXIpk+5uBWqRBhO6BW5mE83sMTN70sx+amZ/E2fBBI3NFpExRUmh/Bo4391nArOAJWY2L5ZSSUBjs0VkDKFTKB4MIH+j+La5+KjfoPJGodSIiIwiUiemmU0wsy3ALuA77v5oLKUSEZGKIgVwdz/s7rOAKcA5Zja9dB8zW25mBTMrDA4ORjmdiIiMEMswQnf/FbABWFLms15373L3rra2tjhOJyIiRBuF0mZm7yy+Ph64EHg2pnKJiEgFUVrgpwIbzOwp4HGCHPg34ylWAqLcD1JEJAFRRqE8BZwdY1mSM7zmyPC09eE1R0AjQEQktTSVHqpfc0REJEXyH8CjLMeq+0GKSIrlO4BHXY5Va2iLSIrlO4BHWY5Va46ISMrlO4BHWY5Va46ISMrlezlZLccqIjmW/hZ4lPHZSo2ISI6lO4BX2wk5GqVGRCTHLFgVtj66urq8UChU/wOdneVTIB0dwa3CREQagJltdveu0u3pboFrfLaIyKjSHcA1PltEZFTpDuDqhBQRGVW6A7g6IUVERpX+ceAany0iUla6W+AiIjIqBXARkYxSABcRySgFcBGRjFIAFxHJqLpOpTezQaDM3PijTAZerkNx0qZR6w2qeyPWvVHrDeHq3uHubaUb6xrAq2FmhXJz/vOuUesNqnsj1r1R6w3x1l0pFBGRjFIAFxHJqDQG8N6kC5CQRq03qO6NqFHrDTHWPXU5cBERqU4aW+AiIlKFRAK4mS0xs5+Z2c/NbEWZz48zs/uLnz9qZp0JFLMmqqj7n5vZVjN7ysy+Z2YdSZSzFirVfcR+S83MzSwXoxSqqbeZfbD4e/+pmd1b7zLWShX/3tvNbIOZ/bj4b/79SZQzbma22sx2mdnTo3xuZvbF4nV5ysxmhzqRu9f1AUwAfgGcBhwLPAlMK9nnU8BXiq8/BNxf73ImWPc/AFqKr69tpLoX95sEbAQ2AV1Jl7tOv/P3Aj8GTiq+f1fS5a5j3XuBa4uvpwH9SZc7prrPB2YDT4/y+fuB/wIMmAc8GuY8SbTAzwF+7u473P0A8K/ApSX7XArcVXy9FrjAzKyOZayVinV39w3uvrf4dhMwpc5lrJVqfu8AtwCfA/bXs3A1VE29Pwn8s7u/CuDuu+pcxlqppu4O/Ebx9YnAi3UsX824+0bglTF2uRS42wObgHea2anjPU8SAfy3gOdHvB8obiu7j7sfAvYArXUpXW1VU/eRPk7wVzoPKta9+DXyPe7+n/UsWI1V8zv/beC3zeyHZrbJzJbUrXS1VU3dbwY+YmYDwEPA9fUpWuLGGwvKSv8NHRqUmX0E6ALel3RZ6sHMmoDbgGUJFyUJxxCkURYQfOPaaGYz3P1XSRaqTq4A7nT3fzCzc4F7zGy6uw8lXbAsSKIF/gLwnhHvpxS3ld3HzI4h+Gq1uy6lq61q6o6ZLQRWApe4+6/rVLZaq1T3ScB04GEz6yfIC67LQUdmNb/zAWCdux909/8FthEE9Kyrpu4fBx4AcPdHgIkEa4XkXVWxoJIkAvjjwHvNbKqZHUvQSbmuZJ91wFXF1x8Avu/FzH/GVay7mZ0N3E4QvPOSC4UKdXf3Pe4+2d073b2TIP9/ibsXkilubKr59/4NgtY3ZjaZIKWyo45lrJVq6v4ccAGAmZ1BEMAH61rKZKwDriyORpkH7HH3l8Z9lIR6aN9P0Mr4BbCyuO1vCf7DQvBLfBD4OfAYcFrSvcp1rPt3gV8CW4qPdUmXuV51L9n3YXIwCqXK37kRpI+2Aj8BPpR0metY92nADwlGqGwBFiVd5pjqfR/wEnCQ4BvWx4FrgGtG/M7/uXhdfhL237pmYoqIZJRmYoqIZJQCuIhIRimAi4hklAK4iEhGKYCLiGSUAriISEYpgIuIZJQCuIhIRv0/2aSI16sim/wAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y_ = linear_model(x_train)\n", "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的例子可以看到,更新之后红色的线跑到了蓝色的线下面,没有特别好的拟合蓝色的真实值,所以我们需要在进行几次更新" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 19, loss: 9.138844332292493\n", "epoch: 39, loss: 8.31670591484358\n", "epoch: 59, loss: 8.010376750480548\n", "epoch: 79, loss: 7.896237967760094\n", "epoch: 99, loss: 7.853709612500179\n" ] } ], "source": [ "for e in range(100): # 进行 100 次更新\n", " y_ = linear_model(x_train)\n", " loss = get_loss(y_, y_train)\n", " \n", " w.grad.zero_() # 记得归零梯度\n", " b.grad.zero_() # 记得归零梯度\n", " loss.backward()\n", " \n", " w.data = w.data - 1e-2 * w.grad.data # 更新 w\n", " b.data = b.data - 1e-2 * b.grad.data # 更新 b \n", " if (e + 1) % 20 == 0:\n", " print('epoch: {}, loss: {}'.format(e, loss.item()))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAXXklEQVR4nO3df5BV5X3H8fd3cRVXqTFAHBPCLs4kVARFdrXYGZFEBKqZRGum0VmjpFH8Ua1tJ86Y8IdOlclkppXWJKNuLLG6a6pidJjGJtRfJU1QXAyaFBMwuOCiDetqqPwKsPvtH2cXYb3Lnrv3nnOec8/nNXNn95579t7vc5f98NznPOc55u6IiEi46rIuQEREjkxBLSISOAW1iEjgFNQiIoFTUIuIBO6oJJ50woQJ3tTUlMRTi4jUpHXr1r3j7hNLPZZIUDc1NdHZ2ZnEU4uI1CQz2zLcYxr6EBEJnIJaRCRwCmoRkcAlMkZdyv79++nu7mbv3r1pvWTNGzt2LJMmTaK+vj7rUkQkQakFdXd3N+PGjaOpqQkzS+tla5a709vbS3d3N1OmTMm6HBFJUGpDH3v37mX8+PEK6SoxM8aPH69PKCIB6OiApiaoq4u+dnRU9/lT61EDCukq0/spkr2ODli8GHbvju5v2RLdB2htrc5r6GCiiEgFliz5IKQH7d4dba8WBXVMTU1NvPPOO1mXISKB2bq1vO2jEWxQJznm4+709/dX7wlFpLAmTy5v+2gEGdSDYz5btoD7B2M+lYR1V1cXU6dO5corr2T69OnccccdnHXWWZx++uncdtttB/e7+OKLaW5u5rTTTqOtra0KrRGRWrZ0KTQ0HL6toSHaXi1BBnVSYz6bNm3ihhtuYNmyZWzbto21a9eyfv161q1bx+rVqwFYvnw569ato7Ozk7vvvpve3t7KXlREalprK7S1QWMjmEVf29qqdyARUp71EVdSYz6NjY3Mnj2br33ta6xatYozzzwTgJ07d7Jp0ybmzJnD3XffzRNPPAHAm2++yaZNmxg/fnxlLywiNa21tbrBPFSQQT15cjTcUWp7JY477jggGqP++te/zrXXXnvY488//zxPP/00a9asoaGhgblz52qesohkLsihj6THfBYsWMDy5cvZuXMnANu2bWP79u3s2LGDE088kYaGBn7961/zwgsvVOcFRUQqEGSPevAjxJIl0XDH5MlRSFfro8X8+fN57bXXOOeccwA4/vjjaW9vZ+HChdx7772ceuqpTJ06ldmzZ1fnBUVEKmDuXvUnbWlp8aEXDnjttdc49dRTq/5aRaf3VaQ2mNk6d28p9ViQQx8iIvIBBbWISOAU1CIigVNQi4gETkEtIhI4BbWISOAU1CU88MADvPXWWwfvX3311WzYsKHi5+3q6uLhhx8u++cWLVrEihUrKn59EcmncIM66WvbHMHQoL7//vuZNm1axc872qAWkWILM6iTWOcUaG9v5+yzz2bmzJlce+219PX1sWjRIqZPn86MGTNYtmwZK1asoLOzk9bWVmbOnMmePXuYO3cugyfwHH/88dxyyy2cdtppzJs3j7Vr1zJ37lxOOeUUVq5cCUSBfO655zJr1ixmzZrFz3/+cwBuvfVWfvrTnzJz5kyWLVtGX18ft9xyy8HlVu+77z4gWovkxhtvZOrUqcybN4/t27dX1G4RyTl3r/qtubnZh9qwYcOHtg2rsdE9iujDb42N8Z+jxOt/7nOf83379rm7+/XXX++33367z5s37+A+7733nru7n3feef7SSy8d3H7ofcCfeuopd3e/+OKL/YILLvB9+/b5+vXr/YwzznB39127dvmePXvc3X3jxo0++H4899xzftFFFx183vvuu8/vuOMOd3ffu3evNzc3++bNm/3xxx/3efPm+YEDB3zbtm1+wgkn+GOPPTZsu0Qk/4BOHyZTg1zrI4l1Tp955hnWrVvHWWedBcCePXtYuHAhmzdv5qabbuKiiy5i/vz5Iz7P0UcfzcKFCwGYMWMGxxxzDPX19cyYMYOuri4A9u/fz4033sj69esZM2YMGzduLPlcq1at4tVXXz04/rxjxw42bdrE6tWrufzyyxkzZgwf//jH+exnPzvqdotI/oUZ1Amsc+ruXHXVVXzzm988bPvSpUv5yU9+wr333sujjz7K8uXLj/g89fX1B6/+XVdXxzHHHHPw+wMHDgCwbNkyTjrpJF555RX6+/sZO3bssDV9+9vfZsGCBYdtf+qpp0bVRhGpTWGOUSewzun555/PihUrDo73vvvuu2zZsoX+/n4uvfRS7rzzTl5++WUAxo0bx/vvvz/q19qxYwcnn3wydXV1PPTQQ/T19ZV83gULFnDPPfewf/9+ADZu3MiuXbuYM2cOjzzyCH19fbz99ts899xzo65FRPIvzB51AuucTps2jTvvvJP58+fT399PfX09d911F5dccsnBC90O9rYXLVrEddddx7HHHsuaNWvKfq0bbriBSy+9lAcffJCFCxcevGDB6aefzpgxYzjjjDNYtGgRN998M11dXcyaNQt3Z+LEiTz55JNccsklPPvss0ybNo3JkycfXI5VRIpJy5zmnN5XkdqgZU5FRHJMQS0iErhUgzqJYZYi0/spUgypBfXYsWPp7e1VuFSJu9Pb2zvs1D8RqR2xZn2Y2c3ANYAB33P3fyr3hSZNmkR3dzc9PT3l/qgMY+zYsUyaNCnrMkQkYSMGtZlNJwrps4F9wI/N7N/d/fVyXqi+vp4pU6aMrkoRkQKLM/RxKvCiu+929wPAfwF/nmxZIiIyKE5Q/wo418zGm1kDcCHwyaE7mdliM+s0s04Nb4hIoSS8LPOIQx/u/pqZfQtYBewC1gN9JfZrA9ogOuGlqlWKiIRqcFnm3buj+4PLMkNFZ1MfKtasD3f/F3dvdvc5wHtA6eXgRESKZsmSD0J60O7d0fYqiTvr42Puvt3MJhONT8+uWgUiInmWwLLMQ8VdlOlxMxsP7Af+yt1/X7UKRETyLIFlmYeKO/RxrrtPc/cz3P2Zqr26iEjeJbAs81Ba60NEpBKtrdDWBo2NYBZ9bWur2oFECHU9ahGRPGltrWowD6UetYhI4BTUIiKBU1CLiAROQS0iEjgFtYhI4BTUIjUk4bWBJCOanidSI1JYG0gyoh61SI1IYW0gyYiCWqRGpLA2kGREQS1SI4ZbA6iKawNJRhTUIjUihbWBJCMKapEakcLaQJIRzfoQqSEJrw0kGVGPWkQkcApqEZHAKahFRAKnoBYRCZyCWkTSoYVIRk2zPkQkeVqIpCLqUYtI8rQQSUUU1CKSPC1EUhEFtYhUx5HGoLUQSUUU1CJSucEx6C1bwP2DMejBsNZCJBVRUItI5UYag9ZCJBUxd6/6k7a0tHhnZ2fVn1dEAlVXF/WkhzKD/v7068khM1vn7i2lHlOPWkQqpzHoRCmoRWR4HR0wYULUMzaLvi91oorGoBOloBYJUOYn8Q0G9BVXQG/vB9t7e+m76isfLkhj0IlSUEthZR6GwxhpAkVqBRwa0IcY07efnTeXOFGltRW6uqIx6a6uD4V0qO93HuhgohTS0DOaIfqkHkInsKkpCuehGhuj/MusgEP0Y9R5/IOEIb/foTjSwcRYQW1mfwtcDTjwS+Ar7r53uP0V1BK6zMPwCDKfQDFcAYfoopEm74r9lCG/36GoaNaHmX0C+Gugxd2nA2OAy6pboki6Qj6jOfMJFCO80F7quWt8eQcJQ36/8yDuGPVRwLFmdhTQALyVXEkiycs8DI8g8wkUJQrwgVsP47m+/vv8yT+XN14R8vudByMGtbtvA/4B2Aq8Dexw91VD9zOzxWbWaWadPT091a9UpIoyD8MjyHwCxZACdo5v5Obx7Ywx56zGd5j3/dayawn5/c6DEceozexE4HHgS8DvgceAFe7ePtzPaIxa8qCjIzrDeevWqGe3dKkObCVJ7/eRVXpm4jzgDXfvcff9wA+BP61mgSJZGGE2WT4FPAeuJt/vlMS5wstWYLaZNQB7gPMBdZdFQqOrqNSsOGPULwIrgJeJpubVAW0J1yUi5dJVVGpWrFkf7n6bu/+xu0939y+7+x+SLkxEjqDUEEfGc+ACHnXJPV3cViRvhhvi+OhHS5/2ncIcOI26JEtrfUiq1OuqguGGOCCzOXAadUmWglpSk/liQ7ViuKGMd9/NbAK2zjxMloJaUlP0XldZnyZGe6HYjObA6czDZCmoJTVF7nWV9WkihxeKDbCk2uLuVb81Nze7yFCNje5R8hx+a2zMurLkldX2ODu3t0f3zaKv7e0ptOLIAiwpV4BOHyZTtR61pKbIaxKXtXRp5uucShZ0cVsJQuaLDWWorDFcDfjKEApqSVVR13tYuhQW1XfwBk30UccbNLGovqP0GK4GfGUInfAikoJWOviSLeYoonGfJrbwPVs88Ac45H+rwf+9tNScDNAYtUgadC0qGYHGqEWyVuS5iVIxBbVIGnJ6gFCn/IdBQS2ShhweINQp/+FQUIukIYdzE4t+yn9IdDBRRErSeTfp0sFEESlbTofVa5KCWkRKyuGwes1SUItISTkcVq9ZCmqRHEpr2lxRT/kPjU4hF8kZXZ+weNSjluLK6dkcmjZXPOpRSzHluFuqs9GLRz1qKaYcd0s1ba54FNRSTDnulmraXPEoqKWYctwt1bS54lFQSzHlvFuqaXPFoqCWYlK3VHJEsz6kuFpbFcySC+pRi4gETkEtIhI4BbWISOBGDGozm2pm6w+5/Z+Z/U0KtYmICDGC2t1/4+4z3X0m0AzsBp5IujARyU5Ol0GpWeXO+jgf+K27b0miGBHJXo6XQalZ5Y5RXwb8IIlCRCQMOV4GpWbFDmozOxr4PPDYMI8vNrNOM+vs6empVn0ikrIcL4NSs8rpUf8Z8LK7/67Ug+7e5u4t7t4yceLE6lQntUeDn8HL8TIoNaucoL6cBIc99PdbAIODn1u2gPsHg5/6ZQcl58ug1KRYQW1mxwEXAD9Mogj9/RaEBj9zQcughMfcvepP2tLS4p2dnbH3b2qKwnmoxsZoZTCpEXV10f/EQ5lFy8CJFJiZrXP3llKPBXFmog5eFIQGP0VGJYig1t9vQWjwU2RUgghq/f3GF8RB19EWocFPkdFx96rfmpubvVzt7e6Nje5m0df29rKfoua1t7s3NLhHA73RraEh5fcqiCJEag/Q6cNkahAHEyWeIA66BlGESO0J/mCixBPEQdcgihApFgV1jgRx0DWIIkSKRUGdI0EcdA2iCJFiUVDnSBCTJoIoQqRYdDBRRCQAOpgohwtiMraIxFXuFV4k73T5DpHcUY+6aLSCnUjuKKiLRvOgRXJHQV00mgctkjsK6qLRPGiR3FFQ59loZm9oHrRI7iio82YwnM3gy18e1fXLOmiliS7q6KeJLjpQSIuETEGdJ4deXBI+fFmrGLM3dH1KkfzRmYl5MtwSo4ca4fqDWqVUJEw6M7FWxJlCN8LsjWrMztOJjSLpUlDnyUhT6GLM3qh0dp6GTkTSp6AOQdwuaqmpdWbR15izNyqdnacTG0XSp6DOWjld1FJT6x56KPq5rq5YU+wqnZ2nExtF0qeDiVnL2dG9nJUrkhs6mBiynHVRdWKjSPoU1Fmr4OheFrMvdGKjSPoU1FkbZRc1y9kXra3RMEd/f+yhcRGpgII6a6Psomr2hUhx6GBiTtXVffgMchjxxEQRCZQOJtYgLSstUhwK6pzS7AuR4lBQ55RmX4gUR6yrkJvZR4D7gemAA3/p7msSrEtiaG1VMIsUQaygBv4Z+LG7f9HMjgYaRvoBERGpjhGD2sxOAOYAiwDcfR+wL9myRERkUJwx6ilAD/B9M/uFmd1vZscN3cnMFptZp5l19vT0VL1QEZGiihPURwGzgHvc/UxgF3Dr0J3cvc3dW9y9ZeLEiVUuU0SkuOIEdTfQ7e4vDtxfQRTcIiKSghGD2t3/F3jTzKYObDof2JBoVUnRNaREJIfizvq4CegYmPGxGfhKciUlZHAVo8EFMgZXMQLNcRORoBVnrQ+teC8iAdNaH5C7BfpFRAYVJqh3frT0akXDbRcRCUXtBPUIBwq/wVJ2DTmhchcNfAOtYiQiYauNoI5xuZPvvNvKNbTRRSP9GF00cg1tfOddHUgUkbDVxsHEGAcKdSxRREJW+wcTYxwo1PrNIpJXtRHUMS53ovWbRSSvaiOoY3aXdfVsEcmj2ghqdZdFpIaFE9SVrsOR4+6yliARkSOJu9ZHsgq8DkeBmy4iMYUxPa/Ac+cK3HQROUT40/MKvA5HgZsuIjGFEdQxptfVqgI3XURiCiOoC3w2SoGbLiIxhRHUBZ5eV+Cmi0hMYRxMFBEpuPAPJoqIyLAU1CIigVNQi4gETkEtIhI4BbWISOAU1CIigVNQi4gETkEtIhI4BbWISOAU1CIigVNQi4gETkEtIhI4BbWISOAU1CIigVNQi4gELtZVyM2sC3gf6AMODLdmqoiIVF+soB7wGXd/J7FKRESkpJoZ+ujogKYmqKuLvnZ0ZF2RiEh1xA1qB1aZ2TozW1xqBzNbbGadZtbZ09NTvQpj6OiAxYthyxZwj74uXqywFpHaEOuaiWb2CXffZmYfA/4TuMndVw+3f9rXTGxqisJ5qMZG6OpKrQwRkVGr+JqJ7r5t4Ot24Ang7OqVV7mtW8vbLiKSJyMGtZkdZ2bjBr8H5gO/SrqwckyeXN52EZE8idOjPgn4bzN7BVgL/Mjdf5xsWeVZuhQaGg7f1tAQbRcRybsRp+e5+2bgjBRqGbXW1ujrkiXRcMfkyVFID24XEcmzcuZRB621VcEsIrWpZuZRi4jUKgW1iEjgFNQiIoFTUIuIBE5BLSISuFinkJf9pGY9QImTug+aABR1JT61vZiK2vaithvKb3uju08s9UAiQT0SM+ss6prWarvaXiRFbTdUt+0a+hARCZyCWkQkcFkFdVtGrxsCtb2Yitr2orYbqtj2TMaoRUQkPg19iIgETkEtIhK4RIPazBaa2W/M7HUzu7XE48eY2SMDj79oZk1J1pOmGG3/OzPbYGavmtkzZtaYRZ1JGKnth+x3qZm5mdXE9K047Tazvxj4vf+PmT2cdo1JifHvfbKZPWdmvxj4N39hFnVWm5ktN7PtZlbyYioWuXvgfXnVzGaN6oXcPZEbMAb4LXAKcDTwCjBtyD43APcOfH8Z8EhS9aR5i9n2zwANA99fX6S2D+w3DlgNvAC0ZF13Sr/zTwG/AE4cuP+xrOtOse1twPUD308DurKuu0ptnwPMAn41zOMXAv8BGDAbeHE0r5Nkj/ps4HV33+zu+4B/A74wZJ8vAP868P0K4HwzswRrSsuIbXf359x998DdF4BJKdeYlDi/d4A7gG8Be9MsLkFx2n0N8F13fw8OXoO0FsRpuwN/NPD9CcBbKdaXGI8u8v3uEXb5AvCgR14APmJmJ5f7OkkG9SeANw+53z2wreQ+7n4A2AGMT7CmtMRp+6G+SvS/bi0Yse0DH/8+6e4/SrOwhMX5nX8a+LSZ/czMXjCzhalVl6w4bb8duMLMuoGngJvSKS1z5WZBSTVzhZe8MrMrgBbgvKxrSYOZ1QF3AYsyLiULRxENf8wl+gS12sxmuPvvsywqJZcDD7j7P5rZOcBDZjbd3fuzLiwPkuxRbwM+ecj9SQPbSu5jZkcRfSTqTbCmtMRpO2Y2D1gCfN7d/5BSbUkbqe3jgOnA82bWRTRut7IGDijG+Z13Ayvdfb+7vwFsJAruvIvT9q8CjwK4+xpgLNGiRbUuVhaMJMmgfgn4lJlNMbOjiQ4Wrhyyz0rgqoHvvwg86wMj8Dk3YtvN7EzgPqKQrpWxShih7e6+w90nuHuTuzcRjc9/3t07sym3auL8e3+SqDeNmU0gGgrZnGKNSYnT9q3A+QBmdipRUPekWmU2VgJXDsz+mA3scPe3y36WhI+IXkjUa/gtsGRg298T/WFC9Mt6DHgdWAuckvVR3BTb/jTwO2D9wG1l1jWn1fYh+z5PDcz6iPk7N6Jhnw3AL4HLsq45xbZPA35GNCNkPTA/65qr1O4fAG8D+4k+MX0VuA647pDf+XcH3pdfjvbfuk4hFxEJnM5MFBEJnIJaRCRwCmoRkcApqEVEAqegFhEJnIJaRCRwCmoRkcD9P4jSg7+0cwW/AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "y_ = linear_model(x_train)\n", "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n", "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "经过 100 次更新,我们发现红色的预测结果已经比较好的拟合了蓝色的真实值。\n", "\n", "现在你已经学会了你的第一个机器学习模型了,再接再厉,完成下面的小练习。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.4 练习题\n", "\n", "重启 notebook 运行上面的线性回归模型,但是改变训练次数以及不同的学习率进行尝试得到不同的结果" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. 多项式回归模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们更进一步,讲一讲多项式回归。什么是多项式回归呢?非常简单,根据上面的线性回归模型\n", "\n", "$$\n", "\\hat{y} = w x + b\n", "$$\n", "\n", "这里是关于 x 的一个一次多项式,这个模型比较简单,没有办法拟合比较复杂的模型,所以我们可以使用更高次的模型,比如\n", "\n", "$$\n", "\\hat{y} = w_0 + w_1 x + w_2 x^2 + w_3 x^3 + \\cdots\n", "$$\n", "\n", "这样就能够拟合更加复杂的模型,这就是多项式模型,这里使用了 x 的更高次,同理还有多元回归模型,形式也是一样的,只是出了使用 x,还是更多的变量,比如 y、z 等等,同时他们的 loss 函数和简单的线性回归模型是一致的。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "首先我们可以先定义一个需要拟合的目标函数,这个函数是个三次的多项式" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3\n" ] } ], "source": [ "# 定义一个多变量函数\n", "\n", "w_target = np.array([0.5, 3, 2.4]) # 定义参数\n", "b_target = np.array([0.9]) # 定义参数\n", "\n", "f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(\n", " b_target[0], w_target[0], w_target[1], w_target[2]) # 打印出函数的式子\n", "\n", "print(f_des)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们可以先画出这个多项式的图像" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAlIklEQVR4nO3deXxU9b3/8dcnG2FfQkAkQED2TYGwVcuDFheqXjdai61L1YpetbWtt7Xq7bX9aW+1trbaai1Vi1aLIGq1dSlqVVzKLpssBtmSCCQsCQSyznx+f2TgorIlM8mZmbyfj0ceM3PmzPl+zhDeOfOd7/kec3dERCQ5pQRdgIiINB6FvIhIElPIi4gkMYW8iEgSU8iLiCQxhbyISBI77pA3s8fMrNjMVh2yrJOZvWZm+ZHbjpHlZmYPmNl6M1thZiMbo3gRETm6+hzJzwAmf2bZj4E33L0f8EbkMcBXgH6Rn2nAH6IrU0REGsLqczKUmeUC/3D3oZHH64CJ7r7VzLoBb7n7ADP7Y+T+zM+ud7Ttd+7c2XNzcxu2JyIizdSSJUt2uHv24Z5Li3LbXQ8J7m1A18j97kDBIesVRpYdNeRzc3NZvHhxlCWJiDQvZrb5SM/F7ItXr/tIUO85EsxsmpktNrPFJSUlsSpHRESIPuS3R7ppiNwWR5YXAT0OWS8nsuxz3H26u+e5e1529mE/bYiISANFG/IvAldE7l8BvHDI8ssjo2zGAWXH6o8XEZHYO+4+eTObCUwEOptZIXAHcDcw28yuBjYDF0dWfxk4G1gP7AeubGiBNTU1FBYWUllZ2dBNyGdkZmaSk5NDenp60KWISCM77pB390uO8NSkw6zrwA0NLepQhYWFtG3bltzcXMwsFpts1tydnTt3UlhYSO/evYMuR0QaWdyf8VpZWUlWVpYCPkbMjKysLH0yEmkm4j7kAQV8jOn9FGk+EiLkRUSS2f2v57Ngw85G2bZCvgnk5uayY8eOoMsQkTi0oaSc37z+EQs27mqU7Svk68HdCYfDQZcRN3WISPT+Mn8z6anG1DE9jr1yAyjkj2HTpk0MGDCAyy+/nKFDh1JQUMC9997L6NGjGT58OHfcccfBdS+44AJGjRrFkCFDmD59+jG3/eqrrzJy5EhOPvlkJk2qG6T005/+lF/96lcH1xk6dCibNm36XB133nknP/zhDw+uN2PGDG688UYAnnzyScaMGcMpp5zCtddeSygUitXbISIxtK+qljmLCzl7WDe6tM1slDainbumSf3s7x+y+pM9Md3m4BPbccd/DDnqOvn5+Tz++OOMGzeOuXPnkp+fz8KFC3F3zjvvPObNm8eECRN47LHH6NSpExUVFYwePZopU6aQlZV12G2WlJRwzTXXMG/ePHr37s2uXcf+qHZoHSUlJYwfP557770XgFmzZnH77bezZs0aZs2axXvvvUd6ejrXX389Tz31FJdffnn93xwRaVR/W1bE3qpaLh/fq9HaSKiQD0qvXr0YN24cAHPnzmXu3LmMGDECgPLycvLz85kwYQIPPPAAzz//PAAFBQXk5+cfMeTnz5/PhAkTDo5V79SpU73qyM7Opk+fPsyfP59+/fqxdu1aTj31VB588EGWLFnC6NGjAaioqKBLly7RvQEiEnPuzhPvb2bIie0Y2bNjo7WTUCF/rCPuxtK6deuD992dW2+9lWuvvfZT67z11lu8/vrr/Pvf/6ZVq1ZMnDixQWPR09LSPtXffug2Dq0DYOrUqcyePZuBAwdy4YUXYma4O1dccQW/+MUv6t22iDSdBRt3sW77Xn45ZXijDmtWn3w9nXXWWTz22GOUl5cDUFRURHFxMWVlZXTs2JFWrVqxdu1a5s+ff9TtjBs3jnnz5rFx40aAg901ubm5LF26FIClS5cefP5wLrzwQl544QVmzpzJ1KlTAZg0aRJz5syhuLj44HY3bz7iLKQiEpAn/r2J9i3T+Y+TT2zUdhLqSD4enHnmmaxZs4bx48cD0KZNG5588kkmT57Mww8/zKBBgxgwYMDBbpUjyc7OZvr06Vx00UWEw2G6dOnCa6+9xpQpU3jiiScYMmQIY8eOpX///kfcRseOHRk0aBCrV69mzJgxAAwePJi77rqLM888k3A4THp6Og8++CC9ejVen5+I1M/Wsgr++eF2rj6tNy0zUhu1rXpdGaqx5eXl+WcvGrJmzRoGDRoUUEXJS++rSHDum7uO3725nrf/60v0zGoV9fbMbIm75x3uOXXXiIg0oaraEH9duIUvD+gSk4A/FoW8iEgTenXVNnaUV3NZIw6bPFRChHw8dSklA72fIsF5/P1N5Ga1YkK/prkSXtyHfGZmJjt37lQwxciB+eQzMxvn7DoRObLlBaUs3VLKZeNzSUlpmtlg4350TU5ODoWFhegi37Fz4MpQItK0Hnl3I21bpHFxXtP9/4v7kE9PT9cVjEQk4RWVVvDyyq1cdWoubTOb7tKbcd9dIyKSDGa8V3di47dObdqDVoW8iEgj21tZw9MLCzh7WDe6d2jZpG0r5EVEGtmsRQXsrarlmi82fddzTELezL5vZh+a2Sozm2lmmWbW28wWmNl6M5tlZhmxaEtEJJHUhsL8+b1NjMntxPCcDk3eftQhb2bdge8Cee4+FEgFpgL3AL9x977AbuDqaNsSEUk0r364jaLSCr4dwFE8xK67Jg1oaWZpQCtgK/BlYE7k+ceBC2LUlohIQnB3/vTORnKzWjFpUNdAaog65N29CPgVsIW6cC8DlgCl7l4bWa0Q6H6415vZNDNbbGaLNRZeRJLJks27WV5QytWn9Sa1iU5++qxYdNd0BM4HegMnAq2Bycf7enef7u557p6Xnd00p/mKiDSFR97ZSPuW6UwZFdzJh7Horjkd2OjuJe5eAzwHnAp0iHTfAOQARTFoS0QkIWzasY9/rt7GN8f2pFVGcOedxiLktwDjzKyV1V3DahKwGngT+GpknSuAF2LQlohIQvjjvI/JSE3hyiY++emzYtEnv4C6L1iXAisj25wO3AL8wMzWA1nAo9G2JSKSCLaVVTJnSSEX5/Ugu22LQGuJyWcId78DuOMzizcAY2KxfRGRRPKndzYQdpg2oU/QpeiMVxGRWNq1r5q/LtjC+aecSI9OjX/lp2NRyIuIxNCM9zZSWRvi+oknBV0KoJAXEYmZvZU1zHh/E2cNPoG+XdoGXQ6gkBcRiZmnFmxhT2Ut138pPo7iQSEvIhITlTUhHnlnI1/s1zmQiciORCEvIhIDzywuYEd5FddP7Bt0KZ+ikBcRiVJNKMzDb29gZM8OjOvTKehyPkUhLyISpeeWFlJUWsENX+pL3Yn/8UMhLyISheraMA+8sZ7hOe358sAuQZfzOQp5EZEoPLOkgKLSCr5/Rv+4O4oHhbyISINV1Yb4/b/WM6JnByb2j8+p0hXyIiINNGtRAVvLKrn5jAFxeRQPCnkRkQaprAnx4JvrGZPbiVP7ZgVdzhEp5EVEGuCvC7awfU9V3PbFH6CQFxGpp4rqEA+99THj+2Qx/qT4PYoHhbyISL09OX8zO8rrjuLjnUJeRKQe9lXV8vDbH/PFfp0Z0zu+zm49HIW8iEg9PPruRnbuq06Io3hQyIuIHLcd5VX88e2POWtIV0b27Bh0OcdFIS8icpx+90Y+lbVhfjR5YNClHDeFvIjIcdi0Yx9PLdjC10f34KTsNkGXc9xiEvJm1sHM5pjZWjNbY2bjzayTmb1mZvmR28T4bCMichj3zl1HemoK35vUL+hS6iVWR/L3A6+6+0DgZGAN8GPgDXfvB7wReSwiknCWF5Ty0oqtXDOhD13aZQZdTr1EHfJm1h6YADwK4O7V7l4KnA88HlntceCCaNsSEWlq7s4vXllD5zYZTJvQJ+hy6i0WR/K9gRLgz2b2gZk9Ymatga7uvjWyzjag6+FebGbTzGyxmS0uKSmJQTkiIrHz1roS5m/YxXcn9aNNi7Sgy6m3WIR8GjAS+IO7jwD28ZmuGXd3wA/3Ynef7u557p6XnR2fU3WKSPMUCjt3v7KW3KxWXDKmZ9DlNEgsQr4QKHT3BZHHc6gL/e1m1g0gclscg7ZERJrMnCUFrNu+lx+eNZD01MQcjBh11e6+DSgwswGRRZOA1cCLwBWRZVcAL0TblohIU9lTWcO9/1xHXq+OnD3shKDLabBYdTB9B3jKzDKADcCV1P0BmW1mVwObgYtj1JaISKN74PV8du6rZsaVY+J6KuFjiUnIu/syIO8wT02KxfZFRJrS+uJyZry/ia/n9WBo9/ZBlxOVxOxkEhFpJO7O//vHalpmpPJfZw049gvinEJeROQQb6wpZt5HJXzv9P50btMi6HKippAXEYmoqg1x50ur6dulDZeP7xV0OTGhkBcRiXjs3U1s3rmf/zl3cMIOmfys5NgLEZEoFe+p5Pf/yuf0QV2Z0D95TsxUyIuIAHe+tIaakPPf5wwKupSYUsiLSLP31rpi/r78E274Ul9yO7cOupyYUsiLSLNWUR3iJy+sok92a66bmHizTB5L4k2pJiISQ/e/kU/BrgqenjaOFmmpQZcTczqSF5Fma+22PTzyzgYuzsthXJ+soMtpFAp5EWmWwmHn1udW0q5lOrd+Jbm+bD2UQl5EmqW/LtzCB1tK+e9zBtGxdUbQ5TQahbyINDvFeyq559W1nNo3iwtHdA+6nEalkBeRZsXd+e+/raKqNsxdFwxL6GmEj4dCXkSalReWfcLc1dv5rzP70zvJxsQfjkJeRJqN7Xsq+Z8XVjGqV0euPi35xsQfjkJeRJoFd+fHz66gOhTmV187mdSU5O6mOUAhLyLNwjNLCnlzXQm3TB7YLLppDlDIi0jS+6S0gjv/vppxfTpxxfjcoMtpUgp5EUlq7s4tz64g5M69Xz2ZlGbSTXOAQl5EktqTC7bwTv4Objt7ED06tQq6nCYXs5A3s1Qz+8DM/hF53NvMFpjZejObZWbJe0qZiMSlddv2ctc/VjOhfzbfHNsz6HICEcsj+ZuANYc8vgf4jbv3BXYDV8ewLRGRo6qsCfGdmUtpm5nOr792ctKf9HQkMQl5M8sBzgEeiTw24MvAnMgqjwMXxKItEZHjcec/VvPR9nLuu/hkstu2CLqcwMTqSP63wI+AcORxFlDq7rWRx4XAYSeIMLNpZrbYzBaXlJTEqBwRac5eWbmVpxZs4doJfZLqeq0NEXXIm9m5QLG7L2nI6919urvnuXtednbz/scQkegVlVZwy7MrODmnPTefOSDocgIXiytDnQqcZ2ZnA5lAO+B+oIOZpUWO5nOAohi0JSJyRLWhMDfN/ICwwwOXjCAjTQMIo34H3P1Wd89x91xgKvAvd/8m8Cbw1chqVwAvRNuWiMjR/Pb1fBZv3s1dFwylV1bzOav1aBrzz9wtwA/MbD11ffSPNmJbItLMvbZ6O79/cz1fG5XDBUk+R3x9xPRC3u7+FvBW5P4GYEwsty8icjgfl5Tzg1nLGNa9PXdeMDTocuKKOqxEJKGVV9Vy3V+WkJ6WwsOXjSIzPTXokuKKQl5EEpa786M5y/m4pJzfXzKC7h1aBl1S3FHIi0jCmj5vAy+v3MaPvzKQL/TtHHQ5cUkhLyIJ6b31O7jn1bWcM7wb13yxeVzlqSEU8iKScNYXl/OfTy6hb5c2/HLK8GY7L83xUMiLSELZWV7FVTMWkZGWwqNXjKZ1i5gOEkw6endEJGFU1oS45onFbN9TyaxrxzfL+eHrSyEvIgkhHHZufmY5HxSU8tA3RnJKjw5Bl5QQ1F0jIgnhV3PX8dKKrdz6lYF8ZVi3oMtJGAp5EYl7Ty/cwkNvfcw3xvbUSJp6UsiLSFx7ddVWbnt+JRP6Z/Oz84ZoJE09KeRFJG7N+6iE78z8gFN6dODhS0eSnqrIqi+9YyISlxZv2sW0vyymb5e2/PlbY2iVoXEiDaGQF5G48+EnZVw5YxHd2rfkiavG0L5VetAlJSyFvIjElQ0l5Vz+6ELatkjjyW+PbdYX4Y4FhbyIxI0NJeV8408LAPjLt8dqVskYUCeXiMSF/O17+cYjCwiHnSe/PZaTstsEXVJSUMiLSODWbN3DpY8sICXFeHraOPp1bRt0SUlD3TUiEqhVRWVc8qf5pKemMEsBH3M6kheRwHywZTeXP7aQdpnpzLxmHD2zNOFYrEV9JG9mPczsTTNbbWYfmtlNkeWdzOw1M8uP3HaMvlwRSRZvf1TCpY8soFPrDGZfN14B30hi0V1TC9zs7oOBccANZjYY+DHwhrv3A96IPBYRYfaiAq6asYheWa155trxGkXTiKLurnH3rcDWyP29ZrYG6A6cD0yMrPY48BZwS7TtiUjicnd++3o+97+Rz4T+2Tz0zZG00UU/GlVM310zywVGAAuArpE/AADbgK6xbEtEEktNKMxtz63kmSWFfG1UDv970TDNRdMEYhbyZtYGeBb4nrvvOXSmOHd3M/MjvG4aMA2gZ8+esSpHROJI2f4abpy5lHfyd3DTpH587/R+mk2yicTkz6iZpVMX8E+5+3ORxdvNrFvk+W5A8eFe6+7T3T3P3fOys7NjUY6IxJF12/Zy3oPvMn/DTu6ZMozvn9FfAd+EYjG6xoBHgTXuft8hT70IXBG5fwXwQrRtiUhieWnFVi586D32V4d4eto4vj5an9abWiy6a04FLgNWmtmyyLLbgLuB2WZ2NbAZuDgGbYlIAgiFnXv/uY6H3/6YkT078IdLR9G1XWbQZTVLsRhd8y5wpM9ek6LdvogklpK9Vfxg9jLeyd/BN8f25I7/GEJGmr5gDYrGLolIzLy5tpgfzlnO3spa7pkyTN0zcUAhLyJRq6wJcfcra5nx/iYGntCWv14zjv6agyYuKORFJCrrtu3lpqc/YO22vVx5ai63TB5IZnpq0GVJhEJeRBqkJhRm+rwN3P9GPu0y05hx5WgmDugSdFnyGQp5Eam3ZQWl/PjZFazdtpezh53Az84bqsv0xSmFvIgct31Vtfx67kfMeH8jXdpmMv2yUZw55ISgy5KjUMiLyDG5O6+s2sbPX1pDUWkFl43rxY8mD6BtZnrQpckxKORF5KhWFJZy5z9Ws2jTbgae0JY5140nL7dT0GXJcVLIi8hhbS2r4N5X1/HcB0V0bpPBLy4axsV5PUhN0bwziUQhLyKfsqO8iunzNvDEvzcRdvjPiSdx/cST1DWToBTyIgJA8d5Kpr+9gScXbKa6Nsz5p3TnB2f0p0cnXZYvkSnkRZq5T0orePTdjTwVCfcLRnTnxi/1pU92m6BLkxhQyIs0Q+7Oks27+fP7m3h11TYALjilOzd+uS+9O7cOuDqJJYW8SDNSWRPilVVb+fN7m1hRWEa7zDSuPq03l43rpW6ZJKWQF0ly7s6yglKeXVrIi8s+YU9lLSdlt+bOC4YyZWR3WmUoBpKZ/nVFklTBrv38fcUnPLukkI9L9pGZnsJZQ07gq6NyOPWkzqRoKGSzoJAXSRLuTn5xOa+u2sY/P9zGh5/sAWB0bkemTejD2cO6aRhkM6SQF0lg+6pqWbhxF++u38G/1hazccc+AEb27MBtZw9k8pBu9MxSX3tzppAXSSD7q2tZUVjGgg27eG/9DpZu2U1t2MlIS2Fs705cfVpvzhzclS66nqpEKORF4lRNKMzGHftYUVjGB1t288GWUtZt30so7JjBsO7tuWZCH07r25lRvTrqQh1yWAp5kYBV1oTYsms/m3fuZ31xOeu27WHttr18XFJOTcgBaJuZxik9OnDD4L6M6NmBET060KFVRsCVSyJo9JA3s8nA/UAq8Ii7393YbYrEA3dnX3WI3fuqKSmvonhPJdv3VLFtTyXb91RSuLuCzTv3sX1P1aded2L7TAac0JaJA7ow8IS2DDmxHSdlt9FoGGmQRg15M0sFHgTOAAqBRWb2oruvbsx2pensr66lZG8VO8qrKN1fw57KGvZU1LKnou7+vuoQldUhKmrqfiprQlTVhgmFnZqQUxuqux/yuiPWyM1BqSlGih24NVJTjLQUIy01hdQUIz3VSEtJIT3VSE9NIS01cj8lhbTIsvRUiyxPIS3l/7Zx4DYlxTCrayfFDAMcCLvjXhfWYa/rPqkOhamp9YP391fXsr86xP6qEPsi98sqaijdX0NZRfXBI/FDpaUYXdq2oHvHlpzWN5vcrFb0zGpFr6zW9O7cmvYtNQJGYqexj+THAOvdfQOAmT0NnA8o5BNAKOwU7a6gYPd+Cnfvp3B3ReRnP8V7q9ixt4p91aEjvr5leiqtW6TRMiOFzLRUWmakkpmWSpsWaZGA/b8ATjE4cJxqVnfP3Qk5hMNO2L3uj0HYqY3c1oTCVNWEKQ/V1v3BCIepCdUtrwmFqQ051ZHbA8/FyoE/MK0y0miVkRr5qbvfr0sbOrTKoEOrdDq2SqdDywyy27agS7sWdG2XSadWGToqlybT2CHfHSg45HEhMLaR25R6cneKSitYVVTGR9vLWV9cTn5xORtKyqmqDR9cLzXF6NY+k+4dWnJyTgey27agc5sWdG6TQee2LejUKoN2LdNpl5lG28x0MtJSAtyrzztwRF4bPhD8dX8sDiw/cBt2JyVyZG9mWOQIP/3gJ4MUzakuCSPwL17NbBowDaBnz54BV9M8lO6vZvGm3awoLGV5YRkri8rYta/64PM5HVvSt0sbTuubRd8ubejZqTU5HVvSrX0maanxFdz1YWakGqSmpNIi8N98kabR2L/qRUCPQx7nRJYd5O7TgekAeXl5sfs8LQeV7q9mwcZdzN+wk/kbdrF22x7c647M+3Vpw+mDujAspwPDurenf9c2mstEJIk09v/mRUA/M+tNXbhPBb7RyG02e+7O2m17eWPNdt5YW8yyglLcoUVaCnm5HfnB6f0Z2yeLYd3b0zJDY6tFklmjhry715rZjcA/qRtC+Zi7f9iYbTZX4bCzcNMuXl65lTfWFFNUWgHA8Jz23DSpH6f27czwnPa0SFOoizQnjf653N1fBl5u7HaaqzVb9/C3ZUX8fdknfFJWSWZ6Cqf17cyNX+7Llwd2oatObxdp1tT5moDK9tcwZ2khsxcVsG77XtJSjAn9s7nlKwM5Y3BX9amLyEFKgwSysrCMv8zfxIvLP6GyJsyInh248/whnD2sG1ltWgRdnojEIYV8nKsNhXlp5VYee3cjywvLaJWRykUjc7h0bC8Gn9gu6PJEJM4p5ONUVW2IZ5cU8fDbH7Nl135Oym7Nz84bwoUju9NOF34QkeOkkI8z+6pqmblwC396ZwPb91Rxck57bj9nFGcM6qpT4UWk3hTycaImFObpRQXc/3o+O8qr+MJJWdx38Sl84aSsg3O5iIjUl0I+YO7Oq6u2ce8/17Fhxz7G5Hbij5eNZFSvTkGXJiJJQCEfoCWbd3HXS2v4YEsp/bq04ZHL85g0qIuO3EUkZhTyAdi1r5q7X1nD7MWFdG3XgnumDGPKyJyEnvxLROKTQr4JhcPOM0sKuPuVteytrOXaCX347qR+tNaUiCLSSJQuTWTdtr3c/vxKFm/ezejcjtx1wTAGnNA26LJEJMkp5BtZKOz86Z0N3Df3I1q3SOWXU4bz1VE5Gg4pIk1CId+Ituzcz83PLGPRpt2cNaQr/3vhME0/ICJNSiHfCNydWYsKuPMfq0kx476LT+bCEd01akZEmpxCPsbK9tdw8zPLeX3Ndr5wUhb3fu1kundoGXRZItJMKeRjaFVRGf/51BK2lVXyk3MHc+UXctX3LiKBUsjHgLszc2EBP/37h3RuncHsa8czomfHoMsSEVHIR6uiOsTtf1vJc0uL+GK/ztw/dQSdWmcEXZaICKCQj0pRaQVXz1jEuu17uWlSP747qR+p6p4RkTiikG+gFYWlXP34YiqrQzz2rdF8aUCXoEsSEfkchXwDvLpqK9+btYys1i146vqx9O+qM1dFJD5FNSOWmd1rZmvNbIWZPW9mHQ557lYzW29m68zsrKgrjQPuzsNvf8x1Ty5lULd2/O2GUxXwIhLXop328DVgqLsPBz4CbgUws8HAVGAIMBl4yMxSo2wrULWhMLc9v5K7X1nLOcO7MfOacWS31dmrIhLfogp5d5/r7rWRh/OBnMj984Gn3b3K3TcC64Ex0bQVpKraEN+Z+QEzFxZww5dO4ndTR5CZntB/s0SkmYhln/xVwKzI/e7Uhf4BhZFln2Nm04BpAD179oxhObFRUR3i2ieXMO+jEn5y7mCuPq130CWJiBy3Y4a8mb0OnHCYp2539xci69wO1AJP1bcAd58OTAfIy8vz+r6+Me2prOHqGYtYsnk3v5wynItH9wi6JBGRejlmyLv76Ud73sy+BZwLTHL3AyFdBByaiDmRZQljZ3kVV/x5Ieu27eV3l4zknOHdgi5JRKTeoh1dMxn4EXCeu+8/5KkXgalm1sLMegP9gIXRtNWUSvZW8fXp88nfXs70y/MU8CKSsKLtk/890AJ4LTKN7nx3v87dPzSz2cBq6rpxbnD3UJRtNYld+6q59JEFFO2u4PGrxjCuT1bQJYmINFhUIe/ufY/y3M+Bn0ez/aZWVlHDZY8uYNPOffz5W6MV8CKS8KIdJ5809lbWcMVjC8nfXs4fLxvFF/p2DrokEZGoKeSB/dW1XDVjEauKyvj9N0YwUfPQiEiSaPYhX1kT4tuPL2bJ5t3cP3UEZw453GhREZHE1KwnKAuHnZufWc77H+/kvotP1igaEUk6zfpI/n9fXsNLK7Zy29kDuWhkzrFfICKSYJptyD/67kYeeXcj3/pCLtd8sU/Q5YiINIpmGfIvrdjKXS+tZvKQE/jJuYOJjPEXEUk6zS7kF27cxfdnL2NUz478duopulyfiCS1ZhXyG0rK+fbji8jp2JI/XZ6n6YJFJOk1m5DfU1nDNU8sJi01hcevHEPH1hlBlyQi0uiaRciHws73nl7G5p37eeibI+nRqVXQJYmINIlmEfK/nruOf60t5o7zhmg+GhFpVpI+5P++/BMeeutjLhnTk0vHxt+Vp0REGlNSh/yqojJ+OGc5eb068rPzhmiopIg0O0kb8jvLq7j2L0vo2CqDP1w6ioy0pN1VEZEjSsq5a8Jh5/uzl1NSXsWc68aT3bZF0CWJiAQiKQ9vH573MfM+KuEn5w5meE6HoMsREQlM0oX8ok27+PXcjzhneDd90SoizV5ShfyufdV8568fkNOxJXdfNExftIpIs5c0ffLhsPOD2cvYta+a567/Am0z04MuSUQkcElzJP/HeRt4a10JPzl3EEO7tw+6HBGRuBCTkDezm83Mzaxz5LGZ2QNmtt7MVpjZyFi0cySLN+3iV3PXcc6wblw6rldjNiUiklCiDnkz6wGcCWw5ZPFXgH6Rn2nAH6Jt52gy01M5tW9nfjFF/fAiIoeKxZH8b4AfAX7IsvOBJ7zOfKCDmTXaBVSHdm/PE1eNoZ364UVEPiWqkDez84Eid1/+mae6AwWHPC6MLDvcNqaZ2WIzW1xSUhJNOSIi8hnHHF1jZq8DJxzmqduB26jrqmkwd58OTAfIy8vzY6wuIiL1cMyQd/fTD7fczIYBvYHlkX7wHGCpmY0BioAeh6yeE1kmIiJNqMHdNe6+0t27uHuuu+dS1yUz0t23AS8Cl0dG2YwDytx9a2xKFhGR49VYJ0O9DJwNrAf2A1c2UjsiInIUMQv5yNH8gfsO3BCrbYuISMMkzRmvIiLyeQp5EZEkZnU9K/HBzEqAzQ18eWdgRwzLCZL2JT4ly74ky36A9uWAXu6efbgn4irko2Fmi909L+g6YkH7Ep+SZV+SZT9A+3I81F0jIpLEFPIiIkksmUJ+etAFxJD2JT4ly74ky36A9uWYkqZPXkREPi+ZjuRFROQzkirkzezOyJWolpnZXDM7MeiaGsrM7jWztZH9ed7MOgRdU0OZ2dfM7EMzC5tZwo2EMLPJZrYucqWzHwddT0OZ2WNmVmxmq4KuJVpm1sPM3jSz1ZHfrZuCrqkhzCzTzBaa2fLIfvws5m0kU3eNmbVz9z2R+98FBrv7dQGX1SBmdibwL3evNbN7ANz9loDLahAzGwSEgT8C/+XuiwMu6biZWSrwEXAGdZPwLQIucffVgRbWAGY2ASin7oI+Q4OuJxqRixB1c/elZtYWWAJckGj/LlY3hW9rdy83s3TgXeCmyMWWYiKpjuQPBHxEaz59taqE4u5z3b028nA+ddM1JyR3X+Pu64Kuo4HGAOvdfYO7VwNPU3fls4Tj7vOAXUHXEQvuvtXdl0bu7wXWcIQLE8WzyNXzyiMP0yM/Mc2tpAp5ADP7uZkVAN8E/ifoemLkKuCVoItopo77KmcSDDPLBUYACwIupUHMLNXMlgHFwGvuHtP9SLiQN7PXzWzVYX7OB3D32929B/AUcGOw1R7dsfYlss7tQC11+xO3jmdfRGLNzNoAzwLf+8wn+YTh7iF3P4W6T+tjzCymXWmNNZ98oznSlaoO4ynq5rW/oxHLicqx9sXMvgWcC0zyOP/ypB7/LolGVzmLU5E+7GeBp9z9uaDriZa7l5rZm8BkIGZfjifckfzRmFm/Qx6eD6wNqpZomdlk4EfAee6+P+h6mrFFQD8z621mGcBU6q58JgGKfGH5KLDG3e8Lup6GMrPsAyPnzKwldV/wxzS3km10zbPAAOpGcmwGrnP3hDzqMrP1QAtgZ2TR/AQeKXQh8DsgGygFlrn7WYEWVQ9mdjbwWyAVeMzdfx5sRQ1jZjOBidTNdrgduMPdHw20qAYys9OAd4CV1P1/B7jN3V8Orqr6M7PhwOPU/W6lALPd/f/FtI1kCnkREfm0pOquERGRT1PIi4gkMYW8iEgSU8iLiCQxhbyISBJTyIuIJDGFvIhIElPIi4gksf8P49VH+I9HxDQAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出这个函数的曲线\n", "x_sample = np.arange(-3, 3.1, 0.1)\n", "y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 + w_target[2] * x_sample ** 3\n", "\n", "plt.plot(x_sample, y_sample, label='real curve')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接着我们可以构建数据集,需要 x 和 y,同时是一个三次多项式,所以我们取了 $x,\\ x^2, x^3$" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# 构建数据 x 和 y\n", "# x 是一个如下矩阵 [x, x^2, x^3]\n", "# y 是函数的结果 [y]\n", "\n", "x_train = np.stack([x_sample ** i for i in range(1, 4)], axis=1)\n", "x_train = torch.from_numpy(x_train).float() # 转换成 float tensor\n", "\n", "y_train = torch.from_numpy(y_sample).float().unsqueeze(1) # 转化成 float tensor " ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([61, 3])\n" ] } ], "source": [ "print(x_train.size())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接着我们可以定义需要优化的参数,就是前面这个函数里面的 $w_i$" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# 定义参数和模型\n", "w = Variable(torch.randn(3, 1), requires_grad=True)\n", "b = Variable(torch.zeros(1), requires_grad=True)\n", "\n", "# 将 x 和 y 转换成 Variable\n", "x_train = Variable(x_train)\n", "y_train = Variable(y_train)\n", "\n", "def multi_linear(x):\n", " return torch.mm(x, w) + b\n", "\n", "def get_loss(y_, y):\n", " return torch.mean((y_ - y) ** 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们可以画出没有更新之前的模型和真实的模型之间的对比" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAprklEQVR4nO3deXhUVbb38e8iBMI8S6OgwRaRQREJGLQVGhRQEURFaW2cRbr1tlM74tt4Ha4DjiiKKIheuSCDA22LAiqiImrACQEBUSGIEEBBIECG/f6xq0jAAEmqKqeq8vs8z3lODafOWZVh1a599lnbnHOIiEhyqhJ0ACIiEjtK8iIiSUxJXkQkiSnJi4gkMSV5EZEkpiQvIpLESp3kzWycma03s0XFHmtoZrPMbHlo3SD0uJnZSDNbYWZfmdlxsQheRET2rywt+fFAn70euxV4xznXCngndB/gNKBVaBkCPB1ZmCIiUh5WlouhzCwdeMM51z50/1ugu3NurZk1A+Y451qb2TOh2xP33m5/+2/cuLFLT08v3zsREamkFixYsME516Sk56pGuO+mxRL3z0DT0O1DgNXFtssOPbbfJJ+enk5WVlaEIYmIVC5m9uO+novaiVfnvxKUuUaCmQ0xsywzy8rJyYlWOCIiQuRJfl2om4bQen3o8TVAi2LbNQ899jvOuTHOuQznXEaTJiV+2xARkXKKNMlPBy4O3b4YeL3Y4xeFRtlkApsP1B8vIiLRV+o+eTObCHQHGptZNjAcuB+YbGaXAz8C54U2fxM4HVgBbAcuLW+AeXl5ZGdns2PHjvLuQiKQlpZG8+bNSU1NDToUESmHUid559xf9vFUzxK2dcDV5Q2quOzsbOrUqUN6ejpmFo1dSik559i4cSPZ2dm0bNky6HBEpBzi/orXHTt20KhRIyX4AJgZjRo10rcokQQW90keUIIPkH72IoktIZK8iEgyu+sumDs3NvtWki+FkSNH0qZNGy688EKmT5/O/fffD8Brr73G4sWLd283fvx4fvrpp933r7jiij2eFxHZ27JlMHw4vP9+bPYf6RWvlcJTTz3F7Nmzad68OQD9+vUDfJLv27cvbdu2BXySb9++PQcffDAAzz33XDABF5Ofn0/Vqvo1i8Srp56C1FS48srY7F8t+QMYOnQoK1eu5LTTTuPRRx9l/PjxXHPNNcybN4/p06dz0003ceyxx/LAAw+QlZXFhRdeyLHHHktubi7du3ffXaahdu3aDBs2jA4dOpCZmcm6desA+O6778jMzOToo4/mjjvuoHbt2iXG8eKLL3LMMcfQoUMHBg8eDMAll1zC1KlTd28Tfu2cOXM46aST6NevH23btuXWW29l1KhRu7e78847eeihhwAYMWIEnTt35phjjmH48OHR/wGKyD5t3QrPPw8DB8If/hCbYyRWE++66+CLL6K7z2OPhcce2+fTo0eP5q233uK9996jcePGjB8/HoATTjiBfv360bdvX84991wAZsyYwUMPPURGRsbv9rNt2zYyMzO59957ufnmm3n22We54447uPbaa7n22mv5y1/+wujRo0uM4ZtvvuGee+5h3rx5NG7cmE2bNh3wbS1cuJBFixbRsmVLPv/8c6677jquvtqPap08eTJvv/02M2fOZPny5Xz66ac45+jXrx9z587l5JNPPuD+RSRyEybAli1wdVQGnJdMLfkKUq1aNfr27QtAp06d+OGHHwD4+OOPGThwIAAXXHBBia999913GThwII0bNwagYcOGBzxely5ddo9t79ixI+vXr+enn37iyy+/pEGDBrRo0YKZM2cyc+ZMOnbsyHHHHcfSpUtZvnx5pG9VRErBOXjySejYEbp2jd1xEqslv58Wd7xLTU3dPRwxJSWF/Pz8iPdZtWpVCgsLASgsLGTXrl27n6tVq9Ye2w4cOJCpU6fy888/c/755wP+YqfbbruNq666KuJYRKRs5s6FRYtg7FiI5UhlteQjUKdOHX777bd93i+NzMxMpk2bBsCkSZNK3KZHjx5MmTKFjRs3AuzurklPT2fBggUATJ8+nby8vH0e5/zzz2fSpElMnTp19zeH3r17M27cOLZu3QrAmjVrWL9+/T73ISLR8+ST0KABDBoU2+MoyUdg0KBBjBgxgo4dO/Ldd99xySWXMHTo0N0nXkvjscce45FHHuGYY45hxYoV1KtX73fbtGvXjmHDhtGtWzc6dOjADTfcAMCVV17J+++/T4cOHfj4449/13rfex+//fYbhxxyCM2aNQOgV69eXHDBBXTt2pWjjz6ac889t8wfUiJSdtnZ8OqrcPnlULNmbI9VppmhYi0jI8PtPWnIkiVLaNOmTUARxd727dupUaMGZsakSZOYOHEir7/++oFfWIGS/XcgUtH+9S+45x5YsQIOPzzy/ZnZAufc70d8kGh98klowYIFXHPNNTjnqF+/PuPGjQs6JBGJoZ074Zln4IwzopPgD0RJPmAnnXQSX375ZdBhiEgFmTYN1q+P7bDJ4tQnLyJSgZ58Eo44Anr1qpjjKcmLiFSQzz6Djz/2rfgqFZR9leRFRCrII49A3bpw2WUVd0wleRGRCrBqFUyZ4guR1a1bccdVkq8A6enpbNiwIegwRCRAI0f69T/+UbHHVZIvA+fc7jICikNESmvLFnj2WV9t8tBDK/bYSvIH8MMPP9C6dWsuuugi2rdvz+rVq/dZnvess86iU6dOtGvXjjFjxhxw32+99RbHHXccHTp0oGdPPx968TLAAO3bt+eHH374XRx33303N9100+7twiWQAV566SW6dOnCsccey1VXXUVBQUG0fhwiUg5jx/pEf+ONFX/sqIyTN7PrgSsAB3wNXAo0AyYBjYAFwGDn3K597qQUAqg0DMDy5ct54YUXyMzM3G953nHjxtGwYUNyc3Pp3Lkz55xzDo0aNSpxnzk5OVx55ZXMnTuXli1blqp8cPE4cnJy6Nq1KyNGjADg5ZdfZtiwYSxZsoSXX36Zjz76iNTUVP7+978zYcIELrroojL+ZEQkGvLz4fHH4aSToIQq5DEXcZI3s0OAfwBtnXO5ZjYZGAScDjzqnJtkZqOBy4GnIz1eEA477DAyMzMB9ijPC7B161aWL1/OySefzMiRI3n11VcBWL16NcuXL99nkp8/fz4nn3zy7nLApSkfXDyOJk2acPjhhzN//nxatWrF0qVLOfHEExk1ahQLFiygc+fOAOTm5nLQQQdF9gMQkXJ75RX48Uef6IMQrSteqwI1zCwPqAmsBXoA4QLpLwB3EmGSD6rScPHCX/sqzztnzhxmz57Nxx9/TM2aNenevTs7duwo87GKlw8G9tjH3gXIBg0axOTJkznqqKMYMGAAZoZzjosvvpj77ruvzMcWkehyDh5+2F/8FJpOosJF3CfvnFsDPASswif3zfjumV+dc+Gi6dnAISW93syGmFmWmWXl5OREGk7M7as87+bNm2nQoAE1a9Zk6dKlzJ8/f7/7yczMZO7cuXz//ffAnuWDFy5cCPjZncLPl2TAgAG8/vrrTJw4kUGheqU9e/Zk6tSpu0sGb9q0iR9//DGyNy0i5TJvHnz6KVx/PaSkBBNDNLprGgD9gZbAr8AUoE9pX++cGwOMAV+FMtJ4Yq1Xr14sWbKErqGpXGrXrs1LL71Enz59GD16NG3atKF169a7u1X2pUmTJowZM4azzz6bwsJCDjroIGbNmsU555zDiy++SLt27Tj++OM58sgj97mPBg0a0KZNGxYvXkyXLl0AaNu2Lffccw+9evWisLCQ1NRURo0axWGHHRa9H4KIlMojj/ia8RdfHFwMEZcaNrOBQB/n3OWh+xcBXYGBwB+cc/lm1hW40znXe3/7qoylhhOBfgciZbdiBRx5JNx6K/zP/8T2WPsrNRyNIZSrgEwzq2l+fruewGLgPeDc0DYXA/FVJF1EJIYefBCqV4drrw02jmj0yX8CTAUW4odPVsF3v9wC3GBmK/DDKMdGeiwRkUSwZg2MH+9r1DRtGmwsURld45wbDgzf6+GVQJco7X/3JNhSseJp5jCRRPHww1BYCMWuVwxM3F/xmpaWxsaNG5VsAuCcY+PGjaSlpQUdikjC2LDBz/x04YWQnh50NAkwM1Tz5s3Jzs4mEYZXJqO0tDSaN28edBgiCWPkSMjN9Sdc40HcJ/nU1NTdV4WKiMSzLVvgiSdgwACIlwFpcd9dIyKSKEaPhl9/hdtuCzqSIkryIiJRkJvrL37q1SuYQmT7oiQvIhIFzz8P69bFVyselORFRCKWl+cvfuraFbp1CzqaPcX9iVcRkXj34ou+nPCTT0K8XdKjlryISAR27YK77/b98GecEXQ0v6eWvIhIBJ5/3rfin346/lrxoJa8iEi57dwJ99wDmZnQp9QF1iuWWvIiIuX03HOQne1b8/HYige15EVEyiU319eJP+kk6Nkz6Gj2TS15EZFyGDMGfvoJJkyI31Y8qCUvIlJm27fDfffBn/8M3bsHHc3+qSUvIlJGTz/tr26dMiXoSA5MLXkRkTLYuhUeeABOPdX3x8c7JXkRkTJ49FHIyYG77go6ktJRkhcRKaX1632NmgED/Nj4RKAkLyJSSnff7YdO3ndf0JGUnpK8iEgprFjhJwW54gpo3TroaEovKknezOqb2VQzW2pmS8ysq5k1NLNZZrY8tG4QjWOJiARh2DCoVg2GDw86krKJVkv+ceAt59xRQAdgCXAr8I5zrhXwTui+iEjC+ewzmDwZ/vlPaNYs6GjKJuIkb2b1gJOBsQDOuV3OuV+B/sALoc1eAM6K9FgiIhXNObj5ZjjoIJ/kE000WvItgRzgeTP73MyeM7NaQFPn3NrQNj8DTUt6sZkNMbMsM8vKycmJQjgiItEzYwbMmQP/+hfUqRN0NGUXjSRfFTgOeNo51xHYxl5dM845B7iSXuycG+Ocy3DOZTRp0iQK4YiIREdBAdxyCxxxBAwZEnQ05RONJJ8NZDvnPgndn4pP+uvMrBlAaL0+CscSEakw48fDokW+2mRqatDRlE/ESd459zOw2szCg4p6AouB6cDFoccuBl6P9FgiIhVl82a4/XY48UQ499ygoym/aBUo+y9ggplVA1YCl+I/QCab2eXAj8B5UTqWiEjM3XWXL18wY0Z8lxI+kKgkeefcF0BGCU/FcSl9EZGSLV0KI0fC5ZfDcccFHU1kdMWriEgxzsF110GtWnDvvUFHEznVkxcRKeaNN+Dtt321yYMOCjqayKklLyISsnMnXH89tGkDV18ddDTRoZa8iEjIY4/Bd9/5lnyiDpncm1ryIiLA2rVwzz3Qrx/06hV0NNGjJC8iAtxwA+zaBQ8/HHQk0aUkLyKV3ltvwaRJvpzwEUcEHU10KcmLSKW2fTv8/e9+IpBbbgk6mujTiVcRqdTuugu+/95XmqxePehook8teRGptL7+2vfBX3YZdOsWdDSxoSQvIpVSYaEvH1y/Pjz4YNDRxI66a0SkUhozBubPhxdfhEaNgo4mdtSSF5FKZ+1auPVW6NkT/vrXoKOJLSV5EalUnPOjaXbsgKefTuwywqWh7hoRqVT+7//gtddgxAho1SroaGJPLXkRqTR++gmuuQZOOMEXIqsMlORFpFJwDq680leaHD8eUlKCjqhiqLtGRCqF8ePhzTfh8ccrRzdNmFryIpL0Vq/2sz117+67ayoTJXkRSWrO+blaCwpg3DioUsmynrprRCSpjR4Ns2b54ZItWwYdTcWL2meamaWY2edm9kbofksz+8TMVpjZy2ZWLVrHEhEpjUWLfJ343r3hqquCjiYY0fzici2wpNj9B4BHnXNHAL8Al0fxWCIi+5WbC4MGQb168MILyX/R075EJcmbWXPgDOC50H0DegBTQ5u8AJwVjWOJiJTGDTfAN9/42jRNmwYdTXCi1ZJ/DLgZKAzdbwT86pzLD93PBg4p6YVmNsTMsswsKycnJ0rhiEhlNm2a74u/6abkmq+1PCJO8mbWF1jvnFtQntc758Y45zKccxlNmjSJNBwRqeRWrYIrroDOnf3E3JVdNEbXnAj0M7PTgTSgLvA4UN/MqoZa882BNVE4lojIPuXnwwUX+OGSEydCNQ33iLwl75y7zTnX3DmXDgwC3nXOXQi8B5wb2uxi4PVIjyUisj933gkffeSHS/7xj0FHEx9ieVnALcANZrYC30c/NobHEpFKbvp0uPdeuPRSuPDCoKOJH1G9GMo5NweYE7q9EugSzf2LiJTk229h8GDo1AmeeiroaOJLJbvAV0SSzW+/wdln+/73V16BtLSgI4ovKmsgIgnLObjsMli61JcuOPTQoCOKP0ryIpKwHnoIpk71szz16BF0NPFJ3TUikpDeecdPxn3eeXDjjUFHE7+U5EUk4SxdCueeC23awNixlbcuTWkoyYtIQsnJgTPO8Cda33gDatcOOqL4pj55EUkYubnQv7+fkPv99yE9PeiI4p+SvIgkhMJCuOQSmD8fpkyBLroKp1SU5EUkIdxxB0ye7EfSnHNO0NEkDvXJi0jce+45uO8+P7uTRtKUjZK8iMS1V17xyb13b3jiCY2kKSsleRGJWzNn+in8jj/eTwSSmhp0RIlHSV5E4tJHH8FZZ0HbtvCf/0CtWkFHlJiU5EUk7nzxhR8L37w5vP02NGgQdESJS0leROLKsmV+Xta6dWH27Mo9CXc0KMmLSNxYtqyo0JiqSkaHkryIxIXFi6FbN9i1y7fgW7cOOqLkoCQvIoH76ivo3t3fnjMHjjkmyGiSi5K8iARq4UL48599wbH33/ejaSR6lORFJDCffOL74OvUgblz4cgjg44o+USc5M2shZm9Z2aLzewbM7s29HhDM5tlZstDaw2CEpHd3n4bTjkFGjf2Cf7ww4OOKDlFoyWfD9zonGsLZAJXm1lb4FbgHedcK+Cd0H0REcaN8+PgjzgCPvhAo2hiKeIk75xb65xbGLr9G7AEOAToD7wQ2uwF4KxIjyUiic05uPNOuPxy34qfOxeaNQs6quQW1VLDZpYOdAQ+AZo659aGnvoZ0CUNIpVYXp4vNPb883DppfDMM6pFUxGiduLVzGoD04DrnHNbij/nnHOA28frhphZlpll5eTkRCscEYkjv/wCffv6BD98uJ+XVQm+YkQlyZtZKj7BT3DOvRJ6eJ2ZNQs93wxYX9JrnXNjnHMZzrmMJk2aRCMcEYkjixZB587w3nu+Lvydd6pccEWKxugaA8YCS5xzjxR7ajpwcej2xcDrkR5LRBLLlCmQmQnbtvmLnC6/POiIKp9otORPBAYDPczsi9ByOnA/cKqZLQdOCd0XkUqgoABuvRXOO89fvbpgAZxwQtBRVU4Rn3h1zn0I7OvLV89I9y8iiWXdOhg82BcYGzoUHn/cX80qwdBE3iISNW++6UfObNni+9/VPRM8lTUQkYjt2AH/+Ie/wKlpU8jKUoKPF0ryIhKRRYugSxc/yfa118Knn0K7dkFHJWFK8iJSLnl5cN99kJHh++FnzIDHHoO0tKAjk+LUJy8iZfbpp3Dllb4O/LnnwpNPapq+eKWWvIiU2tatcP310LUrbNwIr73mx8IrwccvJXkROSDnYOpUaN/ed8kMHeqn6+vfP+jI5EDUXSMi+5WV5VvvH34IRx/t1yeeGHRUUlpqyYtIibKz4aKLfN2ZZctgzBj4/HMl+ESjlryI7GH9ehgxAkaNgsJCX57gttugbt2gI5PyUJIXEQB+/tkn96efhp074YIL4O67IT096MgkEkryIpXc6tXw6KMwerRP7n/9Kwwbpkm1k4WSvEgl5BzMmwcjR8K0af6xcHJv1SrY2CS6lORFKpEdO/xQyMcf96Nm6tf3I2euvlrdMslKSV4kyTnnr1B94QWYOBF+/RWOOgqeesqPnqlVK+gIJZaU5EWS1A8/wKRJPrkvXQo1asCAAXDJJdCzJ1TRAOpKQUleJEk4569CfeUVePVVP6Yd4E9/8rXdBw7UMMjKSEleJIFt3Qpz58Ls2fDGG7B8uX+8a1c/HPLss+Hww4ONUYKlJC+SQLZt8ydM33/fJ/aPP4b8fKheHbp1gxtu8PVkmjULOlKJF0ryInEqL8+XE8jKgvnz/fL1136SbDPo1An++U845RQ/SXaNGkFHLPFISV4kYLm5sHIlfPcdLFniE/nXX/vbeXl+m3r14Pjj/Tj2zEx/u2HDYOOWxBDzJG9mfYDHgRTgOefc/bE+pkg8cM73mW/c6EsG/PSTX9as8esffoAVK/zt4lq08NUeTzvNrzt29EMeNRpGyiOmSd7MUoBRwKlANvCZmU13zi2O5XGl4mzb5hPYunWwaZMfg1182boVtm8vWnJz/QU5+fm+lZqX528XFPj9Obfn/lNSipYqVaBqVb+kphatw0u1anuuS7odfn3xpUqVPRczH0dhYdG6sNDHunMn7Nrll507/fvfts2/z/Dyyy/+Z7FpU1FLvLiqVX2f+WGHwamnwhFHwB//6Jcjj/QXKIlES6xb8l2AFc65lQBmNgnoD0Q3yWdnw0cf+SZQixb+P6iqeqIiVVAAP/4I33/vW53Fl7VrfWLfunXfr69ZE+rU8esaNYrWder8PkGHkysUrZ3zMRQU+CRbUOA/EMJLXp7/wNiypegDY9euktfhJVpSUhzVUh21ahRSu0ZB0TqtgLZN82l0ZB4N6+TTqG4eDevm84eGuzi48S4ObpJH4/r5VEmxok+V8CdYSgqsrQob9voEq1bNn1mtXl3NeSmzWGfCQ4DVxe5nA8dH/SgffOBL5oVVqeITffPmft2sGRx8cNHtgw4qWjTrMM7BqlWwYAF8843vC168GL791ifRsJQU/xl62GHQpQv84Q9+2rfw0rixb4XWr+/7kKtVi0JgeXlFXwH2/koQXu9r2bGjaL1jB257LoU788jPzSNvRwH5O/LJ35GP25VH4a78oiWvgCq7dlCFAqpQiOGoQiHV2EU1dpFKHikFhVAA7AB+ifB9lkXVqv5vNi1tz0/OGjX8pau1a/ulTp2i2/Xq7bnUrw8NGvhO/bp1iz5VJSkF3tw1syHAEIBDDz20fDvp39+fqVq92rfqV68uur18uR9IvGlTya+tU8cn+0aN/NKwYdG6QYM9s1b9+v6fIvwPVL16Qv6DbNrkv/h89plfsrJgw4ai59PToU0bP2qjTRvfjZCe7j8zq1bFN6P3lVhX5cLSvZLw3kl674S9v8fD/ThlVb26T3zhZJiWhlWvTkpaGilpaVSvHUqU4RZyuLVcrVrR7b37e8J9ROFWdvG+pOJ9SuGvJcUX8B9aey/Fv6aEv6oU78cKfx3ZuXPPJfwBVvxnt3Vr0der8JKbu/+fU0qK/7tu2BCaNNlzOegg/0lefKlXLyH/5iszc3t3gkZz52ZdgTudc71D928DcM7dV9L2GRkZLisrKzbB7NjhO4/XrvWzIuTk+HV42bjRL5s2+fWWLQfeZ9WqRQm/Zs2ipVYtn1iqVy9qdYWTyd6dxeGksXeyCCeH4gkj/Lsqvg4niXCiKCws6s8IJY1NW6oyd2Vz5qw8lDmrWvJVTjMcVUixAtrVyyaj/ndk1PmWTjWX0K7acmrlby5KJOF18aW8iTclZc+fT/GfWbglWtJjxbct3not3ootvqSlqVsjLD/f/y1v3ly0/PrrnicOwn/zGzYU/W9s2FDy7zktDQ45xH/ih9fNm8Ohh/qveOnpOqkQADNb4JzLKPG5GCf5qsAyoCewBvgMuMA5901J28c0yZdVXt6e/xTh9ZYt8Ntvey57tz63bft9YszNLeogzs+PaegO+Jqj+Tdn8gZ9+YTjcVQhjVxOTJlP99R5dEv7hE41l1CzhitqvVartucHUvh28e6BvW8XT6wlJd/iyTk1NabvW6KosNB/EKxb5xtH4WXtWj88KDvbL2vW+L/r4urW9Qn/8MOLziiHl/R0nS+LgcCSfOjgpwOP4YdQjnPO3buvbeMqycdSuK+5+NCS8Dq8FP9KHx7mUfyrf3gdavkXUoUPPq3OlP/U4N8zq7MqOwWAjE6F9D0Dep5ahc6dfc4WiZrCQt/yX7XKn6X/4Yei9cqVfil+Yic11Sf71q2LljZtoF07FdaJQKBJviwqTZKPoq++ggkTfAnZ1at9g/mUU+DMM+GMM/z5ZpHAFBb61v933/mLApYt82f0v/3W3y8+5KlFC5/s27XzFwgce6z/AIj4DH7y21+S1/emBPTLL7587NixsGiR//bbuzfcf78/B6364BI3qlTxffeHHAInn7znc/n5vsW/eLEf1hVe3nvPnwsC3/Jv08Yn/I4dISPD365du4LfSOJSSz6BLFjgJ3qYONF38WdmwuDBvoRskyZBRycSJQUFflTcl1/CF18Urdeu9c+b+cSfkQGdO/t/hA4dKvU5H3XXJLD8fJgyxU+0/NlnvpX+17/C3/7m/65FKo21a31LZ8ECP+73s8/8iWHwJ/4zMnzCP+EEX0S/ErV8lOQT0M6dvkvmgQf8uaujjvLzcA4e7Icqi1R6zvkTUeESnfPn+w+A8Gif1q3hpJP80q2bH/GTpJTkE8jWrTBmDDz8sC9c1bkz3H479Ounod8iB7Rzp0/0H3zgl48+8kOfwQ/p7NHDL3/+s7+4K0koySeAvDw/Rdt//7f/Btqjh0/uPXroAkORciss9KMT5szxJ3TnzClK+m3b+hELvXv7k8IJXJBfST6OOefn5Lz9dj+67KST/CiZE04IOjKRJFRQ4E/ivvsuzJrlS57s3On79E8+Gfr0gb59oVWroCMtEyX5ODVvHtx4o+9KbNvWJ/e+fdVyF6kw27f7RP/WW/D227B0qX/8yCP9xSZ9+8KJJ8b9yB0l+TizYQPccguMG+cvVrrrLrj4Yl3tLRK477+H//wH/v1v37Wza5evxXPmmX5W9F69fJmOOKMkHycKC+H5532C37wZrr8e/vUvXdchEpd++8136Uyf7pdffvH99qed5hP+mWfGTSkGJfk4sGgRDB3qT/b/6U/w9NPQvn3QUYlIqeTl+W6dV16B117zQ9+qV4fTT4dBg3wNkQAvNd9fktegvBgrKIAHH4ROnXx339ix8P77SvAiCSU1FXr2hFGj/Nj8jz6Cq67yJ9TOP9/X3h80yHfzRHMKsihQSz6GVq70fe0ffggDBsAzz1Sqi/BEkl9BgR+P//LLMHWqP+HWpIlP+IMH+6twK2AkhVryFcw5P+a9QwdfJfLFF2HaNCV4kaSTkgLdu/v+159+8n333bv7Kxq7dPHD5h54wNfiD4iSfJT98gucdRZceaX/HX/9tf9A17BIkSSXmupPxk6e7JP6s8/6iY9vvdXPntW/v/8QiPGkQXtTko+ihQt93/uMGb6g2KxZflY0Ealk6teHK67wXTnffgv//Cd8+qlP9C1awLBhfnKVCqAkHwXO+W9nJ5zgP6Q/+ACuu061ZkQEf2HV/ff72bNef90XpLr/fl9L58wz4c03yz9vcikoDUVo+3a45BJ/or1bN9+aP/74oKMSkbiTmuorDU6f7i+6uu02Xy75jDPgiCP8ybsYUJKPwKpV0LUr/O//wvDh/gO5ceOgoxKRuHfooXDPPT6JvPyyn+B8+/aYHEoX0pdTVpb/prV9u78K+rTTgo5IRBJOtWpw3nl+idFwdrXky+GVV3zBuurVfZExJXgRiViMhuBFlOTNbISZLTWzr8zsVTOrX+y528xshZl9a2a9I440Djjnr1495xw/Bv6TT/zE8iIi8SrSlvwsoL1z7hhgGXAbgJm1BQYB7YA+wFNmlhLhsQKVn+9Prt5yi/9m9e670LRp0FGJiOxfREneOTfTORce2T8faB663R+Y5Jzb6Zz7HlgBdInkWEHaudNfpfzss35yj4kTE3oSGRGpRKLZJ38ZMCN0+xBgdbHnskOP/Y6ZDTGzLDPLysnJiWI40bF9u79+Ydo0f4HTvfdq/LuIJI4Djq4xs9lASTPeDnPOvR7aZhiQD0woawDOuTHAGPAFysr6+ljavNlPDDNvnq8eedllQUckIlI2B0zyzrlT9ve8mV0C9AV6uqKSlmuAFsU2ax56LGHk5PjpHr/+GiZNgoEDg45IRKTsIh1d0we4GejnnCs+kn86MMjMqptZS6AV8Gkkx6pI69b5q1cXL/ZXISvBi0iiivRiqCeB6sAs82M85zvnhjrnvjGzycBifDfO1c652BVniKING+CUU3ztoLfe8sleRCRRRZTknXNH7Oe5e4F7I9l/Rfv1Vz9P74oV/ipWJXgRSXQqaxCyZYvvg//mG99F06NH0BGJiEROSR7Yts0XgluwwM/g1adP0BGJiERHpU/yO3b46p/z5vlRNP37Bx2RiEj0VOokX1joJ9p+911fylmjaEQk2VTqazdvuslPxzhihJ+HVUQk2VTaJP/YY/DII/Bf/wU33hh0NCIisVEpk/yUKXDDDXD22b4eTYzKOIuIBK7SJfkPPvBdMyecAC+9BCkJXQBZRGT/KlWSX7bMj6RJT/dj4VUuWESSXaVJ8ps3++GRVavCjBnQqFHQEYmIxF6lGEJZUAAXXujLFcyeDS1bBh2RiEjFqBRJ/v/9P1+L5qmnVI9GRCqXpO+ueflluO8+GDIEhg4NOhoRkYqV1En+88/h0kvhxBPhiSc0VFJEKp+kTfI5OXDWWf4E67RpUK1a0BGJiFS8pOyTLyz0Y+HXrYMPP4SmTYOOSEQkGEmZ5B98EN5+259ozcgIOhoRkeAkXXfNhx/CHXfAeefpRKuISFIl+Q0bYNAgf0Xrs8/qRKuISNJ01xQWwkUX+ROu8+dD3bpBRyQiErykSfIjRvhyBaNGQceOQUcjIhIfotJdY2Y3mpkzs8ah+2ZmI81shZl9ZWbHReM4+/LRRzBsmJ/Z6W9/i+WRREQSS8RJ3sxaAL2AVcUePg1oFVqGAE9Hepz9qVkTTjlF/fAiInuLRkv+UeBmwBV7rD/wovPmA/XNrFkUjlWijh3hrbegXr1YHUFEJDFFlOTNrD+wxjn35V5PHQKsLnY/O/RYSfsYYmZZZpaVk5MTSTgiIrKXA554NbPZwB9KeGoYcDu+q6bcnHNjgDEAGRkZ7gCbi4hIGRwwyTvnTinpcTM7GmgJfGm+I7w5sNDMugBrgBbFNm8eekxERCpQubtrnHNfO+cOcs6lO+fS8V0yxznnfgamAxeFRtlkApudc2ujE7KIiJRWrMbJvwmcDqwAtgOXxug4IiKyH1FL8qHWfPi2A66O1r5FRKR8kqp2jYiI7ElJXkQkiZnvWYkPZpYD/FjOlzcGNkQxnCDpvcSnZHkvyfI+QO8l7DDnXJOSnoirJB8JM8tyziXFFCF6L/EpWd5LsrwP0HspDXXXiIgkMSV5EZEklkxJfkzQAUSR3kt8Spb3kizvA/ReDihp+uRFROT3kqklLyIie0mqJG9md4dmovrCzGaa2cFBx1ReZjbCzJaG3s+rZlY/6JjKy8wGmtk3ZlZoZgk3EsLM+pjZt6GZzm4NOp7yMrNxZrbezBYFHUukzKyFmb1nZotDf1vXBh1TeZhZmpl9amZfht7Hf0f9GMnUXWNmdZ1zW0K3/wG0dc4NDTiscjGzXsC7zrl8M3sAwDl3S8BhlYuZtQEKgWeAfzrnsgIOqdTMLAVYBpyKL8L3GfAX59ziQAMrBzM7GdiKn9CnfdDxRCI0CVEz59xCM6sDLADOSrTfi/kSvrWcc1vNLBX4ELg2NNlSVCRVSz6c4ENqsedsVQnFOTfTOZcfujsfX645ITnnljjnvg06jnLqAqxwzq10zu0CJuFnPks4zrm5wKag44gG59xa59zC0O3fgCXsY2KieBaaPW9r6G5qaIlq3kqqJA9gZvea2WrgQuBfQccTJZcBM4IOopIq9SxnEgwzSwc6Ap8EHEq5mFmKmX0BrAdmOeei+j4SLsmb2WwzW1TC0h/AOTfMOdcCmABcE2y0+3eg9xLaZhiQj38/cas070Uk2sysNjANuG6vb/IJwzlX4Jw7Fv9tvYuZRbUrLVb15GNmXzNVlWACvq798BiGE5EDvRczuwToC/R0cX7ypAy/l0SjWc7iVKgPexowwTn3StDxRMo596uZvQf0AaJ2cjzhWvL7Y2atit3tDywNKpZImVkf4Gagn3Nue9DxVGKfAa3MrKWZVQMG4Wc+kwCFTliOBZY45x4JOp7yMrMm4ZFzZlYDf4I/qnkr2UbXTANa40dy/AgMdc4lZKvLzFYA1YGNoYfmJ/BIoQHAE0AT4FfgC+dc70CDKgMzOx14DEgBxjnn7g02ovIxs4lAd3y1w3XAcOfc2ECDKicz+xPwAfA1/v8d4Hbn3JvBRVV2ZnYM8AL+b6sKMNk5d1dUj5FMSV5ERPaUVN01IiKyJyV5EZEkpiQvIpLElORFRJKYkryISBJTkhcRSWJK8iIiSUxJXkQkif1/Z7y6+/bBNA8AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新之前的模型\n", "y_pred = multi_linear(x_train)\n", "\n", "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以发现,这两条曲线之间存在差异,我们计算一下他们之间的误差" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(1144.2655, grad_fn=)\n" ] } ], "source": [ "# 计算误差,这里的误差和一元的线性模型的误差是相同的,前面已经定义过了 get_loss\n", "loss = get_loss(y_pred, y_train)\n", "print(loss)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# 自动求导\n", "loss.backward()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ -94.7455],\n", " [-139.1247],\n", " [-629.8584]])\n", "tensor([-25.7413])\n" ] } ], "source": [ "# 查看一下 w 和 b 的梯度\n", "print(w.grad)\n", "print(b.grad)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# 更新一下参数\n", "w.data = w.data - 0.001 * w.grad.data\n", "b.data = b.data - 0.001 * b.grad.data" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAApLElEQVR4nO3de5yV4/7/8denaWo6n0tEE0IHElMmNqVI6FuiCJuctbEdt2PsfB2+DiFySpTy0y4dHNo2SUibClMOpaJEGqVGqVRTzeH6/XGt1UyZqZlZa+Zea837+Xhcj3W+789qNZ91reu+7s9lzjlERCQxVQk6ABERKT9K8iIiCUxJXkQkgSnJi4gkMCV5EZEEpiQvIpLASpzkzWyMma0zs0WF7mtoZu+b2bLQZYPQ/WZmI8xsuZl9Y2bHlEfwIiKyd6XpyY8Feu1x3x3AB8651sAHodsApwOtQ+0q4PnIwhQRkbKw0pwMZWapwNvOufah298B3Zxza8ysOTDLOXe4mb0Quj5hz+ftbfuNGzd2qampZXsnIiKV1Pz5839zzjUp6rGqEW67WaHE/SvQLHT9AGBVoedlhu7ba5JPTU0lIyMjwpBERCoXM1tZ3GNRO/Dq/E+CUtdIMLOrzCzDzDKysrKiFY6IiBB5kl8bGqYhdLkudP8vwIGFntcidN+fOOdGOefSnHNpTZoU+WtDRETKKNIkPw0YFLo+CHir0P0Xh2bZpAOb9jUeLyIi0VfiMXkzmwB0AxqbWSYwFHgYmGRmlwMrgXNDT38HOANYDmwDLi1rgDk5OWRmZrJ9+/aybkIikJKSQosWLUhOTg46FBEpgxIneefc+cU81KOI5zrg2rIGVVhmZiZ16tQhNTUVM4vGJqWEnHOsX7+ezMxMWrVqFXQ4IlIGMX/G6/bt22nUqJESfADMjEaNGulXlEgci/kkDyjBB0j/9iLxLS6SvIhIIrvvPpg9u3y2rSRfAiNGjKBNmzZceOGFTJs2jYcffhiAN998k8WLF+963tixY1m9evWu21dcccVuj4uI7On772HoUPj44/LZfqRnvFYKzz33HDNnzqRFixYA9OnTB/BJvnfv3rRt2xbwSb59+/bsv//+ALz00kvBBFxIbm4uVavqYxaJVc89B8nJcOWV5bN99eT3YfDgwaxYsYLTTz+d4cOHM3bsWK677jrmzJnDtGnTuPXWWzn66KN55JFHyMjI4MILL+Too48mOzubbt267SrTULt2bYYMGUKHDh1IT09n7dq1APzwww+kp6dz5JFHcvfdd1O7du0i43jllVc46qij6NChAxdddBEAl1xyCVOmTNn1nPBrZ82axYknnkifPn1o27Ytd9xxB88+++yu591777089thjAAwbNoxOnTpx1FFHMXTo0Oj/A4pIsbZsgZdfhgEDYL/9ymcf8dXFu/FG+Oqr6G7z6KPhySeLfXjkyJFMnz6djz76iMaNGzN27FgAjj/+ePr06UPv3r3p378/AO+++y6PPfYYaWlpf9rO1q1bSU9P58EHH+S2227jxRdf5O677+aGG27ghhtu4Pzzz2fkyJFFxvDtt9/ywAMPMGfOHBo3bsyGDRv2+bYWLFjAokWLaNWqFV9++SU33ngj117rZ7VOmjSJ9957jxkzZrBs2TI+//xznHP06dOH2bNnc9JJJ+1z+yISufHjYfNmuDYqE86Lpp58BalWrRq9e/cG4Nhjj+Wnn34CYO7cuQwYMACACy64oMjXfvjhhwwYMIDGjRsD0LBhw33ur3Pnzrvmtnfs2JF169axevVqvv76axo0aMCBBx7IjBkzmDFjBh07duSYY45h6dKlLFu2LNK3KiIl4Bw88wx07AhdupTffuKrJ7+XHnesS05O3jUdMSkpidzc3Ii3WbVqVfLz8wHIz89n586dux6rVavWbs8dMGAAU6ZM4ddff+W8884D/MlOd955J1dffXXEsYhI6cyeDYsWwejRUJ4zldWTj0CdOnX4448/ir1dEunp6UydOhWAiRMnFvmc7t27M3nyZNavXw+wa7gmNTWV+fPnAzBt2jRycnKK3c95553HxIkTmTJlyq5fDqeddhpjxoxhy5YtAPzyyy+sW7eu2G2ISPQ88ww0aAADB5bvfpTkIzBw4ECGDRtGx44d+eGHH7jkkksYPHjwrgOvJfHkk0/yxBNPcNRRR7F8+XLq1av3p+e0a9eOIUOG0LVrVzp06MDNN98MwJVXXsnHH39Mhw4dmDt37p9673tu448//uCAAw6gefPmAPTs2ZMLLriALl26cOSRR9K/f/9Sf0mJSOllZsIbb8Dll0PNmuW7r1KtDFXe0tLS3J6LhixZsoQ2bdoEFFH527ZtGzVq1MDMmDhxIhMmTOCtt97a9wsrUKJ/BiIV7Z//hAcegOXL4eCDI9+emc13zv15xgfxNiafgObPn891112Hc4769eszZsyYoEMSkXK0Ywe88AKceWZ0Evy+KMkH7MQTT+Trr78OOgwRqSBTp8K6deU7bbIwjcmLiFSgZ56BQw+Fnj0rZn9K8iIiFeSLL2DuXN+Lr1JB2VdJXkSkgjzxBNStC5ddVnH7VJIXEakAP/8Mkyf7QmR161bcfpXkK0Bqaiq//fZb0GGISIBGjPCX119fsftVki8F59yuMgKKQ0RKavNmePFFX23yoIMqdt9K8vvw008/cfjhh3PxxRfTvn17Vq1aVWx53rPOOotjjz2Wdu3aMWrUqH1ue/r06RxzzDF06NCBHj38euiFywADtG/fnp9++ulPcdx///3ceuutu54XLoEM8Oqrr9K5c2eOPvporr76avLy8qL1zyEiZTB6tE/0t9xS8fuOyjx5M7sJuAJwwELgUqA5MBFoBMwHLnLO7Sx2IyUQQKVhAJYtW8a4ceNIT0/fa3neMWPG0LBhQ7Kzs+nUqRPnnHMOjRo1KnKbWVlZXHnllcyePZtWrVqVqHxw4TiysrLo0qULw4YNA+C1115jyJAhLFmyhNdee41PP/2U5ORkrrnmGsaPH8/FF19cyn8ZEYmG3Fx46ik48UQoogp5uYs4yZvZAcD1QFvnXLaZTQIGAmcAw51zE81sJHA58Hyk+wtCy5YtSU9PB9itPC/Ali1bWLZsGSeddBIjRozgjTfeAGDVqlUsW7as2CQ/b948TjrppF3lgEtSPrhwHE2aNOHggw9m3rx5tG7dmqVLl3LCCSfw7LPPMn/+fDp16gRAdnY2TZs2jewfQETK7PXXYeVKn+iDEK0zXqsCNcwsB6gJrAG6A+EC6eOAe4kwyQdVabhw4a/iyvPOmjWLmTNnMnfuXGrWrEm3bt3Yvn17qfdVuHwwsNs29ixANnDgQCZNmsQRRxxBv379MDOccwwaNIiHHnqo1PsWkehyDh5/3J/8FFpOosJFPCbvnPsFeAz4GZ/cN+GHZzY658JF0zOBA4p6vZldZWYZZpaRlZUVaTjlrrjyvJs2baJBgwbUrFmTpUuXMm/evL1uJz09ndmzZ/Pjjz8Cu5cPXrBgAeBXdwo/XpR+/frx1ltvMWHCBAaG6pX26NGDKVOm7CoZvGHDBlauXBnZmxaRMpkzBz7/HG66CZKSgokhGsM1DYC+QCtgIzAZ6FXS1zvnRgGjwFehjDSe8tazZ0+WLFlCl9BSLrVr1+bVV1+lV69ejBw5kjZt2nD44YfvGlYpTpMmTRg1ahRnn302+fn5NG3alPfff59zzjmHV155hXbt2nHcccdx2GGHFbuNBg0a0KZNGxYvXkznzp0BaNu2LQ888AA9e/YkPz+f5ORknn32WVq2bBm9fwQRKZEnnvA14wcNCi6GiEsNm9kAoJdz7vLQ7YuBLsAAYD/nXK6ZdQHudc6dtrdtVcZSw/FAn4FI6S1fDocdBnfcAf/3f+W7r72VGo7GFMqfgXQzq2l+fbsewGLgI6B/6DmDgNgqki4iUo4efRSqV4cbbgg2jmiMyX8GTAEW4KdPVsEPv9wO3Gxmy/HTKEdHui8RkXjwyy8wdqyvUdOsWbCxRGV2jXNuKDB0j7tXAJ2jtP1di2BLxYqllcNE4sXjj0N+PhQ6XzEwMX/Ga0pKCuvXr1eyCYBzjvXr15OSkhJ0KCJx47ff/MpPF14IqalBRxMHK0O1aNGCzMxM4mF6ZSJKSUmhRYsWQYchEjdGjIDsbH/ANRbEfJJPTk7edVaoiEgs27wZnn4a+vWDWJmQFvPDNSIi8WLkSNi4Ee68M+hICijJi4hEQXa2P/mpZ89gCpEVR0leRCQKXn4Z1q6NrV48KMmLiEQsJ8ef/NSlC3TtGnQ0u4v5A68iIrHulVd8OeFnnoFYO6VHPXkRkQjs3An33+/H4c88M+ho/kw9eRGRCLz8su/FP/987PXiQT15EZEy27EDHngA0tOhV4kLrFcs9eRFRMropZcgM9P35mOxFw/qyYuIlEl2tq8Tf+KJ0KNH0NEUTz15EZEyGDUKVq+G8eNjtxcP6smLiJTatm3w0ENw8snQrVvQ0eydevIiIqX0/PP+7NbJk4OOZN/UkxcRKYUtW+CRR+DUU/14fKxTkhcRKYXhwyErC+67L+hISkZJXkSkhNat8zVq+vXzc+PjgZK8iEgJ3X+/nzr50ENBR1JySvIiIiWwfLlfFOSKK+Dww4OOpuSikuTNrL6ZTTGzpWa2xMy6mFlDM3vfzJaFLhtEY18iIkEYMgSqVYOhQ4OOpHSi1ZN/CpjunDsC6AAsAe4APnDOtQY+CN0WEYk7X3wBkybBP/4BzZsHHU3pRJzkzawecBIwGsA5t9M5txHoC4wLPW0ccFak+xIRqWjOwW23QdOmPsnHm2j05FsBWcDLZvalmb1kZrWAZs65NaHn/Ao0K+rFZnaVmWWYWUZWVlYUwhERiZ5334VZs+Cf/4Q6dYKOpvSikeSrAscAzzvnOgJb2WNoxjnnAFfUi51zo5xzac65tCZNmkQhHBGR6MjLg9tvh0MPhauuCjqasolGks8EMp1zn4VuT8En/bVm1hwgdLkuCvsSEakwY8fCokW+2mRyctDRlE3ESd459yuwyszCk4p6AIuBacCg0H2DgLci3ZeISEXZtAnuugtOOAH69w86mrKLVoGyvwPjzawasAK4FP8FMsnMLgdWAudGaV8iIuXuvvt8+YJ3343tUsL7EpUk75z7Ckgr4qEYLqUvIlK0pUthxAi4/HI45pigo4mMzngVESnEObjxRqhVCx58MOhoIqd68iIihbz9Nrz3nq822bRp0NFETj15EZGQHTvgppugTRu49tqgo4kO9eRFREKefBJ++MH35ON1yuSe1JMXEQHWrIEHHoA+faBnz6CjiR4leRER4OabYedOePzxoCOJLiV5Ean0pk+HiRN9OeFDDw06muhSkheRSm3bNrjmGr8QyO23Bx1N9OnAq4hUavfdBz/+6CtNVq8edDTRp568iFRaCxf6MfjLLoOuXYOOpnwoyYtIpZSf78sH168Pjz4adDTlR8M1IlIpjRoF8+bBK69Ao0ZBR1N+1JMXkUpnzRq44w7o0QP++tegoylfSvIiUqk452fTbN8Ozz8f32WES0LDNSJSqfzrX/DmmzBsGLRuHXQ05U89eRGpNFavhuuug+OP94XIKgMleRGpFJyDK6/0lSbHjoWkpKAjqhgarhGRSmHsWHjnHXjqqcoxTBOmnryIJLxVq/xqT926+eGaykRJXkQSmnN+rda8PBgzBqpUsqyn4RoRSWgjR8L77/vpkq1aBR1NxYvad5qZJZnZl2b2duh2KzP7zMyWm9lrZlYtWvsSESmJRYt8nfjTToOrrw46mmBE84fLDcCSQrcfAYY75w4Ffgcuj+K+RET2KjsbBg6EevVg3LjEP+mpOFFJ8mbWAjgTeCl024DuwJTQU8YBZ0VjXyIiJXHzzfDtt742TbNmQUcTnGj15J8EbgPyQ7cbARudc7mh25nAAUW90MyuMrMMM8vIysqKUjgiUplNnerH4m+9NbHWay2LiJO8mfUG1jnn5pfl9c65Uc65NOdcWpMmTSINR0QquZ9/hiuugE6d/MLclV00ZtecAPQxszOAFKAu8BRQ38yqhnrzLYBforAvEZFi5ebCBRf46ZITJkA1TfeIvCfvnLvTOdfCOZcKDAQ+dM5dCHwE9A89bRDwVqT7EhHZm3vvhU8/9dMlDzkk6GhiQ3meFnA7cLOZLceP0Y8ux32JSCU3bRo8+CBceilceGHQ0cSOqJ4M5ZybBcwKXV8BdI7m9kVEivLdd3DRRXDssfDcc0FHE1sq2Qm+IpJo/vgDzj7bj7+//jqkpAQdUWxRWQMRiVvOwWWXwdKlvnTBQQcFHVHsUZIXkbj12GMwZYpf5al796CjiU0arhGRuPTBB34x7nPPhVtuCTqa2KUkLyJxZ+lS6N8f2rSB0aMrb12aklCSF5G4kpUFZ57pD7S+/TbUrh10RLFNY/IiEjeys6FvX78g98cfQ2pq0BHFPiV5EYkL+flwySUwbx5MngyddRZOiSjJi0hcuPtumDTJz6Q555ygo4kfGpMXkZj30kvw0EN+dSfNpCkdJXkRiWmvv+6T+2mnwdNPayZNaSnJi0jMmjHDL+F33HF+IZDk5KAjij9K8iISkz79FM46C9q2hf/8B2rVCjqi+KQkLyIx56uv/Fz4Fi3gvfegQYOgI4pfSvIiElO+/96vy1q3LsycWbkX4Y4GJXkRiRnff19QaExVJaNDSV5EYsLixdC1K+zc6Xvwhx8edESJQUleRAL3zTfQrZu/PmsWHHVUkNEkFiV5EQnUggVw8sm+4NjHH/vZNBI9SvIiEpjPPvNj8HXqwOzZcNhhQUeUeCJO8mZ2oJl9ZGaLzexbM7shdH9DM3vfzJaFLjUJSkR2ee89OOUUaNzYJ/iDDw46osQUjZ58LnCLc64tkA5ca2ZtgTuAD5xzrYEPQrdFRBgzxs+DP/RQ+O9/NYumPEWc5J1za5xzC0LX/wCWAAcAfYFxoaeNA86KdF8iEt+cg3vvhcsv97342bOhefOgo0psUS01bGapQEfgM6CZc25N6KFfAZ3SIFKJ5eT4QmMvvwyXXgovvKBaNBUhagdezaw2MBW40Tm3ufBjzjkHuGJed5WZZZhZRlZWVrTCEZEY8vvv0Lu3T/BDh/p1WZXgK0ZUkryZJeMT/Hjn3Ouhu9eaWfPQ482BdUW91jk3yjmX5pxLa9KkSTTCEZEYsmgRdOoEH33k68Lfe6/KBVekaMyuMWA0sMQ590Shh6YBg0LXBwFvRbovEYkvkydDejps3epPcrr88qAjqnyi0ZM/AbgI6G5mX4XaGcDDwKlmtgw4JXRbRCqBvDy44w4491x/9ur8+XD88UFHVTlFfODVOfcJUNyPrx6Rbl9E4svatXDRRb7A2ODB8NRT/mxWCYYW8haRqHnnHT9zZvNmP/6u4ZngqayBiERs+3a4/np/glOzZpCRoQQfK5TkRSQiixZB585+ke0bboDPP4d27YKOSsKU5EWkTHJy4KGHIC3Nj8O/+y48+SSkpAQdmRSmMXkRKbXPP4crr/R14Pv3h2ee0TJ9sUo9eREpsS1b4KaboEsXWL8e3nzTz4VXgo9dSvIisk/OwZQp0L69H5IZPNgv19e3b9CRyb5ouEZE9iojw/feP/kEjjzSX55wQtBRSUmpJy8iRcrMhIsv9nVnvv8eRo2CL79Ugo836smLyG7WrYNhw+DZZyE/35cnuPNOqFs36MikLJTkRQSAX3/1yf3552HHDrjgArj/fkhNDToyiYSSvEglt2oVDB8OI0f65P7Xv8KQIVpUO1EoyYtUQs7BnDkwYgRMnervCyf31q2DjU2iS0lepBLZvt1PhXzqKT9rpn59P3Pm2ms1LJOolORFEpxz/gzVceNgwgTYuBGOOAKee87PnqlVK+gIpTwpyYskqJ9+gokTfXJfuhRq1IB+/eCSS6BHD6iiCdSVgpK8SIJwzp+F+vrr8MYbfk47wF/+4mu7DxigaZCVkZK8SBzbsgVmz4aZM+Htt2HZMn9/ly5+OuTZZ8PBBwcbowRLSV4kjmzd6g+YfvyxT+xz50JuLlSvDl27ws03+3oyzZsHHanECiV5kRiVk+PLCWRkwLx5vi1c6BfJNoNjj4V//ANOOcUvkl2jRtARS4k559dIXL0a1qzxl4cf7mtIRJmSvEjAsrNhxQr44QdYssQn8oUL/fWcHP+cevXguOP8PPb0dH+9YcNg45Y95OX5+stZWbu3dev86cRr1xa0NWv8B1/YLbfEZ5I3s17AU0AS8JJz7uHy3qdILHDOj5mvX+//xlev9u2XX/zlTz/B8uX+emEHHuirPZ5+ur/s2NFPedRsmAqQkwObNvle9qZNBe33333buLHg+vr1vm3Y4C83bvQfelEaN/ZF95s189/QzZvD/vvvfnnAAeXylso1yZtZEvAscCqQCXxhZtOcc4vLc79ScbZuLeikbNjg/58Xblu2wLZtBS0725+Qk5vr/55ycvz1vDy/vT3/RpKSClqVKlC1qm/JyQWX4Vat2u6XRV0Pv75wq1Jl92bm48jPL7jMz/ex7tgBO3f6tmOHf/9bt/r3GW6//+7/LTZsKOiJF1a1qv+bbtkSTj0VDj0UDjnEt8MO8ycoSSHO+f8gO3b8uW3fvnvLzi5ohf/jhT+owh/Y1q3wxx8FbfNmv829MfM/q+rXh0aNfDv44ILrjRtDkya7t8aN/X/AgJR3T74zsNw5twLAzCYCfQEl+TiQlwcrV8KPP/peZ+G2Zo1P7Fu2FP/6mjWhTh1/WaNGwWWdOn9O0OHkCgWX4b/tvDyfZPPy/BdCuOXk+L/rzZsLvjB27iz6MtyiJSnJUa1qPrVq5FM7JZdaKXn+sloubevvpFGLHTSsuZ1GtbbTsOZ29quzlf3rbWX/On/QuOY2qri83b9F1ubDmnyYHbqv8LdMSVr4H6zw9X3dV9TtvbX8PWILt/AHFf6QwveHP7DwhxhuhT/EvLzdP6DCLfxtunNn8T3kkkpJgdq1/Zlf4ctatfw3bZ06Ba1uXZ/Ew5fhVr8+NGjg709KiiyWClbeSf4AYFWh25nAceW8Tykl5+Dnn2H+fPj2Wz8WvHgxfPedT6JhSUl+KKFlS+jcGfbbr+AXaLNmvsNSv75v9er5nnPU5eX5b5bNm3fvhYV7ZuHu9Natu/fktm3Dbd1GfvYOcrNzyNmWQ252Drnbc3Hbd5C/M5f8HTm+5eRShfxdzXBUIZ9q7KQaO0kmh6S8fMgD9tHxqzCFvyH3/Lbc231F3S6qhb+Fw9cL/+wx2/3nVrhVrbr7T7GkJH9ftWr+Gz98X+Fv+8I/uwq35GQ/haioVqOGT+Lhy5QUn8DDvYoaNSr1WFfgB17N7CrgKoCDDjoo4Ggqhw0b4NNP4YsvfMvIgN9+K3g8NRXatPGzNtq08cMIqanQooX/G41YTk7BOGbhcc3wWGd4vGPjxt3HRTdu9Mm7pMLJJNSsRg2SatYkKSWF6k1r7p4UwgkjJcW/rnr1ggRTvXrRiSjcCo/9hBNZUQkunAT3TIaFE+aelyVpIntR3kn+F+DAQrdbhO7bxTk3ChgFkJaWFuFvMinKhg3+hJlZs3z75hvfe09KgnbtoE8fSEvzU/LatStDLZPwdLDw4Hzhy3Xr/jzTYNOm4rdVpYr/WdygQcHPgv333/2nc+Gf1nXq+J/fdeoU/BSvXdsn9qh8I4nEt/L+K/gCaG1mrfDJfSBwQTnvs9Jzzk/B+/e//VmQn33m70tJ8Uu33XefP3Hm2GN9LtyrvDyfsFet8i0z07fwFJFw27btz69NSio4+NS0qd9h+EBU+EBVuDVs6FudOuqdikRRuSZ551yumV0HvIefQjnGOfdtee6zssrPh//+FyZP9sn955/9/WlpMHSoL0jVqZMfedhNXp5P2CtW+LZypT+yunKlb5mZ/gBZYSkpfuxm//39DsJTwJo3Lxio328/n7Qr8VioSCwo99+zzrl3gHfKez+V1TffwPjxvoTsqlV+mPmUU+Duu+HMM33+JSfHT5GZucwXN1m2zE/QDif1wtNOqlTxL2rZ0p9G2bIlHHSQT+oHHugvGzZUb1skTmjQMg79/rsvHzt6NCxa5IeeTzsNHr5nC31bLaTWz0t8bdnBS/3lihUFE9HBj2sfeigccwz07+/n+R58MLRq5RN5gHN6RSS6lOTjyPz5fqGHCRMc2dlGequ1PNttNgPyJtJkwVz4z5qCJ1ev7s+sOfpoOO88v6ZbuDVqpJ64SCWhJB/LnCN3xc9Mfm4dwyfuzxerD6CWbeVi9//4G8/T4cdvYG1NaNsWevb0U2PatvXzHlu2jLuTNkQk+pTkY4VzflglIwPmz2dHxkLGfXYEj2y7jhV04giW8HTTMVx0/HLqHdsajrrfFzZp2VIHN0WkWEryQcnK8rVjP/us4IykDRvYQi1GJV3D41XGsTqnKZ1aruXxa76nzzUtqVL7nqCjFpE4oyRfEfLy/DSYOXN8Yp8719eVhV1nJOX07c9L2//K/85IZ+36ZLp3hVfugu7dm2HWLNj4RSRuKcmXh23b4PPP4ZNPfJszx9dXAT9/vEsXuPpqSE/HHXMsr0+vyV13+QUiTjwRXn/Yz14UEYmUknw0ZGf73nm4bsBnn/nKeWbQvj389a9+NeUTTvBzzkMzW+bMgVtO8Z37tm1h2jTo3VsTX0QkepTkyyI314+jz5zp27x5PqlXqeLPAL3xRt8lP+EEX4NlD7/9BrffDmPG+POOXnoJBg1SqRURiT6llZJatgymT/dJfdYsX5DLzJ9QdP31cPLJvrdet26xm8jPh5df9gl+0ya49Vb45z99PS0RkfKgJF+crVvho498Yp8+veBA6SGHwPnn+9oBJ5/sTywqgUWLYPBgX+L3L3+B55/3IzkiIuVJSb6wlSt92cZ//9sn+J07fZnGHj3g5pt97YBDDinVJvPy4PHH4Z57fIHF0aPhkks0tV1EKkblTvL5+X5++ptv+uS+cKG//7DD4Lrr4IwzfLf7T6UbS2bFCj/W/skn0K8fvPCCr7QrIlJRKl+Sz8mBjz+GN96At97yZXaTkuCkk3yXu3dvn+Qj4Jzvsd90k++xv/KKn2CjWTMiUtEqR5LfudMfMJ082Sf233/3NXl79fJd7DPP9OVzo+D33/1wzLRp0L27P9CqVQ1FJCiJm+R37oT33y9I7Bs3+hK7ffvC2WfDqaeWYFmk0lmwwFfuzcyE4cP9pBuNvYtIkBIryYeXR/rXv2DKFL+4ab16cNZZMGCAnxFTxvH1vXEOXnzRJ/WmTX0Ixx0X9d2IiJRaYiT5Zcv8Uc2JE/0Ye61avsd+/vm+x14OiT1s2zb429/8uHvPnn6VpsaNy213IiKlkhhJfskSGDECTj8dHnsM/ud/fKIvZz//7He1cKFfR/Wee1TCXURiS2Ik+V694Ndfo3bwtCQyMnyC37YN/vMf//0iIhJrEuOwYLVqFZrgX3/dz7isXt0XGVOCF5FYFVGSN7NhZrbUzL4xszfMrH6hx+40s+Vm9p2ZnRZxpDHAOXj0UTjnHOjQwRebbNcu6KhERIoXaU/+faC9c+4o4HvgTgAzawsMBNoBvYDnzCyuR6tzc30J+Ntvh3PPhQ8/hGZay0NEYlxESd45N8M5lxu6OQ9oEbreF5jonNvhnPsRWA50jmRfQdqxAwYO9NMk77oLJkzw51KJiMS6aI7JXwa8G7p+ALCq0GOZofv+xMyuMrMMM8vIysqKYjjRsW2bn405dao/wenBB3WCk4jEj33OrjGzmcB+RTw0xDn3Vug5Q4BcYHxpA3DOjQJGAaSlpbnSvr48bdrkS9nMmeNr0Vx2WdARiYiUzj6TvHPulL09bmaXAL2BHs65cJL+BTiw0NNahO6LG1lZfmbmwoX+HKsBA4KOSESk9CKdXdMLuA3o45zbVuihacBAM6tuZq2A1sDnkeyrIq1dC127wuLFvuyNEryIxKtIT4Z6BqgOvG++ju4859xg59y3ZjYJWIwfxrnWOZcX4b4qxG+/+RI3K1f6BaG6dg06IhGRsosoyTvnDt3LYw8CD0ay/Yq2caOvP7N8uT+LVQleROJdYpQ1iILNm/0Y/Lff+iGa7t2DjkhEJHJK8vg1u888E+bP9xWKe/UKOiIRkeio9El++3bo08dPk5w40c+JFxFJFJU6yefn+4W2P/zQ14PXLBoRSTSV+tzNW2+FSZNg2DC46KKgoxERib5Km+SffBKeeAL+/ne45ZagoxERKR+VMslPngw33+zX8x4+HPwUfxGRxFPpkvx//+uHZo4/Hl59Vcv1iUhiq1RJ/vvv/Uya1FQ/F17lgkUk0VWaJL9pk58eWbUqvPsuNGoUdEQiIuWvUkyhzMuDCy/05QpmzoRWrYKOSESkYlSKJH/PPb4WzXPPqR6NiFQuCT9c89pr8NBDcNVVMHhw0NGIiFSshE7yX34Jl14KJ5wATz+tqZIiUvkkbJLPyoKzzvIHWKdOhWrVgo5IRKTiJeSYfH6+nwu/di188gk0axZ0RCIiwUjIJP/oo/Dee/5Aa1pa0NGIiAQn4YZrPvkE7r4bzj1XB1pFRBIqyf/2Gwwc6M9offFFHWgVEUmY4Zr8fLj4Yn/Add48qFs36IhERIKXMEl+2DBfruDZZ6Fjx6CjERGJDVEZrjGzW8zMmVnj0G0zsxFmttzMvjGzY6Kxn+J8+ikMGeJXdvrb38pzTyIi8SXiJG9mBwI9gZ8L3X060DrUrgKej3Q/e1OzJpxyisbhRUT2FI2e/HDgNsAVuq8v8Irz5gH1zax5FPZVpI4dYfp0qFevvPYgIhKfIkryZtYX+MU59/UeDx0ArCp0OzN0X1HbuMrMMswsIysrK5JwRERkD/s88GpmM4H9inhoCHAXfqimzJxzo4BRAGlpaW4fTxcRkVLYZ5J3zp1S1P1mdiTQCvja/EB4C2CBmXUGfgEOLPT0FqH7RESkApV5uMY5t9A519Q5l+qcS8UPyRzjnPsVmAZcHJplkw5scs6tiU7IIiJSUuU1T/4d4AxgObANuLSc9iMiInsRtSQf6s2Hrzvg2mhtW0REyiahateIiMjulORFRBKY+ZGV2GBmWcDKMr68MfBbFMMJkt5LbEqU95Io7wP0XsJaOueaFPVATCX5SJhZhnMuIZYI0XuJTYnyXhLlfYDeS0louEZEJIEpyYuIJLBESvKjgg4givReYlOivJdEeR+g97JPCTMmLyIif5ZIPXkREdlDQiV5M7s/tBLVV2Y2w8z2DzqmsjKzYWa2NPR+3jCz+kHHVFZmNsDMvjWzfDOLu5kQZtbLzL4LrXR2R9DxlJWZjTGzdWa2KOhYImVmB5rZR2a2OPR/64agYyoLM0sxs8/N7OvQ+/jfqO8jkYZrzKyuc25z6Pr1QFvn3OCAwyoTM+sJfOicyzWzRwCcc7cHHFaZmFkbIB94AfiHcy4j4JBKzMySgO+BU/FF+L4AznfOLQ40sDIws5OALfgFfdoHHU8kQosQNXfOLTCzOsB84Kx4+1zMl/Ct5ZzbYmbJwCfADaHFlqIioXry4QQfUovdV6uKK865Gc653NDNefhyzXHJObfEOfdd0HGUUWdguXNuhXNuJzARv/JZ3HHOzQY2BB1HNDjn1jjnFoSu/wEsoZiFiWJZaPW8LaGbyaEW1byVUEkewMweNLNVwIXAP4OOJ0ouA94NOohKqsSrnEkwzCwV6Ah8FnAoZWJmSWb2FbAOeN85F9X3EXdJ3sxmmtmiIlpfAOfcEOfcgcB44Lpgo927fb2X0HOGALn49xOzSvJeRKLNzGoDU4Eb9/glHzecc3nOuaPxv9Y7m1lUh9LKq558uSlupaoijMfXtR9ajuFEZF/vxcwuAXoDPVyMHzwpxecSb7TKWYwKjWFPBcY7514POp5IOec2mtlHQC8gagfH464nvzdm1rrQzb7A0qBiiZSZ9QJuA/o457YFHU8l9gXQ2sxamVk1YCB+5TMJUOiA5WhgiXPuiaDjKSszaxKeOWdmNfAH+KOatxJtds1U4HD8TI6VwGDnXFz2usxsOVAdWB+6a14czxTqBzwNNAE2Al85504LNKhSMLMzgCeBJGCMc+7BYCMqGzObAHTDVztcCwx1zo0ONKgyMrO/AP8FFuL/3gHucs69E1xUpWdmRwHj8P+3qgCTnHP3RXUfiZTkRURkdwk1XCMiIrtTkhcRSWBK8iIiCUxJXkQkgSnJi4gkMCV5EZEEpiQvIpLAlORFRBLY/wdVsP4jv7Ev2wAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新一次之后的模型\n", "y_pred = multi_linear(x_train)\n", "\n", "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "因为只更新了一次,所以两条曲线之间的差异仍然存在,我们进行 100 次迭代" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 20, Loss: 65.56586\n", "epoch 40, Loss: 15.41177\n", "epoch 60, Loss: 3.70702\n", "epoch 80, Loss: 0.97122\n", "epoch 100, Loss: 0.32874\n" ] } ], "source": [ "# 进行 100 次参数更新\n", "for e in range(100):\n", " y_pred = multi_linear(x_train)\n", " loss = get_loss(y_pred, y_train)\n", " \n", " w.grad.data.zero_()\n", " b.grad.data.zero_()\n", " loss.backward()\n", " \n", " # 更新参数\n", " w.data = w.data - 0.001 * w.grad.data\n", " b.data = b.data - 0.001 * b.grad.data\n", " if (e + 1) % 20 == 0:\n", " print('epoch {}, Loss: {:.5f}'.format(e+1, loss.data.item()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到更新完成之后 loss 已经非常小了,我们画出更新之后的曲线对比" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqW0lEQVR4nO3de5yMdf/H8ddnT5Zdp3UKK7tKzhYtrYqUcogciogcuivc3UK5cxfuXyUqN6XIoRVFlFMlFTkfUlQoOW05RBZrt3Vclt3Z+f7+mKGtFrs7M3vNzH6ej8f1mJlrrrm+n2vx9t3vXNf3EmMMSiml/FOA1QUopZTyHA15pZTyYxrySinlxzTklVLKj2nIK6WUH9OQV0opP5brkBeRmSKSLCI7s62LEJGVIrLX+VjauV5EZKKI7BORn0SkkSeKV0opdXV56cm/B7T5y7pngdXGmOrAaudrgLZAdefSD5jqWplKKaXyQ/JyMZSIRAGfG2PqOl//DLQwxhwTkYrAOmNMDRF52/n8w79ud7X9ly1b1kRFReXvSJRSqpDaunXr78aYcjm9F+TivitkC+4koILzeWXgcLbtEp3rrhryUVFRbNmyxcWSlFKqcBGRQ1d6z21fvBrHrwR5niNBRPqJyBYR2ZKSkuKucpRSSuF6yB93DtPgfEx2rj8CVMm2XaRz3d8YY+KNMbHGmNhy5XL8bUMppVQ+uRryS4A+zud9gE+zre/tPMsmDjh9rfF4pZRS7pfrMXkR+RBoAZQVkUTgeeBVYIGIPAocAh50br4UuBfYB5wHHslvgZmZmSQmJnLhwoX87kK5IDQ0lMjISIKDg60uRSmVD7kOeWPMQ1d4q2UO2xrgX/ktKrvExESKFy9OVFQUIuKOXapcMsaQmppKYmIi0dHRVpejlMoHr7/i9cKFC5QpU0YD3gIiQpkyZfS3KKV8mNeHPKABbyH92Svl23wi5JVSyp+NGgUbNnhm3xryuTBx4kRq1apFz549WbJkCa+++ioAixcvZvfu3Ze3e++99zh69Ojl14899tif3ldKqb/65Rd4/nlYv94z+3f1itdCYcqUKaxatYrIyEgAOnToADhCvn379tSuXRtwhHzdunWpVKkSAO+88441BWdjs9kICtI/ZqW81ZQpEBxg4/E63wG3un3/2pO/hgEDBnDgwAHatm3LhAkTeO+99xg4cCDffPMNS5Ys4ZlnnqFBgwaMHTuWLVu20LNnTxo0aEB6ejotWrS4PE1DeHg4I0aMICYmhri4OI4fPw7A/v37iYuLo169eowcOZLw8PAc65g9ezb169cnJiaGXr16AdC3b18WLVp0eZtLn123bh3NmjWjQ4cO1K5dm2effZbJkydf3u6FF15g/PjxAIwbN47GjRtTv359nn/+eff/AJVSV5SWBu/OyKKrfT7X7V7jkTZ8q4s3ZAj8+KN799mgAbzxxhXfnjZtGl9++SVr166lbNmyvPfeewDceuutdOjQgfbt29OlSxcAli1bxvjx44mNjf3bfs6dO0dcXBxjxoxh2LBhTJ8+nZEjRzJ48GAGDx7MQw89xLRp03KsYdeuXYwePZpvvvmGsmXLcuLEiWse1rZt29i5cyfR0dH88MMPDBkyhH/9y3FW64IFC1i+fDkrVqxg7969fPfddxhj6NChAxs2bKB58+bX3L9SynVz5sCZtEAGBsdD/0XX/kA+aE++gISEhNC+fXsAbr75Zg4ePAjApk2b6Nq1KwA9evTI8bNr1qyha9eulC1bFoCIiIhrttekSZPL57Y3bNiQ5ORkjh49yvbt2yldujRVqlRhxYoVrFixgoYNG9KoUSMSEhLYu3evq4eqlMoFY+CtN7NoJD8Q16MaeGhaF9/qyV+lx+3tgoODL5+OGBgYiM1mc3mfQUFB2O12AOx2OxkZGZffCwsL+9O2Xbt2ZdGiRSQlJdGtWzfAcbHTc889R//+/V2uRSmVN+vXw66EQGYyERk8yGPtaE/eBcWLF+fs2bNXfJ0bcXFxfPTRRwDMmzcvx23uuusuFi5cSGpqKsDl4ZqoqCi2bt0KwJIlS8jMzLxiO926dWPevHksWrTo8m8OrVu3ZubMmaSlpQFw5MgRkpOTr7gPpZT7vDXJEBFwku63HoaGDT3Wjoa8C7p37864ceNo2LAh+/fvp2/fvgwYMODyF6+58cYbb/D6669Tv3599u3bR8mSJf+2TZ06dRgxYgR33HEHMTExPP300wA8/vjjrF+/npiYGDZt2vS33vtf93H27FkqV65MxYoVAWjVqhU9evSgadOm1KtXjy5duuT5PymlVN4dPgyLFxses8dT9Ol/erStPN0ZytNiY2PNX28asmfPHmrVqmVRRZ53/vx5ihYtiogwb948PvzwQz799NNrf7AA+fufgVIFbeRIeHmMnQOVmhF1aD24eJqziGw1xvz9jA98bUzeD23dupWBAwdijKFUqVLMnDnT6pKUUh508SLET7FxH18QNbijywF/LRryFmvWrBnbt2+3ugylVAFZuBBSTgYxMGQ6PDbb4+1pyCulVAF6641MasgBWvaJhFycDu0q/eJVKaUKyObN8O3WYJ4wkwkYNLBA2tSevFJKFZDXx9spKWd5pPmvULdugbSpPXmllCoAv/4KH30M/c00iv+74C5A1JAvAFFRUfz+++9Wl6GUstCbbxgCTBZP3vgl3HtvgbWrIZ8HxpjL0whoHUqp3Dp1CmZMz6I784gc3hsCCi56NeSv4eDBg9SoUYPevXtTt25dDh8+fMXpeTt16sTNN99MnTp1iI+Pv+a+v/zySxo1akRMTAwtWzruh559GmCAunXrcvDgwb/V8dJLL/HMM89c3u7SFMgAc+bMoUmTJjRo0ID+/fuTlZXlrh+HUiofpk+HtPQgni77PlxhIkJPccsXryLyFPAYYIAdwCNARWAeUAbYCvQyxmRccSe5YMFMwwDs3buXWbNmERcXd9XpeWfOnElERATp6ek0btyYBx54gDJlyuS4z5SUFB5//HE2bNhAdHR0rqYPzl5HSkoKTZs2Zdy4cQDMnz+fESNGsGfPHubPn8/XX39NcHAwTzzxBHPnzqV37955/MkopdwhMxMmvpbBnWyk4TN3Q5EiBdq+yyEvIpWBQUBtY0y6iCwAugP3AhOMMfNEZBrwKDDV1fasULVqVeLi4gD+ND0vQFpaGnv37qV58+ZMnDiRTz75BIDDhw+zd+/eK4b85s2bad68+eXpgHMzfXD2OsqVK0e1atXYvHkz1atXJyEhgdtuu43JkyezdetWGjduDEB6ejrly5d37QeglMq3hQsh8XgI04pOhf4Ff7c4d51CGQQUFZFMoBhwDLgLuPR7ySzgBVwMeatmGs4+8deVpuddt24dq1atYtOmTRQrVowWLVpw4cKFPLeVffpg4E/7+OsEZN27d2fBggXUrFmTzp07IyIYY+jTpw+vvPJKnttWSrmXMfDaKxepyQHa/jMKcpiA0NNcHpM3xhwBxgO/4Qj30ziGZ04ZYy5Nmp4IVM7p8yLST0S2iMiWlJQUV8vxuCtNz3v69GlKly5NsWLFSEhIYPPmzVfdT1xcHBs2bODXX38F/jx98LZt2wDH3Z0uvZ+Tzp078+mnn/Lhhx/SvXt3AFq2bMmiRYsuTxl84sQJDh065NpBK6XyZf162LazCE8HvEnAU4MtqcEdwzWlgY5ANHAKWAi0ye3njTHxQDw4ZqF0tR5Pa9WqFXv27KFp06aA476qc+bMoU2bNkybNo1atWpRo0aNy8MqV1KuXDni4+O5//77sdvtlC9fnpUrV/LAAw8we/Zs6tSpwy233MJNN910xX2ULl2aWrVqsXv3bpo0aQJA7dq1GT16NK1atcJutxMcHMzkyZOpWrWq+34ISqlceX1sBuU4zcPdbRAZaUkNLk81LCJdgTbGmEedr3sDTYGuwHXGGJuINAVeMMa0vtq+CuNUw75A/wyUyrs9e6B2bXieF3jhpwegXj2PtXW1qYbdcQrlb0CciBQTx/3tWgK7gbVAF+c2fQDvmiRdKaU86NUxNopxnoEtEzwa8NfijjH5b4FFwDYcp08G4Bh++Q/wtIjsw3Ea5QxX21JKKV9w8CDM/TCA/kyj7MgBltbilrNrjDHPA8//ZfUBoImb9n/5JtiqYHnTncOU8hXjXs0iwJ7F0MZfwR1PWVqL11/xGhoaSmpqqoaNBYwxpKamEhoaanUpSvmMpCSYMcPQl/eoPPqfYHEH1eunGo6MjCQxMRFfOL3SH4WGhhJp0VkBSvmiCeOzyLQJw+ovh3sWWV2O94d8cHDw5atClVLKm504AVMm2+nGQm4c84jlvXjwgeEapZTyFW9NtJN2IZhnayyGdu2sLgfwgZ68Ukr5grQ0ePO1TO5jOfXHdPOKXjxoT14ppdwi/m07J9KKMDx6HnTubHU5l2lPXimlXHTxIowfc5E72UTcmPsK9KYg16Ihr5RSLpoebzh2siizK8+GB73ruk8NeaWUckF6Orz8/AWa8T0tX2oBgYFWl/QnGvJKKeWCt6faOXayKB9ETkd6vWt1OX+jIa+UUvl07hy88uJF7uIbWoxvD0HeF6ne8+2AUkr5mCmTskg+U5QXb3gfuna1upwced9/O0op5QPOnoWxozNpxSpun/CAV51Rk513VqWUUl7urTcyST0Xyqg6C6B9e6vLuSLtySulVB6dPu2YTrgdy7nlzR5ec3VrTrQnr5RSefTmuAxOng9lVOPPoWVLq8u5Kg15pZTKg5Mn4fXX7HTiExpN7Gt1OdekIa+UUnnw6gsXOHMhhBear4G4OKvLuSYNeaWUyqXffoM3JwfyMHOImfiY1eXkioa8Ukrl0v89nQZZWbz0wHaIibG6nFzRkFdKqVz46SeY/VExngyaStU3rL05d164JeRFpJSILBKRBBHZIyJNRSRCRFaKyF7nY2l3tKWUUlZ4dsApSnGK4YPPgw/d99hdPfk3gS+NMTWBGGAP8Cyw2hhTHVjtfK2UUj5nzWrDsk2lGB4+idLPD7K6nDxxOeRFpCTQHJgBYIzJMMacAjoCs5ybzQI6udqWUkoVNLsdhvU/xfUcYuCrkVC8uNUl5Yk7evLRQArwroj8ICLviEgYUMEYc8y5TRJQIacPi0g/EdkiIltSUlLcUI5SSrnPgg9sbN1fmtGVpxHav4/V5eSZO0I+CGgETDXGNATO8ZehGWOMAUxOHzbGxBtjYo0xseXKlXNDOUop5R4XL8LwIeeI4Ud6vt3cK6cSvhZ3hHwikGiM+db5ehGO0D8uIhUBnI/JbmhLKaUKzFv/O8+vqSUZ23A+Afe2sbqcfHE55I0xScBhEanhXNUS2A0sAS79btMH+NTVtpRSqqAcPw6jRgv3spTW73b36knIrsZdv3s8CcwVkRDgAPAIjv9AFojIo8Ah4EE3taWUUh43fMAJ0jPCmdDje4i51+py8s0tIW+M+RGIzeEt756eTSmlcvD9d4aZiyN4JnQSN0160upyXOJ73yIopZQH2e0wqOfvVCCLkWOLQ0SE1SW5RKc1UEqpbOZMT2fzvnKMjX6bEgN7W12Oy7Qnr5RSTmfPwn+GZtKEn+j1QVuvvW9rXvj+ESillJuMHvI7SedKMLHDagLimlhdjltoyCulFLD3F8OEd0vSJ/gDbnnncavLcRsNeaVUoWcMPNHlOKEmnVdezAA/uvpeQ14pVejNfTuNVTuu45UqU6k4rJfV5biVhrxSqlBLTYWnnjLcwrcM+KQ1BAZaXZJbacgrpQq1YQ8f5eSFosT/YzOBNzewuhy305BXShVa61dcZOaXlRhaaib13+pndTkeoefJK6UKpYsXof9Dp4niHM9/UAOKFrW6JI/QnrxSqlAaO+QYP58oz9TWiynW9g6ry/EYDXmlVKHz8y4bY94uQ/cin9Dmw75Wl+NRGvJKqUIlKwv+0S6JYuYcEyYFQenSVpfkURrySqlCZcKwo3xzKJJJsbO57rH2VpfjcRrySqlCY/ePGYycUIZORZbS84sePnu3p7zQs2uUUoWCzQZ97z1OuCnKtPhApLz/TF1wNdqTV0oVCmOfOMT3x6ow9a6FVOjd2upyCoyGvFLK723fdJ4Xp1ekW7HP6Lr4YavLKVA6XKOU8msZGdDnvlQiCGbyvDJQvLjVJRUot/XkRSRQRH4Qkc+dr6NF5FsR2Sci80UkxF1tKaVUbr3Yax/bU6sQ32kZZe671epyCpw7h2sGA3uyvR4LTDDG3AicBB51Y1tKKXVNaxb8zisLqvGP0h/T4cOHrC7HEm4JeRGJBNoB7zhfC3AXsMi5ySygkzvaUkqp3EhJyuLhXlBDfmHiytoQGmp1SZZwV0/+DWAYYHe+LgOcMsbYnK8Tgcpuaksppa7KGHik+X5OZIQz78VfCLu5ptUlWcblkBeR9kCyMWZrPj/fT0S2iMiWlJQUV8tRSineHPgLX+y9ifFNFhLz3w5Wl2Mpd/TkbwM6iMhBYB6OYZo3gVIicunsnUjgSE4fNsbEG2NijTGx5fzovopKKWtsW5nKsClRdAxfxb9Wdba6HMu5HPLGmOeMMZHGmCigO7DGGNMTWAt0cW7WB/jU1baUUupq0s7Y6d4pnfKkMGNpJaR4uNUlWc6TF0P9B3haRPbhGKOf4cG2lFKFnDEwoPku9p+vyNyh2yjTrLbVJXkFt14MZYxZB6xzPj8ANHHn/pVS6komDdjF3O31eKn+Qu4Y1+XaHygkdFoDpZTP2zDnN56Or0HHkmsZ/nW7QjG7ZG5pyCulfFrinrN07VuMGwIPMuurGwgIL2Z1SV5FQ14p5bMuptvpcttRzmcVYfGME5Ssd73VJXkdDXmllM8adPs2vj1Zg1mPfkWtPvoVYE405JVSPmn6Ez8Qvy2WZ+t9wf3T21pdjtfSkFdK+ZzVkxN4YmpdWpfcxOhNLfWL1qvQkFdK+ZTdyw/zwJMVqRl8gPnf30BgWOGceCy3NOSVUj7j+J4T3Ns+gKKk88XyIEpWL291SV5PQ14p5RPOn7hAhyZJpNhK8Vl8EtffeYPVJfkEDXmllNez2+z0qr+d79Nq8sG/fyD2sQZWl+QzNOSVUl7N2A3/jvuKj4/cwusd1tFx3O1Wl+RTNOSVUl7t5VbrmLD1Dp6sv57Bn9xpdTk+R0NeKeW1JnVew8jVd9Kr2te8sbUZEqCnSuaVhrxSyivN6rOGQYvvolOl75i56xYCgjSu8kN/akopr/PxoHX8Y/Yd3F32B+YlNCAo1K2zohcqGvJKKa+y4r9f0X3SrdxSIoHFCbUoUjzE6pJ8moa8UsprLH9uHR1Hx1Kr2CG+2FmVsDJ6NaurNOSVUl7hs0Er6fBqU2qGJbJ6V0VKV9H7s7qDhrxSynIfPbaM+ye1IKbEQVYnVKZslAa8u2jIK6Us9cFDn9Ftxj00idjHyr1RRETqnZ3cSUNeKWUNY3i3wyc8PK8dzcr/zPJ9N1KyfBGrq/I7Loe8iFQRkbUisltEdonIYOf6CBFZKSJ7nY+lXS9XKeUPzMUMXon9iH981pl7Ku/mi301CS8dbHVZfskdPXkbMNQYUxuIA/4lIrWBZ4HVxpjqwGrna6VUIWdLPc0/b1jB8G1d6FnvJz7bX4dixQOtLstvuRzyxphjxphtzudngT1AZaAjMMu52Sygk6ttKaV827mfE+kc/QNvH2nPc+1+4v3t9QkpolMVeJJbx+RFJApoCHwLVDDGHHO+lQRUcGdbSinfcnzVDlrU+52lZ5sxddAeXv68vt61rwC4LeRFJBz4CBhijDmT/T1jjAHMFT7XT0S2iMiWlJQUd5WjlPIi2/+3nLhWxdltu4nFkxIZ8GYtq0sqNNwS8iISjCPg5xpjPnauPi4iFZ3vVwSSc/qsMSbeGBNrjIktV66cO8pRSnkLm40P2s2l6X+akRlcjHVfnOe+gVWtrqpQccfZNQLMAPYYY17P9tYSoI/zeR/gU1fbUkr5DtvRZIZGf0zPpT2JrXiUrftK0bhtWavLKnTcMbXbbUAvYIeI/OhcNxx4FVggIo8Ch4AH3dCWUsoHpHy5lW6dLrD24oM8eU8Cr31Rk2A9Q9ISLoe8MWYjcKWvT1q6un+llA+x2/lq4Hx6Tr2NFCnPrFGH6P3fmlZXVajpJM1KKbfI/DWRUXet5eWDPYgOS2bjkovcfJeOv1tNpzVQSrls/1vLaFY9idEHe9H7tv38cPQ6br6rpNVlKbQnr5RygTl5itkdFjFwYzeCAg3z3zjGg4NvsroslY2GvFIq74zh0LRl/HNIEZZlPEbz63/l/TWRXH9DCasrU3+hIa+UypOsg4eZ1H45I3d1h4AAJgw9zJNjownU6We8ko7JK6Vyx2Zj+7Mf0vSG4zy16zGa10xm1y8hDBlfRQPei2lPXil1TamL1vJi/6NMOdGNiJA0PvhfMt0HVdO5Z3yAhrxS6ooyftzN5Ic2MiqhK2dozuP3HGLMB9GUKavp7it0uEYp9TfmWBKL275NnYbBPJ3QjyY3nWb71iymraimAe9jNOSVUpeZo8f4omM8t1ROpPOX/QkpHc6yeadY/nMUdRuFWF2eygcNeaUU5shRPrsvniaRR2m/pB8pYVWZPiaZ7ckVadOtlNXlKRfomLxShVjmDzv5aOg3jFsXyzbTj2rFk5kxPJleQ8vrhGJ+QkNeqcLGbiflg5XEP3+EKQdac5R+VC95nHeHJ9PzKQ13f6Mhr1QhYY4e47uXVxE/O5S5Z+/jIqG0qn6A6aPO0ObBCgTo4K1f0pBXyp9dvEjie6uYMyGFWT/fQgK9KBaQziP3JDJofFVq1a9mdYXKwzTklfI3GRn8vngjSyYfZv43VVhpa4shgGZVfuXf/ZPo+uR1lChxo9VVqgKiIa+UPzh/nt8+2Mjit4/zyY9RbLDdgZ1AosKS+W+3A/T+v2huuCna6iqVBTTklfJFdjtpm3fy1Ts/s2o1rDpcg59MKwDqlEpkeKt9dB4SRcO48oiUt7hYZSUNeaV8QWYmv6/fxbcLDvHt15ms31uJTZmxZFKfInKR26scYmyrX+g0JJqb6kRaXa3yIhrySnkbm43kTfvZsSyRHd+eZ+vOImxOqcY+0wBoQCA2GkT8xtNxP3N3r4rc1rEsRYvqjTpUzjTklbLIhZPpHNiQyP7NKezfcZ79+2HPkRLsOFuVZGoANQCoGJxCXLWjPH7LduI6V+TmtuUJC9OzYlTueDzkRaQN8CYQCLxjjHnV020q75WZCRcvOh5ttj8es7Ic7xvz5+0DAyEg4I/HoCDHEhz8x+It090aA+fPw6mThtSDZzlx4BSph9I4cSSdpMMZHD0CR1OCOHI6nKPpERyzVwCqOxcoLmepEX6EdnUOUS/mMPVaRFCvbRUqRJYDyll5aMqHeTTkRSQQmAzcAyQC34vIEmPMbk+2qzwrPR2OHYPjxyEp6Y/HEyfg1Kk/L2lpjuBzLIasLPcncmCAneAgQ0iQneBAOyFBhuCgbOuCjPO5ISjw70uA2AkQQwCGADEIBmO3Y7cZ7HbjeMwyZGZCRgZkZMLFzAAyMoVzGcGkZYRwzhZKmr0ohgBAgBLO5Q9lSaFSSCqVws8QUzGFqlX2cEOdUG6ILc0NzStTNqo4IjXd/vNRhZune/JNgH3GmAMAIjIP6AhoyHsxYyAlBfbscSy//goHD/6xJCfn/LlSoemUCj5HqcCzlOI01ewnKZ51imK2MxTLPE0x+1mKkk4oFwjCRjCZlx8DsCM4uvGXHg2CnQCyCCSLQOwEYCMIG0FkEkwmwdgIIsMeQmZGMJkZwWQQQiZ/f7z0PItAbASR7vysjaDL7WRfsr8SDAHYCSHDsYiN8MAsQgKzCAvJJLykjbCidsLDDGHhULpMIGUqBBFRKZQy14cREVWC8nXLU6RyWRDtkauC5emQrwwczvY6Ebgl+wYi0g/oB3D99dd7uBz1V5mZsHMnbNkCW7fCrl2OYE9N/WObkKAsqpY4SVTIUTqaA1QttpvK53+hAse5jiQqcJxypBCcYYfwCChT5o+lVCkoXjzbUh7CwqBIEQgN/WO5NPZyaTwmKMgxPiPy98WYPxYAu93xPPtjTuv+NBZkBzLAXPxjLEjE8RgQACEhjnouPQYHQ7FiULSoozalfITlf1uNMfFAPEBsbKy5xubKRUlJsH49fPUVfP89bN/uGCMHKF3sAnVLJfJAkQRqF99MrbPfUYs9VLYdIeB0AFStCtWqOZYqN0Kl5lC5MlSqBBUrQkQEOgGKUt7F0yF/BKiS7XWkc50qICdOwMqVsG6dY0lIcKwPD80ktuxBBpX5jtiUZcRmfkP0+V8RCYM6daBuXajTBuoMhZtugipVtAerlA/y9L/a74HqIhKNI9y7Az083GahZowjyD//HD77DL7+2jFSUbxoJs3L7eHRssto8ftCGlz4kaDUEGjUCLo1gSYvQ+PGEB2tvXGl/IhHQ94YYxORgcByHKdQzjTG7PJkm4XVzp3wwQewYAHs3+9Y1+C6Ywwv/wXtk6Zzc/pWgk4Wg+bN4c6HoMXbEBOjvXOl/JzH/4UbY5YCSz3dTmF06BDMmwdz58KOHRAYYLi78h6GlppD+1PvU+X4EYiLgyc7wt2THL12DXWlChX9F+9jbDbHMMyUKbBqlWNd00oHeavUO3Q9FU/51HPQqhV0GAXt2kF5nZxKqcJMQ95HHD0K77wD8fFw5AhUKXWGF8vN5uGU16mWfBjatIGeE6FDB8epfkophYa819u5E159FebNc1wt2rrSDiYHj6LdqU8IatoERg2DLl2gbFmrS1VKeSENeS+1eTO88gosWQJhRTIZVGY+TyS/wI1njsNjvWDANqhf3+oylVJeTkPey6xfDy++CGvXQkSxC7xQcipPnn6JiHKV4IWh8PDDjitHlVIqFzTkvcSOHfDss7B0KVQsfpbXwsbR79zrhN8aA8/NdnyJ6i3TLSqlfIZe9WKx336Dvn0hJsbw9ZoLjC36AvvPlufpZt8Tvn4pbNwI7dtrwCul8kV78hY5fx5Gj4bXXzeQZWdo+HSeOzuciFaN4eWNcPPNVpeolPIDGvIW+OwzePJJx8VMD5f8nNGnB1K1RnkYuwjuusvq8pRSfkRDvgAdOgSDBjnOmKld/DfW8zDNyx2D6eMdp0HqkIxSys10TL4AZGXBa69B7dqGVcsyGBs0gh9t9Wj+yr2OCdy7dtWAV0p5hPbkPezAAejTx/H96X1ha5mU+QhVO98Mb+wAvUmKUsrDNOQ9xBjHNARPPWUIzLzALPrTq8I3yFvToG1bq8tTShUSOlzjAUlJcN990K8f3GLfxI6MmvQeUgbZ8ZMGvFKqQGlP3s1WrYIePQxnT9p4k38zsMLnBMx63zGPu1JKFTDtybuJ3e44771VK0PZMwfYaoth0IAMAnZs14BXSllGe/JukJoKvXrBsmXQI3ABbxd/lvDFUx3T/yqllIU05F30/ffQ5QE7SUeymMIgBjTdhcz/GipVsro0pZTS4RpXzJsHzW63Q1ISG+238s/nSiNr12jAK6W8hoZ8PhgDL70EDz0EjbM2szXsDhp/8SK8/LLeQ1Up5VU0kfLo4kV4/HHD++8LvXif6bUnUuTz1Xphk1LKK7nUkxeRcSKSICI/icgnIlIq23vPicg+EflZRFq7XKkX+P13uOduO++/L7zESGZ1+Igi36zVgFdKeS1Xh2tWAnWNMfWBX4DnAESkNtAdqAO0AaaISKCLbVnq4EFoeksW331jYx7dGPlsFvLJxxAebnVpSil1RS6FvDFmhTHG5ny5GYh0Pu8IzDPGXDTG/ArsA5q40paVEhLg9qY2fj+YxpqAe+g2u73jBqwB+pWGUsq7uTOl/gEscz6vDBzO9l6ic93fiEg/EdkiIltSUlLcWI57bNsGzW61YUs+wfqwdty6dozjpHillPIB1/ziVURWAdfl8NYIY8ynzm1GADZgbl4LMMbEA/EAsbGxJq+f96SNG6FdGxul04+yqkx3blwTD3XrWl2WUkrl2jVD3hhz99XeF5G+QHugpTHmUkgfAapk2yzSuc5nfPkl3N8pi+sz9rPy+seosnYuREdbXZZSSuWJq2fXtAGGAR2MMeezvbUE6C4iRUQkGqgOfOdKWwVp2TLo0D6Lmhd/4qs6/6TKt4s04JVSPsnV8+TfAooAK8VxZ6PNxpgBxphdIrIA2I1jGOdfxpgsF9sqEKtWQeeOWdTL+pHVt/4fpZZ+AiVLWl2WUkrli0shb4y58SrvjQHGuLL/grZhA3RoZ+OmzN2saD6GUssWQrFiVpellFL5ple8Om3aBO1aZ1I1Yx+rbnuBMsvmaMArpXyehjywZQu0aZnBdRcOsTpuJOWXv68Br5TyC4X+ap7du6FViwwi0o+wpslzVFo1G8LCrC5LKaXcolD35I8ehbYt0gk5d5rVsf+hyur3NOCVUn6l0Ib8mTPQ9o5znEixs77mUKqtfkfnoVFK+Z1CGfIZGXB/63Ps3hfC5xX70Wj9BChRwuqylFLK7QpdyBsDjz50jtWbw3i3xGBab/wvlC9vdVlKKeURhe6L1+FPpTPn4zBeCnmJvusfgWrVrC5JKaU8plCF/LtvZ/Dqm0XpF/AOI5bdDg0aWF2SUkp5VKEJ+U3fGAY8IdzNSibPLYXcdafVJSmllMcVipA/cgTub51GpP035o/cSVD3LlaXpJRSBcLvQz49HTrdcYK0NFjSfjoRo4ZYXZJSShUYvz67xhjo1+UEW/ZHsPimYdRZOAocs2UqpVSh4Nc9+ddfOMOcpRGMKjGejhuGQmio1SUppVSB8tuQX70sg2GjwugS+DEj190NFSpYXZJSShU4vxyuSUqCng+kU4MjvDs7CGnYwOqSlFLKEn7Xk8/Kgh53HuNMejALH1tBeI8OVpeklFKW8buQf2ngcdYmVGRKzUnUmTrQ6nKUUspSfhXyq5ekMWpaOfoUnU/fdX0hyC9Ho5RSKtf8JuSTjhl6PphJTRKYvDhSv2hVSincFPIiMlREjIiUdb4WEZkoIvtE5CcRaeSOdq4kKwt6ND/MmYtFWPjv7whrdZsnm1NKKZ/hcsiLSBWgFfBbttVtgerOpR8w1dV2rmbm8H2s3Xc9UxrNoM7/+niyKaWU8inu6MlPAIYBJtu6jsBs47AZKCUiFd3QVo76dkljfv0x9F3TW69oVUqpbFz6ZlJEOgJHjDHb5c/hWhk4nO11onPdMVfau5Lgxg14cHsDT+xaKaV82jVDXkRWAdfl8NYIYDiOoZp8E5F+OIZ0uP76613ZlVJKqb+4ZsgbY+7Oab2I1AOigUu9+Ehgm4g0AY4AVbJtHulcl9P+44F4gNjYWJPTNkoppfIn32PyxpgdxpjyxpgoY0wUjiGZRsaYJGAJ0Nt5lk0ccNoY45GhGqWUUlfmqauFlgL3AvuA88AjHmpHKaXUVbgt5J29+UvPDfAvd+1bKaVU/vjNFa9KKaX+TkNeKaX8mIa8Ukr5MXEMn3sHEUkBDuXz42WB391YjpX0WLyTvxyLvxwH6LFcUtUYUy6nN7wq5F0hIluMMbFW1+EOeizeyV+OxV+OA/RYckOHa5RSyo9pyCullB/zp5CPt7oAN9Jj8U7+ciz+chygx3JNfjMmr5RS6u/8qSevlFLqL/wq5EXkJeftBn8UkRUiUsnqmvJLRMaJSILzeD4RkVJW15RfItJVRHaJiF1EfO5MCBFpIyI/O29n+azV9eSXiMwUkWQR2Wl1La4SkSoislZEdjv/bg22uqb8EJFQEflORLY7j+NFt7fhT8M1IlLCGHPG+XwQUNsYM8DisvJFRFoBa4wxNhEZC2CM+Y/FZeWLiNQC7MDbwL+NMVssLinXRCQQ+AW4B8dMq98DDxljdltaWD6ISHMgDcdd2+paXY8rnHeaq2iM2SYixYGtQCdf+3MRxzztYcaYNBEJBjYCg5131HMLv+rJXwp4pzD+fEtCn2KMWWGMsTlfbsYxJ79PMsbsMcb8bHUd+dQE2GeMOWCMyQDm4bi9pc8xxmwATlhdhzsYY44ZY7Y5n58F9uC4+5xPcd4iNc35Mti5uDW3/CrkAURkjIgcBnoC/2d1PW7yD2CZ1UUUUle6laXyEiISBTQEvrW4lHwRkUAR+RFIBlYaY9x6HD4X8iKySkR25rB0BDDGjDDGVAHmAgOtrfbqrnUszm1GADYcx+O1cnMsSrmbiIQDHwFD/vKbvM8wxmQZYxrg+G29iYi4dSjNUzcN8Zgr3Y4wB3Nx3LzkeQ+W45JrHYuI9AXaAy2Nl395koc/F1+T61tZqoLlHMP+CJhrjPnY6npcZYw5JSJrgTaA274c97me/NWISPVsLzsCCVbV4ioRaQMMAzoYY85bXU8h9j1QXUSiRSQE6I7j9pbKQs4vLGcAe4wxr1tdT36JSLlLZ86JSFEcX/C7Nbf87eyaj4AaOM7kOAQMMMb4ZK9LRPYBRYBU56rNPnymUGdgElAOOAX8aIxpbWlReSAi9wJvAIHATGPMGGsryh8R+RBogWO2w+PA88aYGZYWlU8icjvwFbADx793gOHGmKXWVZV3IlIfmIXj71YAsMAYM8qtbfhTyCullPozvxquUUop9Wca8kop5cc05JVSyo9pyCullB/TkFdKKT+mIa+UUn5MQ14ppfyYhrxSSvmx/wd/D+ouaWFrvwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新之后的结果\n", "y_pred = multi_linear(x_train)\n", "\n", "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n", "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,经过 100 次更新之后,可以看到拟合的线和真实的线已经完全重合了" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## 4. 练习题\n", "\n", "上面的例子是一个三次的多项式,尝试使用二次的多项式去拟合它,看看最后能做到多好\n", "\n", "**提示:参数 `w = torch.randn(2, 1)`,同时重新构建 x 数据集**" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 2 }