|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893 |
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 线性模型的PyTorch实现\n",
- "\n",
- "本节简单回顾一下线性回归模型,并演示一下如何使用PyTorch来对线性回归模型进行建模和模型参数计算。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 1. 一元线性回归\n",
- "一元线性回归模型比较简单,假设有变量 $x_i$ 和目标 $y_i$,每个 i 对应于一个数据点,希望建立一个模型\n",
- "\n",
- "$$\n",
- "\\hat{y}_i = w x_i + b\n",
- "$$\n",
- "\n",
- "$\\hat{y}_i$ 是预测的结果,希望通过 $\\hat{y}_i$ 来拟合目标 $y_i$,通俗来讲就是找到这个函数拟合 $y_i$ 使得误差最小,即最小化\n",
- "\n",
- "$$\n",
- "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n",
- "$$"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "那么如何最小化这个误差呢?"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 2. 梯度下降法\n",
- "\n",
- "在梯度下降法中,首先要明确梯度的概念,梯度在数学上就是导数,如果是一个多元函数,那么梯度就是偏导数。比如一个函数$f(x, y)$,那么 $f$ 的梯度就是 \n",
- "\n",
- "$$\n",
- "(\\frac{\\partial f}{\\partial x},\\ \\frac{\\partial f}{\\partial y})\n",
- "$$\n",
- "\n",
- "可以称为 grad f(x, y) 或者 $\\nabla f(x, y)$。具体某一点 $(x_0,\\ y_0)$ 的梯度就是 $\\nabla f(x_0,\\ y_0)$。\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "梯度有什么意义呢?从几何意义来讲,一个点的梯度值是这个函数变化最快的地方。具体来说,对于函数 f(x, y),在点 $(x_0, y_0)$ 处,沿着梯度 $\\nabla f(x_0,\\ y_0)$ 的方向,函数增加最快,也就是说沿着梯度的方向,能够更快地找到函数的极大值点,或者反过来沿着梯度的反方向,能够更快地找到函数的最小值点。\n",
- "\n",
- "针对一元线性回归问题,就是沿着梯度的反方向,不断改变 $w$ 和 $b$ 的值,最终找到一组最好的 $w$ 和 $b$ 使得误差最小。\n",
- "\n",
- "在更新的时候,需要决定每次更新的幅度就是每次往下走的那一步的长度,这个长度称为学习率,用 $\\eta$ 表示。不同的学习率都会导致不同的结果,学习率太小会导致下降非常缓慢;学习率太大又会导致跳动非常明显。\n",
- "\n",
- "最后我们的更新公式就是\n",
- "\n",
- "$$\n",
- "w = w - \\eta \\frac{\\partial f(w,\\ b)}{\\partial w} \\\\\n",
- "b = b - \\eta \\frac{\\partial f(w,\\ b)}{\\partial b}\n",
- "$$\n",
- "\n",
- "通过不断地迭代更新,最终我们能够找到一组最优的 $w$ 和 $b$。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 3. PyTorch实现\n",
- "\n",
- "上面是原理部分,下面通过一个例子来进一步学习线性模型"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<torch._C.Generator at 0x7fb8b00453f0>"
- ]
- },
- "execution_count": 1,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import torch\n",
- "import numpy as np\n",
- "\n",
- "torch.manual_seed(2021)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[<matplotlib.lines.Line2D at 0x7fb7aef5dc70>]"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAORUlEQVR4nO3db4hs913H8c9ncwl2Y2tC7lo08e42Qmu1GIxjrakU2/gvqTQIeRCcIAZhEbQWn9jqQhHkgoIPoogtQ7BFsjZiTAVF04qlKtSkzLZp/jQqaXp3m0TN3ihVsg/iTb4+OLPZzWTuzpmbOb/znTPvFyyzc+bc4bu/2fu5v3vO748jQgCAvFbaLgAAcDKCGgCSI6gBIDmCGgCSI6gBILlTTbzp6dOnY2Njo4m3BoBO2tnZOR8Ra5NeaySoNzY2NBwOm3hrAOgk27sXe41LHwCQHEENAMkR1ACQHEENAMkR1ACQHEENJLG9LW1sSCsr1eP2dtsVIYtGhucBmM32trS5KR0cVM93d6vnktTvt1cXcqBHDSSwtXUU0ocODqrjAEENJLC3N9txLBeCGkjgzJnZjmO5ENRAAmfPSqurrz62ulodBwhqIIF+XxoMpPV1ya4eBwNuJKLCqA8giX6fYMZk9KgBIDmCGgCSI6gBIDmCGgCSI6gBIDmCGgCSI6gBIDmCGugglkwtq+n2ZsIL0DEsmVpWifZ2RMznnY7p9XoxHA7n/r4AptvYqMJi3Pq6dO5c6Wq6b17tbXsnInqTXuPSB9AxLJlaVon2JqiBjmHJ1LJKtDdBDXQMS6aWVaK9CWqgY1gytawS7c3NRABIgJuJQAsYy4x5qRXUtn/N9uO2H7P9Kdvf0nRhwCI7HFu7uytFVI933CGdPk1gY3ZTg9r2NZJ+VVIvIt4h6TJJtzddGLDItraOJkAc9/zzVYAT1phF3UsfpyS9wfYpSauSnm2uJGDxnTSG9uCgCvJxXCrBxUwN6oh4RtLvSdqT9O+SvhkRnx0/z/am7aHt4f7+/vwrBRbItDG040E+6VIJPW8cqnPp4ypJt0p6i6TvlHSF7TvGz4uIQUT0IqK3trY2/0qBBTJpbO1x40E+6VLJxXreWD51Ln38uKSvR8R+RPyfpPsl3dhsWcBiOxxbe/XVr31t0mQIpn3jJHWCek/Su2yv2rakmyQ90WxZWCZdvTbb70vnz0v33DN9MgTTvnGSOteoH5J0n6QvSXp09GcGDdeFJbEM12b7/WoVtZdfrh4nzVhj2jdOwsxEtIolOY9sb1fXpPf2qp702bNM+14mJ81MJKjRqpWVqic9zq56oMCyYAo50uLaLDAdQY1WcW0WmI6gRqtYkhOYjs1t0bp+n2AGTkKPGgCSI6gBIDmCGgCSI6gBIDmCGgCSI6gBIDmCGgCSI6gBIDmCGgCSI6ixkLq62UAX8NnMH0GNhbMImw0sa1gtwmeziFiPGgsn+2YDh2F1fLPa1dXlWGwq+2eTGRsHoFOybzawzGGV/bPJjI0D0CnZNxtY5h3Fs382i4qgxsLJvtnAModV9s9mURHUWDjZNxtY5rDK/tksKq5RAw1gR3HM6qRr1OzwAjSAXWswT1z6AIDkCGoASI6gBoDkCGoASI6gBoDkCGoASI6gBoDkCGoASI6gBoDkCGoASI6gBoDkCGoASC51UC/rvnOZ8ZkA5U1dPc/22yT92bFD10n6aETc1VRR0mv3nTvcJFNiVbK28JkA7ZhpPWrbl0l6RtIPR8SEXeEq81iPepn3ncuKzwRozjz3TLxJ0tdOCul5WeZ957LiMwHaMWtQ3y7pU5NesL1pe2h7uL+//7oLW+Z957LiMwHaUTuobV8u6QOS/nzS6xExiIheRPTW1tZed2HLvO9cVnwmQDtm6VHfLOlLEfGfTRVzHJtk5sNnArSj9s1E2/dK+kxEfGLauWxuCwCzed03E22vSvoJSffPszAAwHS1gjoiDiLi6oj4ZtMFAcCiKDUBbOqEFwDAa5WcAJZ6CjkAZLW1dRTShw4OquPzRlADwCUoOQGMoAaAS1ByAhhBDQCXoOQEMIIaAC5ByQlgjPoAgEvU75eZmUuPGkBndHVjC3rUADqhyxtb0KMG0AklxzWXRlAD6IQub2xBUAPohC5vbEFQA+iELm9sQVAD6IQub2zBqA8AnVFqXHNp9KgBIDmCGgCSI6gBIDmCGgCSI6gBIDmCGgCSI6iBGXR1dTbkxjhqoKYur86G3OhRAzV1eXU25EZQAzV1eXU25EZQAzV1eXU25EZQt4gbU4uly6uzITeCuiWHN6Z2d6WIoxtThHVeXV6dDbk5Iub+pr1eL4bD4dzft0s2NqpwHre+Lp07V7oaAG2zvRMRvUmv0aNuCTemANRFULeEG1MA6iKoW8KNKQB1EdQt4cYUgLqYQt6irm4bBGC+6FEDQHIENQAkVyuobV9p+z7b/2L7Cds/0nRhAIBK3WvUvy/pgYi4zfblklan/QEAwHxMDWrbb5L0Hkm/IEkR8aKkF5stCwBwqM6lj+sk7Uv6hO0v277b9hXjJ9netD20Pdzf3597oQCwrOoE9SlJN0j6WET8gKQXJH1k/KSIGERELyJ6a2trcy4TAJZXnaB+WtLTEfHQ6Pl9qoIbAFDA1KCOiP+Q9A3bbxsduknSVxutCgDwirqjPj4oaXs04uMpSXc2VxIA4LhaQR0RD0uauE4qAKBZzEwEgOQIagBIjqAGgOQIagBIjqAGxmxvV5sPr6xUj+wMj7axcQBwzPa2tLkpHRxUz3d3q+cSmzygPfSogWO2to5C+tDBQXUcaAtBDRyztzfbcaAEgho45syZ2Y4DJRDUwDFnz0qrY9tirK5Wx4G2ENQdwmiF16/flwYDaX1dsqvHwYAbiWgXoz46gtEK89Pv02bIhR51RzBaAegugrojGK0AdBdB3RGMVgC6i6DuCEYrAN1FUHcEoxWA7mLUR4cwWgHoJnrUAJAcQQ0AyRHUAJAcQQ0AyRHUAJAcQQ0AyRHUAJAcQQ0AyRHUAJAcQY1LwiYFQDlMIcfM2KQAKIseNWbGJgVAWQQ1ZsYmBUBZBDVmxiYFQFkENWbGJgVAWQQ1ZsYmBUBZjPrAJWGTAqAcetQAkFytHrXtc5L+V9JLki5ERK/JogAAR2a59PHeiDjfWCUAgIm49AEAydUN6pD0Wds7tjcnnWB70/bQ9nB/f39+FQLAkqsb1O+OiBsk3Szpl22/Z/yEiBhERC8iemtra3MtEgCWWa2gjohnR4/PSfq0pHc2WRQA4MjUoLZ9he03Hn4v6SclPdZ0YQCASp1RH2+W9Gnbh+f/aUQ80GhVAIBXTA3qiHhK0vUFagEATMDwPABIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOTSBPX2trSxIa2sVI/b221XBAA5nGq7AKkK5c1N6eCger67Wz2XpH6/vboAIIMUPeqtraOQPnRwUB0HgGWXIqj39mY7DgDLJEVQnzkz23EAWCYpgvrsWWl19dXHVler4wCw7GoHte3LbH/Z9l/Pu4h+XxoMpPV1ya4eBwNuJAKANNuojw9JekLSm5oopN8nmAFgklo9atvXSnq/pLubLQcAMK7upY+7JP26pJcvdoLtTdtD28P9/f151AYAUI2gtv0zkp6LiJ2TzouIQUT0IqK3trY2twIBYNnV6VG/W9IHbJ+TdK+k99m+p9GqAACvmBrUEfEbEXFtRGxIul3S5yLijsYrAwBIamitj52dnfO2d2f4I6clnW+ilgVDO1RoB9rg0DK1w/rFXnBElCxkchH2MCJ6bdfRNtqhQjvQBodoh0qKmYkAgIsjqAEguSxBPWi7gCRohwrtQBscoh2U5Bo1AODisvSoAQAXQVADQHLFgtr2T9v+V9tP2v7IhNdt+w9Grz9i+4ZStZVUox36o5//EdtfsH19G3U2bVo7HDvvh2y/ZPu2kvWVUqcdbP+Y7YdtP277H0rXWEKNvxffZvuvbH9l1A53tlFnayKi8S9Jl0n6mqTrJF0u6SuSvnfsnFsk/a0kS3qXpIdK1Fbyq2Y73CjpqtH3Ny9rOxw773OS/kbSbW3X3dLvw5WSvirpzOj5t7ddd0vt8JuSfnf0/Zqk/5J0edu1l/oq1aN+p6QnI+KpiHhR1Zoht46dc6ukP4nKg5KutP0dheorZWo7RMQXIuK/R08flHRt4RpLqPP7IEkflPQXkp4rWVxBddrh5yTdHxF7khQRXWyLOu0Qkt5o25K+VVVQXyhbZntKBfU1kr5x7PnTo2OznrPoZv0Zf1HV/zK6Zmo72L5G0s9K+njBukqr8/vwVklX2f687R3bP1+sunLqtMMfSnq7pGclPSrpQxFx0WWXu6aRtT4m8IRj4+MC65yz6Gr/jLbfqyqof7TRitpRpx3ukvThiHip6kR1Up12OCXpByXdJOkNkv7Z9oMR8W9NF1dQnXb4KUkPS3qfpO+W9He2/yki/qfh2lIoFdRPS/quY8+vVfUv46znLLpaP6Pt71e1m87NEfF8odpKqtMOPUn3jkL6tKRbbF+IiL8sUmEZdf9enI+IFyS9YPsfJV0vqUtBXacd7pT0O1FdpH7S9tclfY+kL5YpsWWFbhackvSUpLfo6GbB942d8369+mbiF9u+gN9SO5yR9KSkG9uut812GDv/k+rmzcQ6vw9vl/T3o3NXJT0m6R1t195CO3xM0m+Nvn+zpGcknW679lJfRXrUEXHB9q9I+oyqO7x/HBGP2/6l0esfV3Vn/xZVIXWg6l/QTqnZDh+VdLWkPxr1Ji9Ex1YPq9kOnVenHSLiCdsPSHpE1VZ4d0fEY+1VPX81fx9+W9InbT+qqjP34YhYluVPmUIOANkxMxEAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkvt/nElIdlbTfhoAAAAASUVORK5CYII=\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 生层测试数据\n",
- "x_train = np.random.rand(20, 1)\n",
- "y_train = x_train * 3 + 4 + 3*np.random.rand(20,1)\n",
- "\n",
- "# 画出图像\n",
- "import matplotlib.pyplot as plt\n",
- "%matplotlib inline\n",
- "\n",
- "plt.plot(x_train, y_train, 'bo')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 转换成 Tensor\n",
- "x_train = torch.from_numpy(x_train)\n",
- "y_train = torch.from_numpy(y_train)\n",
- "\n",
- "# 定义参数 w 和 b\n",
- "w = torch.randn(1, requires_grad=True) # 随机初始化\n",
- "b = torch.zeros(1, requires_grad=True) # 使用 0 进行初始化"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 构建线性回归模型\n",
- "def linear_model(x):\n",
- " return x * w + b\n",
- "\n",
- "def logistc_regression(x):\n",
- " return torch.sigmoid(x*w+b) "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "y_ = linear_model(x_train)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "经过上面的步骤我们就定义好了模型,在进行参数更新之前,我们可以先看看模型的输出结果长什么样"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7aee58d00>"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWs0lEQVR4nO3df2xd5X3H8c/XiSE4BYoSDwGp7XQaawKJl8RZQ4vISiAJBK0g+KNgBmmpIggwNA3RsqgrUpp1SFPa0hWoxTI04g41oaXalBVWfpSt5ZdDHVoSSFiwgwMTjkEZxESJ4+/+OL754V7b59r3nPPcc98vybq5517uffzc5MNzn5/m7gIAhKsm6wIAAEZHUANA4AhqAAgcQQ0AgSOoASBwk5N40enTp3tTU1MSLw0AubR169Z97l5f7LFEgrqpqUkdHR1JvDQA5JKZdY/0GF0fABA4ghoAAkdQA0DgEumjLubw4cPq6enRwYMH03rL3JsyZYpmzJih2trarIsCIEGpBXVPT49OPfVUNTU1yczSetvccnf19fWpp6dHM2fOzLo4ABKUWtfHwYMHNW3aNEK6TMxM06ZN4xtKjrS3S01NUk1NdNvennWJEIrUWtSSCOkyoz7zo71dWrVK6u+P7nd3R/clqbU1u3IhDAwmAgFYs+ZYSBf090fXAYI6pqamJu3bty/rYiCn9uwp7TqqS7BBnWR/nbtrcHCwfC8ITFBDQ2nXUV2CDOpCf113t+R+rL9uImHd1dWlWbNmafXq1Zo/f77Wrl2rhQsXau7cufrmN7959HlXXnmlFixYoPPOO09tbW1l+G2Asa1bJ9XVnXitri66DgQZ1En1173xxhu64YYbdO+992rv3r166aWX1NnZqa1bt+q5556TJG3YsEFbt25VR0eH7rvvPvX19U3sTYEYWlultjapsVEyi27b2hhIRCTVWR9xJdVf19jYqEWLFunOO+/Uk08+qXnz5kmSPvroI+3atUsXXXSR7rvvPv30pz+VJL399tvatWuXpk2bNrE3BmJobSWYUVyQLeqk+uumTp0qKeqjvvvuu9XZ2anOzk69+eabuummm/Tss8/qF7/4hZ5//nlt27ZN8+bNY54yKhJzstOVdH0HGdRJ99ctW7ZMGzZs0EcffSRJ2rt3r9577z3t379fZ5xxhurq6vT666/rhRdeKM8bAilKYowHI0ujvoMM6qT765YuXarrrrtOF1xwgebMmaNrrrlGH374oZYvX66BgQHNnTtX3/jGN7Ro0aLyvCGQIuZkpyuN+jZ3L9+rDWlpafHhBwfs2LFDs2bNKvt7VTvqFcPV1EQtu+HMJGalll+56tvMtrp7S9H3iPkCf2Vmr5nZ78zsX81sSvy3B6pTVv3EzMlOVxr1PWZQm9k5kv5SUou7ny9pkqQvla8IQP4U67e8/npp+vTkA5s52elKo77j9lFPlnSKmU2WVCfpnfIVAcifYv2WktTXN/JAU7la4MzJTlca9T1mULv7Xkn/IGmPpHcl7Xf3J8tXBFS7PE4lG23Of7GBpnLPHGhtlbq6oj7Sri5COmlJ13ecro8zJH1R0kxJZ0uaambXF3neKjPrMLOO3t7e8pYSuZXXqWRj9U8OD3JmamA0cbo+LpH0lrv3uvthST+R9LnhT3L3NndvcfeW+vr6cpcTOZXXgCrWb3m84UHO7nkYTZyg3iNpkZnVWbRT/RJJO5ItVrYefvhhvfPOsW74r371q9q+ffuEX7erq0s/+tGPSv7vVq5cqc2bN0/4/UOU14Aq9FsW232g2EATMzUwmjh91C9K2izpFUm/Hfpvkt9WLsOOy+FB/dBDD2n27NkTft3xBnWe5TmgWlulffukjRvHHmhipgZG5e5l/1mwYIEPt3379t+7NqKNG93r6tyjbsvop64uuj4BjzzyiC9cuNCbm5t91apVPjAw4DfeeKOfd955fv755/v69et906ZNPnXqVD/33HO9ubnZ+/v7ffHixf7yyy+7u/vUqVP9rrvu8vnz5/uSJUv8xRdf9MWLF/vMmTP9Zz/7mbu7v/XWW37hhRf6vHnzfN68ef6rX/3K3d0/+9nP+mmnnebNzc2+fv16HxgY8DvvvNNbWlp8zpw5/uCDD7q7++DgoN96660+a9Ysv/zyy/2yyy7zTZs2Ff2dSqrXACX0UVekjRvdGxvdzaLbaqyDaiapw0fI1DCDurHxxH+5hZ/GxtJ+82Hvf8UVV/ihQ4fc3f2WW27xe+65xy+55JKjz/nggw/c3U8I5uH3JfmWLVvc3f3KK6/0Sy+91A8dOuSdnZ3e3Nzs7u4HDhzwjz/+2N3dd+7c6YX6eOaZZ3zFihVHX/eHP/yhr1271t3dDx486AsWLPDdu3f7Y4895pdccokPDAz43r17/fTTT89tULsTUID76EEd5DanSXRcPvXUU9q6dasWLlwoSfr444+1fPly7d69W7fffrtWrFihpUuXjvk6J510kpYvXy5JmjNnjk4++WTV1tZqzpw56urqkiQdPnxYt912mzo7OzVp0iTt3Lmz6Gs9+eSTevXVV4/2P+/fv1+7du3Sc889p2uvvVaTJk3S2WefrYsvvnjcv3clYHtPYHRhBnVDQzRPq9j1cXJ33Xjjjfr2t799wvV169bpiSee0A9+8AP9+Mc/1oYNG0Z9ndra2qOnf9fU1Ojkk08++ueBgQFJ0ne+8x2deeaZ2rZtmwYHBzVlSvEV9+6u73//+1q2bNkJ17ds2cIJ4wCOCnL3vCRGVpYsWaLNmzfrvffekyS9//776u7u1uDgoK6++mqtXbtWr7zyiiTp1FNP1Ycffjju99q/f7/OOuss1dTU6JFHHtGRI0eKvu6yZcv0wAMP6PDhw5KknTt36sCBA7rooov06KOP6siRI3r33Xf1zDPPjLsseZXHRTJ5wWdTfmG2qAvfg9esibo7GhqikJ7A9+PZs2frW9/6lpYuXarBwUHV1tZq/fr1uuqqq44edFtoba9cuVI333yzTjnlFD3//PMlv9fq1at19dVXa9OmTfrCF75w9MCCuXPnavLkyWpubtbKlSt1xx13qKurS/Pnz5e7q76+Xo8//riuuuoqPf3005ozZ47OPfdcLV68eNy/dx4VFskU5l8XFslI4XShtLeX9a9vxaiEz6YSsc1phavGem1qKt4z1tgYLd/N2vCwkqIvhNWw30bon03IJrzNKRCS0BfJ5HW1ZRyhfzaViqBGxQl9kUw1h1Xon02lSjWok+hmqWbVWp+hr+Kr5rAK/bOpVKkF9ZQpU9TX11e14VJu7q6+vr4Rp/7lWej7LVdzWIX+2VSq1AYTDx8+rJ6eHh08eLDs71etpkyZohkzZqi2tjbromCYap31gfEbbTAxtaAGAIyMWR8AUMEIagAIXNBBzVLU8PCZAOkLcwm5WIoaIj4TIBvBDiayFDU8fCZAcipyMLGaV3eFis8EyEawQV3Nq7tCxWcCZCPYoK7m1V2h4jMBTpTW4HqwQc1S1PDwmQDHFAbXu7ujQ10Lg+tJhHWwg4kAELJyD65X5GAiAIQszcF1ghpAbqS5ICvNwXWCGkAupNlnLKU7uE5QA8iFtI9AS3NwncFEALlQUxO1pIczkwYH0y9PqRhMBJB7eV6QRVADyIU8L8giqAHkQp4XZBHUQAnYjztsra3RYpPBweg2DyEtBbwfNRAa9uNGVmhRAzGlPf0LKCCoM8TX6MrCftzICkGdkbRXUWHi8jz9C2EjqDPC1+jKk+fpXwgbQZ0RvkZXnjxP/0LYmPWRkYaG4nvZ8jU6bK2tBDPSF6tFbWafNLPNZva6me0wswuSLlje8TUaQFxxuz6+J+nn7v4ZSc2SdiRXpOrA12gAcY25e56ZnSZpm6RPe8yt9tg9DwBKM9Hd8z4tqVfSP5vZb8zsITObWuRNVplZh5l19Pb2TrDIAICCOEE9WdJ8SQ+4+zxJByR9ffiT3L3N3VvcvaW+vr7MxQSA6hUnqHsk9bj7i0P3NysKbgBACsYManf/X0lvm9kfD11aIml7oqUCABwVd9bH7ZLazexVSX8i6e8SKxGQMfZgQWhiLXhx905JRUcjgTxhK1OEiCXkOUJLcOLYgwUhYgl5TtASLA/2YEGIaFHnBC3B8mArU4SIoM4JWoLlwR4sCBFBnRO0BMuDPVgQIoI6J2gJlk9eT7JG5SKoc4KWIJBfzPrIETa1B/KJFjUABI6gxriwuAZID10fKBmLa4B00aJGyVhcA6SLoEbJWFwDpIugRslYXAOki6BGyVhcA6SLoEbJWFwDpItZHxgXFtcA6aFFDQCBI6gBIHAENQAEjqAGgMAR1AAQOIIaAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDAEdQAEDiCGgACR1ADQOAIagAIHEENAIELJqjb26WmJqmmJrptb8+6RAAQhiCO4mpvl1atkvr7o/vd3dF9ieOeACCIFvWaNcdCuqC/P7oOANUudlCb2SQz+42Z/Xu5C7FnT2nXAaCalNKivkPSjiQK0dBQ2nUAqCaxgtrMZkhaIemhJAqxbp1UV3fitbq66DoAVLu4LervSrpL0uBITzCzVWbWYWYdvb29JRWitVVqa5MaGyWz6LatjYFEAJBiBLWZXSHpPXffOtrz3L3N3VvcvaW+vr7kgrS2Sl1d0uBgdEtIA0AkTov685L+3My6JD0q6WIz25hoqQAAR40Z1O5+t7vPcPcmSV+S9LS7X594yQAAkgKZRw0AGFlJKxPd/VlJzyZSEgBAUbSoASBwBDUABI6gBoDAEdQAEDiCGgACR1ADQOAIagAIHEENAIEjqAEgcAQ1AASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHAENQAEjqAGgMAR1AAQOIIaAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDAEdQAEDiCGgACR1ADQOAIagAIHEENAIEjqAEgcAQ1AASOoAaAwBHUADBe7e1SU5NkJk2eHN02NUXXy2hyWV8NAKpFe7u0apXU3x/dP3Ikuu3ujq5LUmtrWd6KFjUAjMeaNcdCerj+/ujxMhkzqM3sU2b2jJntMLPXzOyOsr07AFSqPXsm9ngJ4rSoByT9tbvPkrRI0q1mNrtsJQCAStTQMLHHSzBmULv7u+7+ytCfP5S0Q9I5ZSsBAJRLYXCvpiaRQb0TrFsn1dUVf6yuLnq8TErqozazJknzJL1Y5LFVZtZhZh29vb1lKh4AxFQY3OvultyPDeolFdatrVJbm9TYGN2fNCm6bWyMrpdpIFGSzN3jPdHsE5J+KWmdu/9ktOe2tLR4R0dHGYoHADE1NUXhPFxjo9TVlXZpSmZmW929pdhjsVrUZlYr6TFJ7WOFNABkYqTBuzIO6mUlzqwPk/RPkna4+/rkiwQA4zDS4F0ZB/WyEqdF/XlJfyHpYjPrHPq5POFyAahUaQ7oHa/Y4F6ZB/WyMubKRHf/b0mWQlkAVLrhq/USWKU3osLrr1kTdXc0NEQhnfT7piD2YGIpGEwEqlSFD+hlacKDiQAQS44H9LJEUAMonxwP6GWJoAbyiAG9XCGogbxJe4Xe8Y5frWeWyCq9asRgIpA3DOhVJAYTgUqwevWxU0ImT47ujwcDerlDUANJi9NfvHq19MADx04JOXIkuj+esGZAL3cIaiAJ7e3S9OlR6/j668fuL25rK/46I10fDQN6uUNQA+XW3i595StSX1/xx4sd01RoSQ830vXRMKCXOwQ1UIo43Rhr1kiHDo3+OsP7iwt7GQ830vWxtLZGA4eDg9EtIV3RCGpguJHCOO60tziDdsP7iwv7YQw30nVUFYIaON5oYVzs1Oli3RhjDdoV6y++/37plluOtaAnTYru33//xH4f5ALzqIHjjTYHec+eKLyHM4u6GAoKfdTFuj+mTZO+9z26IvB7mEeN/ElqifRoc5DjTntrbZU2bIhCuWDaNGnjRmnfPkIaJSOoUXmSXCI9WhiXMu2ttTUKZffoh4DGBBDUyF6preO4fcXjMVoYM+0NGaGPGtkafiKIFAXjaAFYUxOvr3giZcrhKSEI22h91AQ1sjWeDYTYdAg5xGAiwjWeDYRYIo0qQ1CjNIX+5MIOb2YTm3Uxng2E6CtGlSGoEd/xsy2kY/tQTGTWxXhbxyyRRhUhqBFfsdkWBeOddUHrGBgTg4mIb6TZFgXlmnUBVCEGE6tF0geajrWHBRvTA4kgqPMijQNNi/UnFzDrAkgMQZ0XSa7WKzi+P1k6ttMb/cpAogjqLJWzqyKtA00Lsy3cpYGB6JZZF0CiCOqslLurggNNgdwiqLNS7q4KVusBuUVQZ6XcXRXMRwZya3LWBahaDQ3FNxaaSFdFayvBDOQQLeqs0FUBICaCOit0VQCIia6PLNFVASAGWtQAELiwgzrpvSsAoALECmozW25mb5jZm2b29URKMjyUV69Ofu8KAKgAY25zamaTJO2UdKmkHkkvS7rW3beP9N+UvM1psQNOzYpvqcm5eAByaKLbnP6ppDfdfbe7H5L0qKQvlrOARVfpjfQ/kHLvXQEAgYsT1OdIevu4+z1D105gZqvMrMPMOnp7e0srRSnhy94VAKpMnKC2Itd+r7nr7m3u3uLuLfX19aWVYqTwtWFvzYIQAFUoTlD3SPrUcfdnSHqnrKUYaZXezTezIARA1Yuz4OVlSX9kZjMl7ZX0JUnXlbUUhfBdsybqBmloiMKbUAaAsYPa3QfM7DZJT0iaJGmDu79W9pKwSg8Aioq1hNzdt0jaknBZAABFhL0yEQBAUANA6AhqAAgcQQ0AgRtzr49xvahZr6Qi50yNaLqkfWUvSOWhHiLUA3VQUE310OjuRVcLJhLUpTKzjpE2I6km1EOEeqAOCqiHCF0fABA4ghoAAhdKULdlXYBAUA8R6oE6KKAeFEgfNQBgZKG0qAEAIyCoASBwqQX1WAfkWuS+ocdfNbP5aZUtTTHqoXXo93/VzH5tZs1ZlDNpcQ9MNrOFZnbEzK5Js3xpiVMPZvZnZtZpZq+Z2S/TLmMaYvy7ON3M/s3Mtg3Vw5ezKGdm3D3xH0Xbo/6PpE9LOknSNkmzhz3nckn/oehEmUWSXkyjbGn+xKyHz0k6Y+jPl1VrPRz3vKcV7dx4TdblzujvwyclbZfUMHT/D7Iud0b18DeS7h36c72k9yWdlHXZ0/pJq0Ud54DcL0r6F4+8IOmTZnZWSuVLy5j14O6/dvcPhu6+oOhEnbyJe2Dy7ZIek/RemoVLUZx6uE7ST9x9jyS5ex7rIk49uKRTzcwkfUJRUA+kW8zspBXUcQ7IjXWIboUr9Xe8SdG3jLwZsx7M7BxJV0l6MMVypS3O34dzJZ1hZs+a2VYzuyG10qUnTj38o6RZio4B/K2kO9x9MJ3iZS/WwQFlEOeA3FiH6Fa42L+jmX1BUVBfmGiJshGnHr4r6WvufsSGH3KcH3HqYbKkBZKWSDpF0vNm9oK770y6cCmKUw/LJHVKuljSH0r6TzP7L3f/v4TLFoS0gjrOAbnJH6KbvVi/o5nNlfSQpMvcvS+lsqUpTj20SHp0KKSnS7rczAbc/fFUSpiOuP8u9rn7AUkHzOw5Sc2S8hTUcerhy5L+3qNO6jfN7C1Jn5H0UjpFzFhKgwWTJe2WNFPHBgvOG/acFTpxMPGlrDvwM6qHBklvSvpc1uXNsh6GPf9h5XMwMc7fh1mSnhp6bp2k30k6P+uyZ1APD0i6Z+jPZyo6aHt61mVP6yeVFrWPcECumd089PiDikb2L1cUUv2K/g+aKzHr4W8lTZN0/1BrcsBztntYzHrIvTj14O47zOznkl6VNCjpIXf/XXalLr+Yfx/WSnrYzH6rqDH3NXevlu1PWUIOAKFjZSIABI6gBoDAEdQAEDiCGgACR1ADQOAIagAIHEENAIH7f906gbeq7S1IAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
- "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "这个时候需要计算我们的误差函数,也就是\n",
- "\n",
- "$$\n",
- "E = \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n",
- "$$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(704.5194, dtype=torch.float64, grad_fn=<SumBackward0>)\n"
- ]
- }
- ],
- "source": [
- "# 计算误差\n",
- "def get_loss(y_, y):\n",
- " return torch.sum((y_ - y) ** 2)\n",
- "\n",
- "loss = get_loss(y_, y_train)\n",
- "print(loss)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "定义好了误差函数,接下来我们需要计算 $w$ 和 $b$ 的梯度了,这时得益于 PyTorch 的自动求导,不需要手动去算梯度就可以得到计算好的梯度值。手动计算的$w$ 和 $b$ 的梯度分别是\n",
- "\n",
- "$$\n",
- "\\frac{\\partial}{\\partial w} = \\frac{2}{n} \\sum_{i=1}^n x_i(w x_i + b - y_i) \\\\\n",
- "\\frac{\\partial}{\\partial b} = \\frac{2}{n} \\sum_{i=1}^n (w x_i + b - y_i)\n",
- "$$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 自动求导\n",
- "loss.backward()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor([-117.3280])\n",
- "tensor([-234.3059])\n"
- ]
- }
- ],
- "source": [
- "# 查看 w 和 b 的梯度\n",
- "print(w.grad)\n",
- "print(b.grad)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 更新一次参数\n",
- "w.data = w.data - 1e-2 * w.grad.data\n",
- "b.data = b.data - 1e-2 * b.grad.data"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "更新完成参数之后,我们再一次看看模型输出的结果"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7ac5d4a30>"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAW40lEQVR4nO3dfXBc1X3G8ednScYIHMLICkPiWoLOQG3wC7agTtKAAcc4mElgyB+hIsRJGUM8ULcdUkg9U5hxmJSZTFxIw4uGcUkjJ0wxCZOmNHXCS00HgyMRmRAbbGJkkKG1LIiDjV3b0q9/XEm2lZX2rrX33rN3v5+ZndXuXu2ePd59dHzueTF3FwAgXBOyLgAAYGwENQAEjqAGgMAR1AAQOIIaAAJXm8STTpkyxZubm5N4agDIpc7Ozj3u3ljosUSCurm5WR0dHUk8NQDkkpntHO0xuj4AIHAENQAEjqAGgMAl0kddyOHDh9XT06ODBw+m9ZK5N2nSJE2dOlV1dXVZFwVAglIL6p6eHk2ePFnNzc0ys7ReNrfcXX19ferp6dFZZ52VdXEAJCi1ro+DBw+qoaGBkC4TM1NDQwP/Q8mRtWul5mZpwoToeu3arEuEUKTWopZESJcZ9Zkfa9dKy5ZJH3wQ3d65M7otSa2t2ZULYeBkIhCAlSuPhvSQDz6I7gcI6piam5u1Z8+erIuBnHrzzdLuR3UJNqiT7K9zdw0MDJTvCYFxmjattPtRXYIM6qH+up07Jfej/XXjCevu7m5Nnz5dy5cv19y5c7Vq1SpdeOGFmjVrlu68887h466++mrNmzdP5513ntra2srwboDi7r5bqq8//r76+uh+IMigTqq/7rXXXtMNN9yge+65R7t27dKmTZvU1dWlzs5ObdiwQZK0Zs0adXZ2qqOjQ/fdd5/6+vrG96JADK2tUlub1NQkmUXXbW2cSEQk1VEfcSXVX9fU1KT58+frtttu0/r163XBBRdIkvbt26ft27fr4osv1n333acf//jHkqS33npL27dvV0NDw/heGIihtZVgRmFBBvW0aVF3R6H7x+OUU06RFPVRf/3rX9dNN9103OPPPvusfvGLX2jjxo2qr6/XggULGKcMIHNBdn0k3V93xRVXaM2aNdq3b58kadeuXdq9e7f27t2r008/XfX19Xr11Vf1wgsvlOcFAWAcgmxRD/33b+XKqLtj2rQopMv138JFixZp69at+vjHPy5JOvXUU9Xe3q7FixfrwQcf1KxZs3Tuuedq/vz55XlBABgPdy/7Zd68eT7Sli1b/uA+jB/1ikLa292bmtzNouv29qxLlG/lqG9JHT5KpgbZogZw4piOnq406jvIPmoAJ47p6OlKo74JaiBnmI6erjTqm6AGcobp6OlKo74JaiAhWa0vzXT0dKVR37GC2sz+2sx+Y2avmNkPzWxS+YoA5E+h9Wquv16aMiX5wGY6errSqO+iQW1mH5P0l5Ja3P18STWSvlC+IoTnkUce0dtvvz18+8Ybb9SWLVvG/bzd3d36wQ9+UPLvLV26VOvWrRv36yM9hU4wSVJf3/gXGIujtVXq7pYGBqJrQjpZSdd33K6PWkknm1mtpHpJbxc5fvwy3JdoZFA//PDDmjFjxrif90SDGpVnrBNJo40IYCsujKZoULv7LknfkvSmpHck7XX39SOPM7NlZtZhZh29vb3jK1US65xKam9v10UXXaQ5c+bopptuUn9/v5YuXarzzz9fM2fO1OrVq7Vu3Tp1dHSotbVVc+bM0YEDB7RgwQJ1dHRIimYx3n777Zo3b54WLlyoTZs2acGCBTr77LP1k5/8RFIUyJ/61Kc0d+5czZ07V88//7wk6Y477tBzzz2nOXPmaPXq1erv79fXvva14eVWH3roIUnRJKRbbrlFM2bM0JIlS7R79+5xvW+kr9iJpJFBntBHHnkx2kyYoYuk0yU9LalRUp2kJyRdP9bvjHtmYlOTe/R5Pf7S1FTqZJ/jXv+qq67yQ4cOubv7V7/6Vb/rrrt84cKFw8e899577u5+ySWX+C9/+cvh+4+9LcmffPJJd3e/+uqr/dOf/rQfOnTIu7q6fPbs2e7uvn//fj9w4IC7u2/bts2H6uOZZ57xJUuWDD/vQw895KtWrXJ394MHD/q8efN8x44d/vjjj/vChQv9yJEjvmvXLj/ttNP8scceG/V9Vbo8zqJrb3evry/8MS70UU7gI48Ko3HOTFwo6Q1375UkM/uRpE9Iai//n41BCQxMfOqpp9TZ2akLL7xQknTgwAEtXrxYO3bs0K233qolS5Zo0aJFRZ9n4sSJWrx4sSRp5syZOumkk1RXV6eZM2equ7tbknT48GHdcsst6urqUk1NjbZt21bwudavX6+XX355uP9579692r59uzZs2KDrrrtONTU1+uhHP6rLLrvshN936PI6i26o7CtWRP3Sxyo0IoCxzxhLnD7qNyXNN7N6i7a9vlzS1kRLlcDARHfXl770JXV1damrq0uvvfaa7r33Xm3evFkLFizQd7/7Xd14441Fn6eurm549+8JEybopJNOGv75yJEjkqTVq1frjDPO0ObNm9XR0aFDhw6NWqbvfOc7w2V64403hv9YVMsO43meRdfaKu3ZI7W3Fx8RwNhnjCVOH/WLktZJeknSrwd/J9k9qhIYmHj55Zdr3bp1w/297777rnbu3KmBgQFde+21WrVqlV566SVJ0uTJk/X++++f8Gvt3btXZ555piZMmKDvf//76u/vL/i8V1xxhR544AEdPnxYkrRt2zbt379fF198sR599FH19/frnXfe0TPPPHPCZQldNbQk44wIYOwzxhJrUSZ3v1PSnUUPLJcE1jmdMWOGvvGNb2jRokUaGBhQXV2dvv3tb+uaa64Z3uj2m9/8pqRoONzNN9+sk08+WRs3biz5tZYvX65rr71Wjz32mC699NLhDQtmzZql2tpazZ49W0uXLtWKFSvU3d2tuXPnyt3V2NioJ554Qtdcc42efvppzZw5U+ecc44uueSSE37foUtqk4hKk/TSvqhsFvVhl1dLS4sPjZIYsnXrVk2fPr3sr1XtKr1eR/ZRS1FLkgkaqDZm1unuLYUeYwo5MsUsOqA41qNG5tjUFRhbqi3qJLpZqhn1CVSH1IJ60qRJ6uvrI1zKxN3V19enSZNYHwvIu9S6PqZOnaqenh6Ne3o5hk2aNElTp07NuhgAEpZaUNfV1emss85K6+UAIDcY9QEAgSOoUZFCXxI09PIlqZrfe1IYnoeKE/pCTqGXL0nV/N6TlNrMRKBcmpsLTztvaorW0sha6OVLUjW/9/FiZiJyJfSFnEIvX5Kq+b0niaBGxQl9SdDQy5ekan7vSSKoUXFCXxI09PIlqZrfe5IIalSc0BdyCr18Sarm954kTiYCQAA4mQgAFYygBoDAEdQAEDiCGgACR1ADQOAIagAIHEGNkrAyGpC+okFtZueaWdcxl9+b2V+lUDZCITBDK6Pt3Cm5H10ZjX8XIFklTXgxsxpJuyT9qbsXWCMrUo4JLyOXS5SiqajMcsoOK6MBySnnhJfLJf12rJAul5Urjw9pKbq9cmXSr4zRsDIakI1Sg/oLkn5Y6AEzW2ZmHWbWUY4NbAmF8LAyGpCN2EFtZhMlfVbSY4Ued/c2d29x95bGxsZxF4xQCA8rowHZKKVF/RlJL7n7/yZVmGMRCuFhZTQgG6UE9XUapdsjCYRCmFpboxOHAwPRNf8eqGZpjUyLNerDzOolvSXpbHffW+x4ljkFkHflHpk27lEf7v6BuzfECWkAqAZpjkxjZiIAnIA0R6YR1ABwAtIcmUZQA8AJSHNkGkENIDfSXB8ozZFpteV/SgBI38hRGEOLhknJDSNtbU1niCotagC5kOf1gQhqALmQ5/WBCGoAuZDn9YEIagC5kOf1gQhqALmQ5/WBGPUBIDfSGoWRNlrUABA4ghooARsuIwt0fQAxZTGhApBoUQOx5XlCBcJGUAMx5XlCBcJGUGeI/s7KkucJFQgbQZ2Rof7OnTsl96P9nYR1uPI8oQJhI6gzQn9n5cnzhAqELdbmtqVic9viJkyIWtIjmUU7fAOoLuPe3BblR38ngLgI6ozQ3wkgLoI6I/R3Aogr1sxEM/uwpIclnS/JJX3F3TcmWK6qkNcFZACUV9wp5PdK+pm7f97MJkqqL/YLAIDyKBrUZvYhSRdLWipJ7n5I0qFkiwUAGBKnj/psSb2S/tnMfmVmD5vZKSMPMrNlZtZhZh29vb1lLygAVKs4QV0raa6kB9z9Akn7Jd0x8iB3b3P3FndvaWxsLHMxAaB6xQnqHkk97v7i4O11ioIbAJCCokHt7v8j6S0zO3fwrsslbUm0VACAYXFHfdwqae3giI8dkr6cXJEAAMeKFdTu3iWp4Bx0AECymJkIAIEjqIER2NABoWFzW+AYbGCLENGiBo7Bhg4IEUENHIMNbBEigjpH6FsdPzZ0wAlJ+MtHUOcEm+WWBxs6oCRr10pTpkjXX5/ol4+gzgn6VsuDDR0Q21DrqK/vDx8r85ePzW1zgs1ygZQ1N0et59GU+OVjc9sqQN8qkLJiZ5jL+OUjqHOCvlUgZWMFcZm/fAR1TtC3CqSsUOtIkhoayv7lY2ZijrBZLpCioS/bypVRN8i0aVF4J/AlJKgB4ESl1Dqi6wMAAkdQA8iPnE7PpesDQD7keOlDWtQ4ITltuKCS5Xh6Li1qlCzHDRdUshwvfUiLGiXLccMFlSzH03MJapQsxw0XVLIcT88lqFGyHDdcUMlyPD2XoEbJctxwQaVrbZW6u6NV67q7cxHSUsyTiWbWLel9Sf2Sjoy2FB+qQ4ozZwGotFEfl7r7nsRKgorCuiJAeuj6AIDAxQ1ql7TezDrNbFmSBQJQ4ZgNVXZxuz4+6e5vm9lHJP3czF519w3HHjAY4MskaRqn/4HqxGyoRJS8Z6KZ3SVpn7t/a7Rj2DMRqFKj7SPY1BSNwsCoxrVnopmdYmaTh36WtEjSK+UtIoBcYDZUIuL0UZ8h6b/NbLOkTZL+3d1/lmyxAFQkZkMlomgftbvvkDQ7hbIAqHR33318H7XEbKgyYHgegPLJ8TTuLLHMKYDyYjZU2dGiBoDAEdQAEDiCGsgjZgfmCn3UQN4wOzB3aFEDecNeablDUAN5w+zA3CGogbxhdmDuENRA3rBXWu4Q1EDeMDswdwhqIBTLl0u1tVG41tZGt09UTjd5rVYMzwNCsHy59MADR2/39x+9ff/92ZQJwaBFDYSgra20+1FVCGogaXFmCfb3F/7d0e5HVaHrA0jC2rXSihVSX9/x9482S7CmpnAo19QkV0ZUDFrUQLmtXSt95St/GNJDCs0SHArvkUa7H1WFFjVQbitXSocOjX3MyFmCQycM29qilnVNTRTSnEiECGqg/OJM1S40S/D++wlmFBRM1werMqIixPmgFpuqzSxBlCiIoB5alXHnTsn96PkWwhpBiftBvftuaeLEws/R0MAsQZQsiKBmVUYEZbRWc9wPamurtGZNFMpDGhqk9nZpzx5CGiUzdy/7k7a0tHhHR0fs4ydMiBooI5lFM2CB1IxcdF+Kuira2qQvfpEPKhJjZp3u3lLosSBa1KzKiGCM1Wrmg4qMxA5qM6sxs1+Z2U/LXQhWZUQwxlp0nw8qMlJKi3qFpK1JFIJVGRGMsVrNfFCRkVh91GY2VdL3JN0t6W/c/aqxji+1jxoIxlh91AQyElSOPup/lPS3kkY9Y2Jmy8ysw8w6ent7Sy8lEAJazQhQ0aA2s6sk7Xb3zrGOc/c2d29x95bGxsayFRAoKMkZUiy6j8DEmUL+SUmfNbMrJU2S9CEza3f365MtGjCKkd0To61IB+RESeOozWyBpNvoo0ammpujcB6pqSlqAQMVKPhx1EBJxhpCB+RQSUHt7s8Wa00DiWPiCaoMLWpkr9QTg0w8QZUhqJGtE1k6kSF0qDJBLMqEKsaJQUASJxMRMk4MAkUR1MgWJwaBoghqZIsTg0BRBDWyxYlBoCiCGqUZGkpnJtXWRtfjXWuDtTWAMcVZ6wOIjFxjo78/umatDSBRtKgRX6FtqoawGzGQGIIa8RUbMseQOiARBDXiKzZkjiF1QCIIasRXaCjdEIbUAYkhqBHfsUPpJKmmJrpmSB2QKEZ9oDStrQQykDJa1HmS5D6CADJDizov2EcQyC1a1HlRaIwzY5uBXCCo84LlQoHcIqjzguVCgdwiqPOC5UKB3CKo84LlQoHcIqizVO7hdCwXCuRS0eF5ZjZJ0gZJJw0ev87d70y6YLnHcDoAMcVpUf+fpMvcfbakOZIWm9n8REtVDRhOByCmoi1qd3dJ+wZv1g1ePMlCVQWG0wGIKVYftZnVmFmXpN2Sfu7uLxY4ZpmZdZhZR29vb5mLmUMMpwMQU6ygdvd+d58jaaqki8zs/ALHtLl7i7u3NDY2lrmYOcRwOgAxlTTqw91/J+lZSYuTKExVYTgdgJjijPpolHTY3X9nZidLWijpnsRLVg1YMhRADHFWzztT0vfMrEZRC/xf3f2nyRYLADAkzqiPlyVdkEJZAAAFMDMRAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDAEdQAEDiCGgACF3ZQl3urKgCoQHHW+sgGW1UBgKSQWtQjW88rVrBVFQAolBZ1odbzaNiqCkCVCaNFXWij19GwVRWAKhNGUMdtJbNVFYAqFEZQj9ZKbmhgqyoAVS+MoB5to9d775W6u6WBgeiakAZQhcIIajZ6BYBRhTHqQ2KjVwAYRRgtagDAqAhqAAgcQQ0AgSOoASBwBDUABM7cvfxPatYraYwFO/7AFEl7yl6QykM9RKgH6mBINdVDk7s3FnogkaAulZl1uHtL1uXIGvUQoR6ogyHUQ4SuDwAIHEENAIELJajbsi5AIKiHCPVAHQyhHhRIHzUAYHShtKgBAKMgqAEgcKkFtZktNrPXzOx1M7ujwONmZvcNPv6ymc1Nq2xpilEPrYPv/2Uze97MZmdRzqQVq4djjrvQzPrN7PNpli8tcerBzBaYWZeZ/cbM/ivtMqYhxvfiNDP7NzPbPFgPX86inJlx98Qvkmok/VbS2ZImStosacaIY66U9B+STNJ8SS+mUbY0LzHr4ROSTh/8+TPVWg/HHPe0pCclfT7rcmf0efiwpC2Spg3e/kjW5c6oHv5O0j2DPzdKelfSxKzLntYlrRb1RZJed/cd7n5I0qOSPjfimM9J+hePvCDpw2Z2ZkrlS0vRenD35939vcGbL0iamnIZ0xDn8yBJt0p6XNLuNAuXojj18OeSfuTub0qSu+exLuLUg0uabGYm6VRFQX0k3WJmJ62g/pikt4653TN4X6nHVLpS3+NfKPpfRt4UrQcz+5ikayQ9mGK50hbn83COpNPN7Fkz6zSzG1IrXXri1MM/SZou6W1Jv5a0wt0H0ile9tLa4cUK3DdyXGCcYypd7PdoZpcqCuo/S7RE2YhTD/8o6XZ3748aUbkUpx5qJc2TdLmkkyVtNLMX3H1b0oVLUZx6uEJSl6TLJP2xpJ+b2XPu/vuEyxaEtIK6R9IfHXN7qqK/jKUeU+livUczmyXpYUmfcfe+lMqWpjj10CLp0cGQniLpSjM74u5PpFLCdMT9Xuxx9/2S9pvZBkmzJeUpqOPUw5cl/YNHndSvm9kbkv5E0qZ0ipixlE4W1EraIeksHT1ZcN6IY5bo+JOJm7LuwM+oHqZJel3SJ7Iub5b1MOL4R5TPk4lxPg/TJT01eGy9pFcknZ912TOohwck3TX48xmSdkmaknXZ07qk0qJ29yNmdouk/1R0hneNu//GzG4efPxBRWf2r1QUUh8o+guaKzHr4e8lNUi6f7A1ecRztnpYzHrIvTj14O5bzexnkl6WNCDpYXd/JbtSl1/Mz8MqSY+Y2a8VNeZud/dqWf6UKeQAEDpmJgJA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDAEdQAELj/ByZw6bCS3nICAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "y_ = linear_model(x_train)\n",
- "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
- "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "从上面的例子可以看到,更新之后红色的线跑到了蓝色的线下面,没有特别好的拟合蓝色的真实值,所以我们需要在进行几次更新"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "epoch: 19, loss: 21.218688263809952\n",
- "epoch: 39, loss: 19.55484974487415\n",
- "epoch: 59, loss: 18.824963796393106\n",
- "epoch: 79, loss: 18.50477882805245\n",
- "epoch: 99, loss: 18.364321569910107\n"
- ]
- }
- ],
- "source": [
- "for e in range(100): # 进行 100 次更新\n",
- " y_ = linear_model(x_train)\n",
- " loss = get_loss(y_, y_train)\n",
- " \n",
- " w.grad.zero_() # 注意:归零梯度\n",
- " b.grad.zero_() # 注意:归零梯度\n",
- " loss.backward()\n",
- " \n",
- " w.data = w.data - 1e-2 * w.grad.data # 更新 w\n",
- " b.data = b.data - 1e-2 * b.grad.data # 更新 b \n",
- " if (e + 1) % 20 == 0:\n",
- " print('epoch: {}, loss: {}'.format(e, loss.item()))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7ac548be0>"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWwklEQVR4nO3df2zc9X3H8dfbTiA4SSlNPARkiWESjEB+EBuWtmsIkIbQoAKi0spMId2QgQzGJrUDFmllSilDmpqSrvxwEUWrTdEIBXUTa1n5MZAITW1qKCUlocEOCWxxDEohTpbEee+Pr+045i7+Xvy97/dz33s+JOt8d9/cfe5zl5c/9/l+fpi7CwAQrpqsCwAAODKCGgACR1ADQOAIagAIHEENAIGbUI4HnT59ujc0NJTjoQEglzo7O3e6e32h+8oS1A0NDero6CjHQwNALplZT7H76PoAgMAR1AAQOIIaAAJXlj7qQvbv369t27Zp7969aT1l7k2aNEkzZszQxIkTsy4KgDJKLai3bdumqVOnqqGhQWaW1tPmlrurr69P27Zt06mnnpp1cQCUUWpdH3v37tW0adMI6YSYmaZNm8Y3lBxpb5caGqSamuiyvT3rEiEUqbWoJRHSCaM+86O9XWppkfr7o+s9PdF1SWpuzq5cCAMnE4EArFp1KKSH9PdHtwMEdUwNDQ3auXNn1sVATm3dWtrtqC7BBnU5++vcXQcPHkzuAYFxmjmztNtRXYIM6qH+up4eyf1Qf914wrq7u1tnnnmmVq5cqQULFmj16tU699xzNXfuXH3jG98YPu7yyy9XY2OjzjrrLLW2tibwaoCx3XmnVFd3+G11ddHtQJBBXa7+ujfffFPXXHON7r77bm3fvl0bNmxQV1eXOjs79cILL0iSHnroIXV2dqqjo0Nr165VX1/f+J4UiKG5WWptlWbNksyiy9ZWTiQikuqoj7jK1V83a9YsLVy4UF/72tf09NNP65xzzpEkffTRR9q8ebMWLVqktWvX6oknnpAkvfPOO9q8ebOmTZs2vicGYmhuJphRWJBBPXNm1N1R6PbxmDx5sqSoj/r222/X9ddff9j9zz//vH7+859r/fr1qqur0+LFixmnDCBzQXZ9lLu/7uKLL9ZDDz2kjz76SJK0fft27dixQ7t27dIJJ5yguro6/fa3v9XLL7+czBMCwDgE2aIe+vq3alXU3TFzZhTSSX0tXLp0qTZu3KhPf/rTkqQpU6aora1Ny5Yt0/3336+5c+fqjDPO0MKFC5N5QgAYB3P3xB+0qanJR28csHHjRp155pmJP1e1o16BfDCzTndvKnRfkF0fAIBDCGoACBxBDQCBI6iBHGLJ1HSVu76DHPUB4OixZGq60qhvWtRAzrBkaora23X+tQ36sN+0XxM0INPbatBl/e2J1jdBXcDDDz+sd999d/j6ddddpzfeeGPcj9vd3a1HHnmk5H+3YsUKrVu3btzPj+rAkqkpGWxKzxjoUY2kCRpQjaQG9ej7atFne5Lr/wg3qDPsZBsd1A8++KBmz5497sc92qAGSsGSqSkp9NVl0GT16+7a5JrUYQZ1OdY5ldTW1qbzzjtP8+fP1/XXX6+BgQGtWLFCZ599tubMmaM1a9Zo3bp16ujoUHNzs+bPn689e/Zo8eLFGprAM2XKFN16661qbGzUkiVLtGHDBi1evFinnXaafvKTn0iKAvlzn/ucFixYoAULFuill16SJN1222168cUXNX/+fK1Zs0YDAwP6+te/Przc6gMPPCApWovkpptu0uzZs7V8+XLt2LFjXK8b1YUlU1MyxleUUwYS/Arj7on/NDY2+mhvvPHGx24ratYs9yiiD/+ZNSv+YxR4/ksvvdT37dvn7u433nij33HHHb5kyZLhYz744AN3dz///PP9l7/85fDtI69L8qeeesrd3S+//HL//Oc/7/v27fOuri6fN2+eu7vv3r3b9+zZ4+7umzZt8qH6eO6553z58uXDj/vAAw/46tWr3d1979693tjY6Fu2bPHHH3/clyxZ4gcOHPDt27f78ccf74899ljR1wWM1tYW/Xcxiy7b2rIuUQ4Vy6mjzCtJHV4kU8Mc9VGGTrZnnnlGnZ2dOvfccyVJe/bs0bJly7RlyxbdfPPNWr58uZYuXTrm4xxzzDFatmyZJGnOnDk69thjNXHiRM2ZM0fd3d2SpP379+umm25SV1eXamtrtWnTpoKP9fTTT+u1114b7n/etWuXNm/erBdeeEFXXXWVamtrdfLJJ+vCCy886teN6sSSqSm4887Dh3uMlPBXmDC7PsrQyebuuvbaa9XV1aWuri69+eabuueee/Tqq69q8eLF+t73vqfrrrtuzMeZOHHi8O7fNTU1OvbYY4d/P3DggCRpzZo1OvHEE/Xqq6+qo6ND+/btK1qm7373u8Nlevvtt4f/WLDDeOVjLHPOjdztQZJqa6PLMuz6ECuozexvzew3Zva6mf3IzCYlVoJCytDJdtFFF2ndunXD/b3vv/++enp6dPDgQV155ZVavXq1XnnlFUnS1KlT9eGHHx71c+3atUsnnXSSampq9MMf/lADAwMFH/fiiy/Wfffdp/3790uSNm3apN27d2vRokV69NFHNTAwoPfee0/PPffcUZcF2Sh0muXqq6Xp0wnsXGlulrq7ozf5wIHosrs78a8zY3Z9mNkpkv5a0mx332Nm/ybpy5IeTrQkI5VhndPZs2frm9/8ppYuXaqDBw9q4sSJ+va3v60rrrhieKPbu+66S1I0HO6GG27Qcccdp/Xr15f8XCtXrtSVV16pxx57TBdccMHwhgVz587VhAkTNG/ePK1YsUK33HKLuru7tWDBArm76uvr9eSTT+qKK67Qs88+qzlz5uj000/X+eeff9SvG9koNiCgr4/JJyjdmMucDgb1y5LmSfq9pCclrXX3p4v9G5Y5TQ/1GqaamqhxVcysWVHDa6T29vKtwY7wjWuZU3ffLumfJW2V9J6kXYVC2sxazKzDzDp6e3vHW2agoo11OmX0efEyjUhFTowZ1GZ2gqTLJJ0q6WRJk83s6tHHuXuruze5e1N9fX3yJQUqSKHTLCONDnKmfSckp2dw45xMXCLpbXfvdff9kn4s6TNH82RjdbOgNNRnuIYGBBTawL7QeXGmfScgx19L4gT1VkkLzazOojFjF0naWOoTTZo0SX19fYRLQtxdfX19mjSpvANw0pDTRpCam6WdO6W2tqhP2qz4yC2mfScgx19LYu2ZaGb/KOnPJB2Q9CtJ17n7/xU7vtDJxP3792vbtm3au3fv+EqMYZMmTdKMGTM0ceLErIty1EYvESlFLc6Eh6EGj3pIQLEzuGbS4MiukB3pZGJqm9sChTQ0RN9QRys0KiLvGPUxThX+YWJzWwSLvtlDhuZOHDxYljkT+Zfj1agIamSKvlkkZuSU7iOdEKhABDUyleNGELKQ068lBDUyleNGEJCYMJc5RVVhSU7gyGhRA0DgCGoAycrrDKYM0fUBIDmjZ+4MTeOW6N8aB1rUAJKT42ncWSKoASSHGUxlQVADSA4zmMqCoAaQHGYwlQVBDSA5zGAqC0Z9AEgWM5gSR4saFYmhuuHivUkeQY2KUwk7LlVrWFXCe1OJ2DgAFSf09eGrebeW0N+bkLHDC3Il9B2XqjmsQn9vQsYOL8iV0IfqBjHnI6O+l9Dfm0pFUKPihD5UN/OwyrCjOPT3plIR1Kg4oQ/VzTysMlxvI/T3plLRRw2UQaY7itNRXJGO1EfNhBegDDKd8zFzZuGzmXQUVyy6PoC8ybzvBUkjqIG8oaM4dwhqIBQrV0oTJkThOmFCdP1oNTdHg7YPHowuCemKRh81EIKVK6X77jt0fWDg0PV7782mTAgGLWogBK2tpd2OqkJQAyEYGCjtdlQVghool6Fp3EN9zmbFp3PX1hZ+jGK3o6oQ1EA5rFwpfeUrh8YzD7WMi03nbmkp/DjFbkdVIaiBpLW3S/ffX3h2oFR4Ove990o33nioBV1bG13nRCLEFHIgecXWOR2J6dwYhWVOgaTEWT40znqmTOdGCQhqIK64y4eOFcJM50aJgg7qat13LmRV/Z7EXT600FobQ5jOjaPh7kf8kXSGpK4RP7+X9DdH+jeNjY0+Xm1t7nV17lHTJfqpq4tuRzaq5j1pa3OfNcvdLLoceoFmh7/4oR+z+I8BFCGpw4tkakknE82sVtJ2SX/i7kXPliRxMrGa950LVVW8J0famXbVqiqoAGQlyZOJF0n63ZFCOilB7DuHw1TFe3Kk7g2WD0VGSg3qL0v6UaE7zKzFzDrMrKO3t3fcBct83zl8TFW8J0f6a8TyochI7KA2s2MkfVHSY4Xud/dWd29y96b6+vpxF4zGS3iq4j0Z668Ry4ciA6W0qC+R9Iq7/2+5CjMSjZfwVMV7UhV/jVBpYp9MNLNHJf3M3X8w1rHMTERFy3RnWlSrI51MjBXUZlYn6R1Jp7n7rrGOJ6gBoDTjHvXh7v3uPi1OSAOpqOqZNwhFWh9DtuJC5Rk91nloKrdEFwVSk+bHkNXzUHmqYuYNQpf0x5DV85AvVTHzBqFL82NIUKPyVMXMG4QuzY8hQY3slXpGhrHOCECaH0OCGtmKu8bzSFUx8wahS/NjyMlEZIsTg4AkTiYiZJwYRILyOryeoEa2ODGIhBxNL1qlIKiRLU4MIiFxd0qrRAQ1ssWJQSQkz71oBDWOTpKdgazxjATkuReNoEZp2tul6dOlq6/OZ2cgKlaee9EIasQ3dLamr+/j9+WlMxAVK8+9aIyjRnzFxjwPMYu6LwCUjHHUSMZYZ2Xy0BkIBIigRnxHCuK8dAYCASKoEV+hszWSNG1afjoDgQAR1Iiv0NmatjZp505CGigjtuJCaZqbCWUgZbSoASBwBHWe5HXpsIBQxcgCXR95wc7cZUcVIytMeMkLFuAvO6oY5cSEl2qQ56XDAkEVIysEdV7keemwQFDFyApBnaFET0zleemwQFDFyApBnZHEtw3K89JhgaCKkRVOJmaEE1MARuJkYoC2bpWuUrveVoMGVKO31aCr1M6JKQAfwzjqjNz0qXbd1deiyYoG5TaoR99Xi6Z/SpL4Lg3gEFrUGfmWVg2H9JDJ6te3xC4pAA5HUGdkyvuF+ziK3Q6gehHUWWFQLoCYCOqsMCgXQEwEdVYYlAsgplijPszsk5IelHS2JJf0F+6+vozlqg4swg8ghrjD8+6R9FN3/5KZHSOpwMZ5AIByGDOozewTkhZJWiFJ7r5P0r7yFgsAMCROH/Vpknol/cDMfmVmD5rZ5NEHmVmLmXWYWUdvb2/iBQWAahUnqCdIWiDpPnc/R9JuSbeNPsjdW929yd2b6uvrEy4mAFSvOEG9TdI2d//F4PV1ioIbAJCCMYPa3f9H0jtmdsbgTRdJeqOspQIADIs76uNmSe2DIz62SPpq+YoEABgpVlC7e5ekguukAgDKi5mJABA4ghoAAkdQA0DgCGoACFzYQd3eHu0CW1MTXR71Ft1AfHzsEJpw90xsb5daWqT+we2qenqi6xIrzqFs+NghROG0qEc3Y2655dD/liH9/dIq9hRE+axaxccO4QmjRV2oGVPMVvYURPkU+3jxsUOWwmhRF2rGFMOegigjtrJEiMII6rjNFfYURJmxlSVCFEZQF2uuTJvGnoIlYLTC+LGVJUJk7p74gzY1NXlHR0f8fzC6j1qKmjH8D4mNKgQqm5l1unvBNZXCaFHTjBk3RisA+RVGixrjVlMjFXorzaSDB9MvD4DShN+ixrgxWgHIL4I6JxitAOQXQZ0TdPMD+RXGzEQkormZYAbyiBY1AASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHAENQAEjqDGUWGTAiA9TCFHyQrtRdzSEv3OFHYgebSoUTI2KQDSRVCjZMX2Io67RzGA0hDUKBmbFADpIqhRMjYpANJFUKNkbFIApItRHzgqbFIApIcWNQAELlaL2sy6JX0oaUDSgWJbmgMAkldK18cF7r6zbCUBABRE1wcABC5uULukp82s08xaCh1gZi1m1mFmHb29vcmVEACqXNyg/qy7L5B0iaS/MrNFow9w91Z3b3L3pvr6+kQLCQDVLFZQu/u7g5c7JD0h6bxyFgoAcMiYQW1mk81s6tDvkpZKer3cBQMAROKM+jhR0hNmNnT8I+7+07KWCgAwbMygdvctkualUBYAQAEMzwOAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHAENQAEjqAGgMAR1AAQOIIaAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDAEdQAEDiCGgACR1ADQOAIagAIHEENAIEjqAEgcAQ1AASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBCyao29ulhgappia6bG/PukQAEIYJWRdAikK5pUXq74+u9/RE1yWpuTm7cgFACIJoUa9adSikh/T3R7cDQLULIqi3bi3tdgCoJkEE9cyZpd0OANUkiKC+806pru7w2+rqotsBoNrFDmozqzWzX5nZfyRdiOZmqbVVmjVLMosuW1s5kQgAUmmjPm6RtFHSJ8pRkOZmghkAConVojazGZKWS3qwvMUBAIwWt+vjO5L+TtLBYgeYWYuZdZhZR29vbxJlAwAoRlCb2aWSdrh755GOc/dWd29y96b6+vrECggA1S5Oi/qzkr5oZt2SHpV0oZm1lbVUAIBhYwa1u9/u7jPcvUHSlyU96+5Xl71kAABJZVrro7Ozc6eZ9ZTwT6ZL2lmOslQY6iFCPVAHQ6qpHmYVu8PcPc2CFC6EWYe7N2VdjqxRDxHqgToYQj1EgpiZCAAojqAGgMCFEtStWRcgENRDhHqgDoZQDwqkjxoAUFwoLWoAQBEENQAELrWgNrNlZvammb1lZrcVuN/MbO3g/a+Z2YK0ypamGPXQPPj6XzOzl8xsXhblLLex6mHEceea2YCZfSnN8qUlTj2Y2WIz6zKz35jZf6ddxjTE+H9xvJn9u5m9OlgPX82inJlx97L/SKqV9DtJp0k6RtKrkmaPOuYLkv5TkklaKOkXaZQtzZ+Y9fAZSScM/n5JtdbDiOOelfSUpC9lXe6MPg+flPSGpJmD1/8g63JnVA9/L+nuwd/rJb0v6Zisy57WT1ot6vMkveXuW9x9n6I1Qy4bdcxlkv7VIy9L+qSZnZRS+dIyZj24+0vu/sHg1ZclzUi5jGmI83mQpJslPS5pR5qFS1GcevhzST92962S5O55rIs49eCSppqZSZqiKKgPpFvM7KQV1KdIemfE9W2Dt5V6TKUr9TX+paJvGXkzZj2Y2SmSrpB0f4rlSlucz8Ppkk4ws+fNrNPMrkmtdOmJUw//IulMSe9K+rWkW9y96LLLeVOWtT4KsAK3jR4XGOeYShf7NZrZBYqC+k/LWqJsxKmH70i61d0HokZULsWphwmSGiVdJOk4SevN7GV331TuwqUoTj1cLKlL0oWS/kjSf5nZi+7++zKXLQhpBfU2SX844voMRX8ZSz2m0sV6jWY2V9FuOpe4e19KZUtTnHpokvToYEhPl/QFMzvg7k+mUsJ0xP1/sdPdd0vabWYvSJonKU9BHacevirpnzzqpH7LzN6W9MeSNqRTxIyldLJggqQtkk7VoZMFZ406ZrkOP5m4IesO/IzqYaaktyR9JuvyZlkPo45/WPk8mRjn83CmpGcGj62T9Lqks7Muewb1cJ+kOwZ/P1HSdknTsy57Wj+ptKjd/YCZ3STpZ4rO8D7k7r8xsxsG779f0Zn9LygKqX5Ff0FzJWY9/IOkaZLuHWxNHvCcrR4Wsx5yL049uPtGM/uppNcUbYX3oLu/nl2pkxfz87Ba0sNm9mtFjblb3b1alj9lCjkAhI6ZiQAQOIIaAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABO7/ARH9usRSKFMjAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "y_ = linear_model(x_train)\n",
- "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
- "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "经过 100 次更新,可以发现红色的预测结果已经比较好的拟合了蓝色的真实值。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 4. 多项式回归模型"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "下面更进一步尝试一下多项式回归,下面是关于 x 的多项式:\n",
- "\n",
- "$$\n",
- "\\hat{y} = w_0 + w_1 x + w_2 x^2 + w_3 x^3 \n",
- "$$\n",
- "\n",
- "这样就能够拟合更加复杂的模型,这里使用了 $x$ 的更高次,同理还有多元回归模型,形式也是一样的,只是除了使用 $x$,还是更多的变量,比如 $y$、$z$ 等等,同时他们的 $loss$ 函数和简单的线性回归模型是一致的。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "首先定义一个需要拟合的目标函数,这个函数是个三次的多项式"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3\n"
- ]
- }
- ],
- "source": [
- "# 定义一个多变量函数\n",
- "\n",
- "w_target = np.array([0.5, 3, 2.4]) # 定义参数\n",
- "b_target = np.array([0.9]) # 定义参数\n",
- "\n",
- "f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(\n",
- " b_target[0], w_target[0], w_target[1], w_target[2]) # 打印出函数的式子\n",
- "\n",
- "print(f_des)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "多项式的的曲线绘制"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7ac4f25b0>"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAlSklEQVR4nO3deXxU9b3/8dcnC4RAIIBhDRCQNaAIBBC1aIuite7WpValdYGqbbWbdemv1rbe2ta2t96rt8UFUajCddeqBa1KXZBdEcIma8KSsIYl68zn90cGbrRhSTLJmZm8n48Hj5k558z5fs4A73zznXO+x9wdERFJTElBFyAiIo1HIS8iksAU8iIiCUwhLyKSwBTyIiIJTCEvIpLAjjnkzexxMysys09rLOtgZrPNbHXksX2NdXea2RozW2lmZ0e7cBERObq69OSfAM75wrI7gLfcvR/wVuQ1ZpYLXAkMjrznYTNLbnC1IiJSJynHuqG7zzGznC8svhA4I/J8KvAO8NPI8mfcvRxYZ2ZrgFHAh0dq47jjjvOcnC82ISIiR7Jw4cLt7p5V27pjDvnD6OzuWwDcfYuZdYos7w7MrbFdQWTZEeXk5LBgwYIGliQi0ryY2YbDrWusL16tlmW1zp9gZhPNbIGZLSguLm6kckREmqeGhvw2M+sKEHksiiwvAHrU2C4b2FzbDtx9srvnuXteVlatv22IiEg9NTTkXwYmRJ5PAF6qsfxKM2tpZr2BfsC8BrYlIiJ1dMxj8mb2NNVfsh5nZgXAPcD9wEwzux7YCFwG4O7LzGwmsByoAm5x91B9CqysrKSgoICysrL6vF1qkZaWRnZ2NqmpqUGXIiKNzGJpquG8vDz/4hev69atIyMjg44dO2JW21C/1IW7s2PHDvbu3Uvv3r2DLkdEosDMFrp7Xm3rYv6K17KyMgV8FJkZHTt21G9GIs1EzIc8oICPMn2eIs1HXIS8iEgim/L+OmYv39Yo+1bIN4GcnBy2b98edBkiEoP2lFbyuzdWMnv51kbZv0K+DtydcDgcdBkxU4eINNyzCwsorQxx7ZicRtm/Qv4o1q9fz6BBg7j55psZPnw4mzZt4ve//z0jR47kxBNP5J577jm07UUXXcSIESMYPHgwkydPPuq+33jjDYYPH87QoUMZN24cAL/4xS944IEHDm0zZMgQ1q9f/291/OpXv+L2228/tN0TTzzB9773PQCmTZvGqFGjOOmkk5g0aRKhUL3OXhWRRhYOO099uJ4RvdozpHu7RmmjoXPXNKl7X1nG8s0lUd1nbre23HP+4CNus3LlSqZMmcLDDz/MrFmzWL16NfPmzcPdueCCC5gzZw5jx47l8ccfp0OHDpSWljJy5EguvfRSOnbsWOs+i4uLufHGG5kzZw69e/dm586dR621Zh3FxcWMGTOG3/3udwDMmDGDu+++m/z8fGbMmMH7779PamoqN998M9OnT+faa6+t+4cjIo3qX2u2s37HAX5wVv9GayOuQj4ovXr14uSTTwZg1qxZzJo1i2HDhgGwb98+Vq9ezdixY3nwwQd54YUXANi0aROrV68+bMjPnTuXsWPHHjpXvUOHDnWqIysriz59+jB37lz69evHypUrOfXUU3nooYdYuHAhI0eOBKC0tJROnTodabciEpAnP1jPcW1a8tUhXRutjbgK+aP1uBtL69atDz13d+68804mTZr0uW3eeecd3nzzTT788EPS09M544wzjnguurvXeipjSkrK58bba+6jZh0AV1xxBTNnzmTgwIFcfPHFmBnuzoQJE/jNb35T5+MUkaazcccB/rmyiO99uS8tUhpv5Fxj8nV09tln8/jjj7Nv3z4ACgsLKSoqYs+ePbRv35709HRWrFjB3Llzj7ifMWPG8O6777Ju3TqAQ8M1OTk5LFq0CIBFixYdWl+bSy65hBdffJGnn36aK664AoBx48bx7LPPUlRUdGi/GzYcdhZSEQnItI82kGTGVaN7NWo7cdWTjwXjx48nPz+fMWPGANCmTRumTZvGOeecw1/+8hdOPPFEBgwYcGhY5XCysrKYPHkyl1xyCeFwmE6dOjF79mwuvfRSnnzySU466SRGjhxJ//6HH6tr3749ubm5LF++nFGjRgGQm5vLr3/9a8aPH084HCY1NZWHHnqIXr0a9x+SiBy70ooQM+Zv4uzBnenSLq1R24r5uWvy8/MZNGhQQBUlLn2uIsGZOX8Ttz/3Cc9MPJmT+9T+vV1dxPXcNSIiicTdeeKD9QzonMHo3kc/4aKhFPIiIk1o0cZdLN9SwjVjejXJPFJxEfKxNKSUCPR5igRn6gcbyGiZwsXDjnrb66iI+ZBPS0tjx44dCqYoOTiffFpa437ZIyL/buueMl5buoWv52XTumXTnPcS82fXZGdnU1BQgG7yHT0H7wwlIk1r6ofrCbvz7VOa7oY9MR/yqampuoORiMS9/eVVTJ+7gbMHd6Fnx/Qmazfmh2tERBLBswsLKCmr4oYv9WnSdhXyIiKNLBR2HntvHcN6ZjKiV/smbVshLyLSyGYv38bGnQe4sYl78RClkDezH5jZMjP71MyeNrM0M+tgZrPNbHXksWl/fImIxIjH3ltLdvtWjM/t3ORtNzjkzaw78H0gz92HAMnAlcAdwFvu3g94K/JaRKRZWbJpN/PX7+K6U3uTktz0gyfRajEFaGVmKUA6sBm4EJgaWT8VuChKbYmIxI1H/7WWjLQULh/ZI5D2Gxzy7l4IPABsBLYAe9x9FtDZ3bdEttkC1HrnCjObaGYLzGyBzoUXkURSsOsAr3+6latG9aRNE1389EXRGK5pT3WvvTfQDWhtZlcf6/vdfbK757l7XlZWVkPLERGJGU+8vx4DJpySE1gN0RiuORNY5+7F7l4JPA+cAmwzs64AkceiKLQlIhIX9pRW8sz8TZx7Qle6ZbYKrI5ohPxG4GQzS7fqKdXGAfnAy8CEyDYTgJei0JaISFyYNncD+8qr+M7pxwdaR4MHidz9IzN7FlgEVAGLgclAG2CmmV1P9Q+CyxralohIPCitCPHYe+v48oAscru1DbSWqHwT4O73APd8YXE51b16EZFm5Zn5G9m5v4Jbvtw36FJ0xauISDRVVIWZPGcto3p3IC+n8e/8dDQKeRGRKHpxcSFb9pTFRC8eFPIiIlETCjv/8+5nDOnelrH9jgu6HEAhLyISNa9/uoV12/dzyxl9m+T+rcdCIS8iEgXuzkNvf8bxWa05e3CXoMs5RCEvIhIF76wsJn9LCd85/XiSkmKjFw8KeRGRBqvuxa+he2YrLhrWPehyPkchLyLSQB98toMFG3YxcWwfUgOYTvhIYqsaEZE44+78cfYqurRN44qAphM+EoW8iEgDzFm9nYUbdnHLV/qSlpocdDn/RiEvIlJPB3vx3TNbcXledtDl1EohLyJST2+vLOLjTbv53lf60jIl9nrxoJAXEamXg734Hh1acemI2OzFg0JeRKReZi/fxqeFJXz/K/1i7oyammK3MhGRGBUOO396czU5HdO5OMbOi/8ihbyISB39Y9lW8reUcOuZ/UiJ4V48KORFROqkuhe/ij5ZrblgaGz34kEhLyJSJy8uKWTVtn3cdmZ/kmNojprDUciLiByjssoQf5i1ihO6t+O8E7oGXc4xUciLiByjpz7cQOHuUu746sCYmmnySBTyIiLHYM+BSv777TWM7Z/FqX1j465PxyIqIW9mmWb2rJmtMLN8MxtjZh3MbLaZrY48to9GWyIiQXj43TWUlFVyxzkDgy6lTqLVk/8z8Ia7DwSGAvnAHcBb7t4PeCvyWkQk7mzeXcqU99dz8bDu5HZrG3Q5ddLgkDeztsBY4DEAd69w993AhcDUyGZTgYsa2paISBD+OHsVAD8aPyDgSuouGj35PkAxMMXMFpvZo2bWGujs7lsAIo+danuzmU00swVmtqC4uDgK5YiIRM+KrSU8t6iAb52SQ/fMVkGXU2fRCPkUYDjwP+4+DNhPHYZm3H2yu+e5e15WVlYUyhERiZ7fvr6CjJYp3HzG8UGXUi/RCPkCoMDdP4q8fpbq0N9mZl0BIo9FUWhLRKTJvL9mO2+vLOaWL/clM71F0OXUS4ND3t23ApvM7OBg1ThgOfAyMCGybALwUkPbEhFpKlWhMPe+soyeHdKZcEpO0OXUW0qU9vM9YLqZtQDWAt+m+gfITDO7HtgIXBaltkREGt20uRtYtW0ff71mREze1u9YRSXk3X0JkFfLqnHR2L+ISFPaub+CP85exWl9j2N8buegy2kQXfEqIvIFD8xayf6KEPecn4tZfExfcDgKeRGRGpZt3sPT8zZy7Zhe9OucEXQ5DaaQFxGJcHfufXk57dNbcNuZ/YMuJyoU8iIiEa9+soV563fy4/EDaNcqNehyokIhLyIClFaE+M1r+Qzu1pYrRvYIupyoUciLiAAP/nM1m/eUcc/5g+Pijk/HSiEvIs3eiq0lPDJnLZeNyGZU7w5BlxNVCnkRadbCYeeu55fStlUqd507KOhyok4hLyLN2t/mbWTRxt3cfe4g2reOz/lpjkQhLyLNVlFJGb99YwWnHN+RS4Z3D7qcRqGQF5Fm65evLqe8KsyvLxoS91e2Ho5CXkSapbdXFvHqJ1v47pf70ierTdDlNBqFvIg0O6UVIf7fi59yfFZrJp3eJ+hyGlW0phoWEYkbv//HSgp2lTJj4sm0TInfaYSPhXryItKsfLR2B1M+WMe1Y3oxuk/HoMtpdAp5EWk29pdX8ZNnP6FH+3Tu+OrAoMtpEhquEZFm4/7XV7Bp1wFmTBxDeovmEX/qyYtIs/D+mu08NXcD153aO+GmLjgShbyIJLy9ZZXc/uwn9MlqzU/OHhB0OU2qefy+IiLN2n1/z2fLnlKevemUuL4pd32oJy8iCe3tFUU8M38TE8cez/Ce7YMup8lFLeTNLNnMFpvZq5HXHcxstpmtjjw2v09XRAJVVFLGj//3YwZ2yeAHZ/ULupxARLMnfyuQX+P1HcBb7t4PeCvyWkSkSYTDzg9nfsz+iir++6phCX/R0+FEJeTNLBv4GvBojcUXAlMjz6cCF0WjLRGRY/HXOWt5b812fnH+YPp2ygi6nMBEqyf/n8DtQLjGss7uvgUg8tiptjea2UQzW2BmC4qLi6NUjog0Z4s37uIPs1bytRO6JtT9WuujwSFvZucBRe6+sD7vd/fJ7p7n7nlZWVkNLUdEmrmSskq+/8xiOrdN4z8uOSFhpxA+VtE4hfJU4AIzOxdIA9qa2TRgm5l1dfctZtYVKIpCWyIih+Xu3P3Cp2zeXcbMSSfTrlVq0CUFrsE9eXe/092z3T0HuBL4p7tfDbwMTIhsNgF4qaFtiYgcyYz5m3jl48384Mx+jOjVfK5qPZLGPE/+fuAsM1sNnBV5LSLSKD7etJufv7yM0/oex01n9A26nJgR1Ste3f0d4J3I8x3AuGjuX0SkNtv3lXPTtIVktWnJg98YRnJS8x6Hr0nTGohIXKsKhfne3xazY38Fz910Ch1atwi6pJiikBeRuPa7f6zkw7U7eOCyoQzp3i7ocmKO5q4Rkbj16iebmTxnLdeO6cXXR2QHXU5MUsiLSFxauXUvtz/7CXm92vOzr+UGXU7MUsiLSNwp3lvO9VPn07plCg9/czgtUhRlh6NPRkTiSmlFiBueXMCOfRU8NiGPTm3Tgi4ppumLVxGJG+Gw84MZS/ikYDd/vXoEJ2ZnBl1SzFNPXkTixm/fWMEby7bys6/lMn5wl6DLiQsKeRGJC9M/2sBfI2fSXHdqTtDlxA2FvIjEvHdWFvHzl5bx5QFZ/Py83GY/s2RdKORFJKYt3LCTm6Yton/nDP7rquGkJCu26kKflojErGWb9/CtKfPp3LYlT143ijYtda5IXSnkRSQmrS3ex7WPzSOjZQrTbhhNVkbLoEuKSwp5EYk5hbtLufrRjwB46obRZLdPD7ii+KXffUQkphTvLeeaRz9ib3kVT994MsdntQm6pLimnryIxIziveVc/ehHbN5TypRvjdSsklGgnryIxIRtJWVc9chcCneX8tiEkeTl6PZ90aCQF5HAbd5dylWPzKV4bzlTvz2K0X06Bl1SwlDIi0igNu08wDcemcueA5U8ef1oRvRqH3RJCUUhLyKBWb99P1c9Mpf9FSGm3zhaE441ggZ/8WpmPczsbTPLN7NlZnZrZHkHM5ttZqsjj/rxLCKHLC3Yw9f/8iFlVWGevvFkBXwjicbZNVXAj9x9EHAycIuZ5QJ3AG+5ez/grchrERHeXlHEFZM/pGVKEjMnjSG3W9ugS0pYDQ55d9/i7osiz/cC+UB34EJgamSzqcBFDW1LROLfM/M2csOTC+iT1ZoXbjmFvp10HnxjiuqYvJnlAMOAj4DO7r4Fqn8QmFmnaLYlIvHF3fnT7FU8+M81nN4/i4e+OVxz0TSBqH3CZtYGeA64zd1LjnUqUDObCEwE6NmzZ7TKEZEYUlYZ4q4XlvL8okKuyOvBry8eQqpmk2wSUfmUzSyV6oCf7u7PRxZvM7OukfVdgaLa3uvuk909z93zsrKyolGOiMSQzbtLufyvH/L8okJ+eFZ/7r/0BAV8E2pwT96qu+yPAfnu/scaq14GJgD3Rx5famhbIhJf5q7dwS3TF1FeFWbyNSN0y74ARGO45lTgGmCpmS2JLLuL6nCfaWbXAxuBy6LQlojEAXdnyvvrue+1fHI6pvPXa/L0BWtAGhzy7v4ecLgB+HEN3b+IxJd95VX87IWlvLhkM+NzO/OHy4eSkZYadFnNlr7aFpGoWbxxF7fNWMKmnQf40Vn9ueXLfUlK0v1Yg6SQF5EGC4Wd/3lnDX96czVd2qYxY9IYRmoWyZigkBeRBtm8u5TbZixh3rqdnD+0G7++aAjtWml4JlYo5EWkXsJhZ8aCTfzHa/mEw84fLhvKJcO7c6zXyEjTUMiLSJ19VryPO59fyrx1Ozm5Twd+e+mJ9OrYOuiypBYKeRE5ZhVVYSbP+YwH/7mGtJQkfnvpCVye10O99ximkBeRY/LBmu3c+8pyVm7by9dO6Mo9F+TSKSMt6LLkKBTyInJEa4v38R+v5fNmfhHZ7VvxyLV5nJXbOeiy5Bgp5EWkVrsPVPDnt1bz1IcbSEtN5vZzBnDdqb1JS00OujSpA4W8iHzO3rJKpn6wnkf+tY69ZZVcMbInPzyrP1kZLYMuTepBIS8iAJSUVfLE++t57L117CmtZNzATvzknAEM7KK7NsUzhbxIM7dzfwVPfbiBx95bS0lZFWcO6syt4/pxQna7oEuTKFDIizRTK7aWMOW99by4pJDyqjBn5VaH+5DuCvdEopAXaUaqQmH+uaKIKe+v58O1O0hLTeKS4dl8+9Qc+nfOCLo8aQQKeZFmYNW2vTy3sIDnFxdSvLecbu3S+Ok5A/nGqB5kprcIujxpRAp5kQRVtLeM15du5blFBXxSsIeUJOOMAZ34+ohszhzUiRTdgq9ZUMiLJJCNOw7wj2Vb+ceyrSzcuAt3yO3alp+fl8sFJ3XjuDY6DbK5UciLxLGyyhCLNuzivTXbeXtlMflbSgAY1LUtt43rzzlDujCgi8bamzOFvEgcKa8KsWxzCfPW7eT9NduZt24n5VVhkpOM4T0zufvcQZw9uAs9O6YHXarECIW8SIwKhZ0NO/bz6eYSFm/cxeKNu1m+uYSKUBiAAZ0z+OboXpzWryOjenekTUv9d5Z/p38VIgGrqApTsOsAG3Ye4LOifazYupeVW/eyumgvZZXVgd4qNZkTs9tx3Wm9GdYzk2E9MzUDpByTRg95MzsH+DOQDDzq7vc3dpsiscDdKa0MsetAJdv3lrOtpIxte8spKilj654yCneXsmHHAbbsKSXs//e+rIyWDOySwdWjezGgSwaDurZlYJcMnQ0j9dKoIW9mycBDwFlAATDfzF529+WN2a40nbLKEMV7y9m+r5zdByopKaukpLSSkrIqSkor2VdeRWlliLLKEGWVYUorQpRXhagKO1UhpyocpirkhLw65dw/v//kJMMMks1ITjKSzEhNrn6ekpxESuQxNclITU4iNaX6eUpy5HVkm4PLk5OSSEk2UpIi+0gykpIMMyPJwKh+dCDsjnt1WDvVPe7KkFMZCkeehymtDLG/PMSBiioOVFQ/lpRWsetABbtLK6moCv/bZ5Zk1UHeLbMVI3Pa07NjNr06pNOrYzq9j2tNR50BI1HU2D35UcAad18LYGbPABcCCvk44O5s2VPGxp0HKNhVSsGu/3ssKimneG85e8urDvv+tNQkWrdIoVWLZNJSk2kV+ZPeIiUStAdDujpwD95b6OBdhtydsEPInXDYCYWdsPvnfkBUVIXZXxGiKlQdulUhp6LG88pQmKpw9WNlyA9ba12ZQYvkJFq1SKZ1ixTSWyST3jKF9NRkco5LZ1h6Ju3SU8ls1YL26al0bNOSLm3T6Ny2JR3btCQ5SXdSkqbR2CHfHdhU43UBMLqR25R6KCopY2nhHlZt28fqor18VrSPNUX72F8ROrSNGXRpm0b3zFYM6taWsW1akpXRkqw2LTkuowXt01vQtlUqbdNSadsqhZYpsTXv+MEfGlXhMKHw//2wOLj84GPYnaRIzx6DJKv+AZSakkSLyG8HCmmJF40d8rX9T/hcd8rMJgITAXr27NnI5QjAvvIqFm3YxcebdvNJ4R6WFuxha0nZofWd27akX6cMLsvrQd9Obcjp2Jrs9q3ompkWc8FdF2ZGskFyUvweg0hdNXbIFwA9arzOBjbX3MDdJwOTAfLy8qL3+7Qcsq+8igXrdzJ37U7mrt3B0sI9hCLf9PXJas3JfTpwQnYmJ2a3Y0CXDNqmpQZcsYhES2OH/Hygn5n1BgqBK4GrGrlNAdZv38+b+dt4K7+I+et3UhV2UpONodmZ3HT68Yzu04GhPTIV6CIJrlFD3t2rzOy7wD+oPoXycXdf1phtNlfuzqeFJbz6yWZm529jbfF+APp1asP1X+rNl/pmMbxXJuktdGmESHPS6P/j3f014LXGbqe52rBjPy8u3sxLHxeytng/qcnG6N4duebkXowb2FmXt4s0c+rWxaHSihAvf1zI0/M2sWTTbgBG9+7AjV/qw1eHdNH84CJyiEI+jnxWvI9pczfw7MIC9pZVMaBzBnd+dSDnD+1Gt8xWQZcnIjFIIR/j3J13Vhbz6HtreX/NDlKTjXNP6MrVJ/cir1f7QxcOiYjURiEfo0Jh57WlW3j4nc/I31JCt3Zp/OTsAVye14OsDF32LiLHRiEfYyqqwrywuIC/vLuWddv3c3xWax64bCgXntSNVE1QJSJ1pJCPEeGw8+rSLTzwj5Vs3HmAE7q34y9XD2d8bheSdAm9iNSTQj4GfLBmO795fQVLC/cwsEsGU741kjMGZGm8XUQaTCEfoFXb9nLf3/N5d1Ux3dql8YfLhnLRsO6a/EpEokYhH4ADFVX8+a3VPPavdaS3SOaucwdy7Zgc0lI1cZaIRJdCvonNWraVe19ZTuHuUi4bkc2d5w6iQ2tdvCQijUMh30QKd5dyz0vLeDN/G/07t+F/vzOGkTkdgi5LRBKcQr6RuTvPLizg3leWEwo7d351INed1lunQ4pIk1DIN6Lt+8q56/mlzFq+jVE5HfjD5UPp0UETholI01HIN5LZy7dx5/OfUFJaxd3nDuK603rrrBkRaXIK+Sgrqwxx7yvLeHreJnK7tmX6DScxoEtG0GWJSDOlkI+ijTsOcNP0hSzbXMJNZxzPD87sT4sUjb2LSHAU8lEye/k2fjhzCUlmPP6tPL4ysHPQJYmIKOQbqioU5oFZq/jLu59xQvd2PPzN4fpyVURihkK+AfYcqOSm6Qv54LMdXDW6Jz8/L1dXrYpITFHI19OGHfv59hPz2bTzAL//+olcltcj6JJERP6NQr4e5q/fycQnF+DAtOtHM7pPx6BLEhGpVYNO/TCz35vZCjP7xMxeMLPMGuvuNLM1ZrbSzM5ucKUx4sXFhXzzkY9on96CF24+VQEvIjGtoef3zQaGuPuJwCrgTgAzywWuBAYD5wAPm1lcD1a7O39+czW3zVjC8F6ZPH/zKfQ+rnXQZYmIHFGDQt7dZ7l7VeTlXCA78vxC4Bl3L3f3dcAaYFRD2gpSOOzc+8py/vTmKi4Z3p0nrxtNZrpmjhSR2BfNK3WuA16PPO8ObKqxriCy7N+Y2UQzW2BmC4qLi6NYTnRUhcLc/twnPPHBeq4/rTd/uGyoLnASkbhx1C9ezexNoEstq+5295ci29wNVAHTD76tlu29tv27+2RgMkBeXl6t2wSlvCrEbc8s4fVPt3Lbmf24dVw/3ZJPROLKUUPe3c880nozmwCcB4xz94MhXQDUPKcwG9hc3yKDUFoRYtK0hcxZVcz/Oy+X60/rHXRJIiJ11tCza84Bfgpc4O4Haqx6GbjSzFqaWW+gHzCvIW01pQMVVUyYMo/3Vhfzu0tPVMCLSNxq6Hny/w20BGZHhjHmuvt33H2Zmc0EllM9jHOLu4ca2FaTKKsMccPUBSxYv5M/XzmM84d2C7okEZF6a1DIu3vfI6y7D7ivIftvauVVISY9tZAP1+7gT5efpIAXkbin00QiKkNhbpm+mHdXFXP/JSdw0bBaTwYSEYkrCnmqT5O87ZklvJm/jV9dOJgrRvYMuiQRkaho9iEfDju3P/sJf1+6hZ99bRDXjMkJuiQRkahp9iF//xsreH5xIT86qz83fKlP0OWIiERVsw75Ke+vY/KctVw7phff/cphv0MWEYlbzTbkX1+6hV++upzxuZ255/zBupJVRBJSswz5Bet3cuuMJQzrkcmD3xhGcpICXkQSU7ML+TVF+7h+6gKyM1vx6ISRul2fiCS0ZhXyO/aV860p80hNNp749ig6tNZ0wSKS2JrN7f8qQ2Fumr6I4r3lzJw0hp4d04MuSUSk0TWbkP/lK8uZt24nf77yJIb2yAy6HBGRJtEshmv+9tFGnpq7gUmn9+HCkzRdgYg0Hwkf8vPW7eTnL33K6f2zuP3sgUGXIyLSpBI65At3l3LTtIX06JCuUyVFpFlK2JAvqwwx6akFlFeFeeTaEbRrlRp0SSIiTS5hv3i995XlfFpYwqPX5tG3U0bQ5YiIBCIhe/IvLSnk6XkbmXR6H87M7Rx0OSIigUm4kF9bvI+7nl/KiF7t+fH4AUGXIyISqIQK+bLKELf8bTGpKUn81zeGkZqcUIcnIlJnCTUm/6tXl5O/pYTHv5VHt8xWQZcjIhK4hOnqvvLxZqZ/tJFJY/vwlYEahxcRgSiFvJn92MzczI6rsexOM1tjZivN7OxotHM467bv587nlzK8ZyY/Plvj8CIiBzV4uMbMegBnARtrLMsFrgQGA92AN82sv7uHGtpebVKSjGE9M7n/0hM1Di8iUkM0EvFPwO2A11h2IfCMu5e7+zpgDTAqCm3VqkeHdJ66fjTdNQ4vIvI5DQp5M7sAKHT3j7+wqjuwqcbrgsiy2vYx0cwWmNmC4uLihpQjIiJfcNThGjN7E+hSy6q7gbuA8bW9rZZlXssy3H0yMBkgLy+v1m1ERKR+jhry7n5mbcvN7ASgN/Bx5CbY2cAiMxtFdc+9R43Ns4HNDa5WRETqpN7DNe6+1N07uXuOu+dQHezD3X0r8DJwpZm1NLPeQD9gXlQqFhGRY9YoF0O5+zIzmwksB6qAWxrrzBoRETm8qIV8pDdf8/V9wH3R2r+IiNSdTioXEUlgCnkRkQRm7rFz1qKZFQMbGrCL44DtUSonSIlyHKBjiUWJchygYzmol7tn1bYipkK+ocxsgbvnBV1HQyXKcYCOJRYlynGAjuVYaLhGRCSBKeRFRBJYooX85KALiJJEOQ7QscSiRDkO0LEcVUKNyYuIyOclWk9eRERqSKiQN7NfmdknZrbEzGaZWbega6ovM/u9ma2IHM8LZpYZdE31ZWaXmdkyMwubWdydCWFm50TucLbGzO4Iup76MrPHzazIzD4NupaGMrMeZva2meVH/m3dGnRN9WFmaWY2z8w+jhzHvVFvI5GGa8ysrbuXRJ5/H8h19+8EXFa9mNl44J/uXmVmvwVw958GXFa9mNkgIAz8Ffixuy8IuKRjZmbJwCqq735WAMwHvuHuywMtrB7MbCywD3jS3YcEXU9DmFlXoKu7LzKzDGAhcFG8/b1Y9RS+rd19n5mlAu8Bt7r73Gi1kVA9+YMBH9Gaw8xhHw/cfZa7V0VezqV6uua45O757r4y6DrqaRSwxt3XunsF8AzVdz6LO+4+B9gZdB3R4O5b3H1R5PleIJ/D3Jgolnm1fZGXqZE/Uc2thAp5ADO7z8w2Ad8Efh50PVFyHfB60EU0U8d8lzMJhpnlAMOAjwIupV7MLNnMlgBFwGx3j+pxxF3Im9mbZvZpLX8uBHD3u929BzAd+G6w1R7Z0Y4lss3dVE/XPD24So/uWI4lTh3zXc6k6ZlZG+A54LYv/CYfN9w95O4nUf3b+igzi+pQWqPMJ9+YDnenqlr8Dfg7cE8jltMgRzsWM5sAnAeM8xj/8qQOfy/xRnc5i1GRMezngOnu/nzQ9TSUu+82s3eAc4CofTkedz35IzGzfjVeXgCsCKqWhjKzc4CfAhe4+4Gg62nG5gP9zKy3mbUArqT6zmcSoMgXlo8B+e7+x6DrqS8zyzp45pyZtQLOJMq5lWhn1zwHDKD6TI4NwHfcvTDYqurHzNYALYEdkUVz4/hMoYuB/wKygN3AEnc/O9Ci6sDMzgX+E0gGHo/cECfumNnTwBlUz3a4DbjH3R8LtKh6MrPTgH8BS6n+/w5wl7u/FlxVdWdmJwJTqf63lQTMdPdfRrWNRAp5ERH5vIQarhERkc9TyIuIJDCFvIhIAlPIi4gkMIW8iEgCU8iLiCQwhbyISAJTyIuIJLD/D4NyVqoKGt6uAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出这个函数的曲线\n",
- "x_sample = np.arange(-3, 3.1, 0.1)\n",
- "y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 + w_target[2] * x_sample ** 3\n",
- "\n",
- "plt.plot(x_sample, y_sample, label='real curve')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "接着构建数据集,需要 x 和 y,同时是一个三次多项式,所以取 $x,\\ x^2, x^3$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 构建数据 x 和 y\n",
- "# x 是一个如下矩阵 [x, x^2, x^3]\n",
- "# y 是函数的结果 [y]\n",
- "\n",
- "x_train = np.stack([x_sample ** i for i in range(1, 4)], axis=1)\n",
- "x_train = torch.from_numpy(x_train).float() # 转换成 float tensor\n",
- "\n",
- "y_train = torch.from_numpy(y_sample).float().unsqueeze(1) # 转化成 float tensor "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "torch.Size([61, 3])\n"
- ]
- }
- ],
- "source": [
- "print(x_train.size())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "接着定义需要优化的参数,就是前面这个函数里面的 $w_i$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 定义参数\n",
- "w = torch.randn((3, 1), dtype=torch.float, requires_grad=True)\n",
- "b = torch.zeros((1), dtype=torch.float, requires_grad=True)\n",
- "\n",
- "# 定义模型\n",
- "def multi_linear(x):\n",
- " return torch.mm(x, w) + b\n",
- "\n",
- "def get_loss(y_, y):\n",
- " return torch.mean((y_ - y) ** 2)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "我们可以画出没有更新之前的模型和真实的模型之间的对比"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7ac41e220>"
- ]
- },
- "execution_count": 20,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqEklEQVR4nO3deXxU1f3/8deHEAib7CCCLLaKIMpi4Ata0YoCVQRUrLjiilatYFtBxJ/aKopCXWhFi4poRZZiFdoKgiDuQANCBSKCyiZbBFEwBEhyfn+cGRIwgZCZyZ2ZvJ+Px3nMdmfuZ7J85sy5536OOecQEZHkVCHoAEREJHaU5EVEkpiSvIhIElOSFxFJYkryIiJJTEleRCSJlTjJm9l4M9tmZssL3VfHzOaY2erQZe1Cjw0zszVmtsrMekQ7cBERObKj6clPAHoect89wFzn3InA3NBtzKw10B84JfScsWaWEnG0IiJyVCqWdEPn3Ptm1vyQu/sA54SuvwzMB4aG7p/snNsLfG1ma4BOwCeH20e9evVc8+aH7kJERA5n8eLF3zrn6hf1WImTfDEaOuc2AzjnNptZg9D9jYEFhbbbGLrvsJo3b05GRkaEIYmIlC9mtq64x2J14NWKuK/I+glmNtDMMswsIysrK0bhiIiUT5Em+a1m1gggdLktdP9G4PhC2zUBNhX1As65cc65dOdcev36RX7bEBGRUoo0yc8ABoSuDwCmF7q/v5lVNrMWwInAogj3JSIiR6nEY/JmNgl/kLWemW0EHgBGAlPN7EZgPXAZgHNuhZlNBVYCucDtzrm80gS4f/9+Nm7cSE5OTmmeLhFKS0ujSZMmpKamBh2KiJSCxVOp4fT0dHfogdevv/6aGjVqULduXcyKGuqXWHHOsX37dnbt2kWLFi2CDkdEimFmi51z6UU9FvdnvObk5CjBB8TMqFu3rr5FiSSwuE/ygBJ8gPSzF0lsCZHkRUSS2ZgxMGNGbF5bSb4ExowZQ6tWrbjqqquYMWMGI0eOBODNN99k5cqVB7abMGECmzYVzBS96aabDnpcRORQO3fCsGEwffoRNy2VSM94LRfGjh3LzJkzDxx87N27N+CTfK9evWjdujXgk3ybNm047rjjAHjhhReCCbiQvLw8UlJUNkgkXk2YANnZcPvtsXl99eSP4NZbb+Wrr76id+/ePPnkk0yYMIE77riDjz/+mBkzZnD33XfTrl07HnvsMTIyMrjqqqto164de/bs4ZxzzjlQpqF69eoMHz6ctm3b0rlzZ7Zu3QrAl19+SefOnenYsSP3338/1atXLzKOV155hdNOO422bdtyzTXXAHDdddcxbdq0A9uEnzt//nx++ctfcuWVV3LqqacydOhQxo4de2C7Bx98kD//+c8AjBo1io4dO3LaaafxwAMPRP8HKCLFys+HZ56BM86ADh1is4/E6skPHgxLl0b3Ndu1g6eeKvbh5557jlmzZvHuu+9Sr149JkyYAMAZZ5xB79696dWrF/369QNg5syZjB49mvT0n85k+vHHH+ncuTMjRoxgyJAhPP/889x3330MGjSIQYMGccUVV/Dcc88VGcOKFSsYMWIEH330EfXq1WPHjh1HfFuLFi1i+fLltGjRgk8//ZTBgwdz2223ATB16lRmzZrF7NmzWb16NYsWLcI5R+/evXn//ffp2rXrEV9fRCI3Zw6sWQN//GPs9qGefBmpVKkSvXr1AuD0009n7dq1AHzyySdcdtllAFx55ZVFPnfevHn069ePevXqAVCnTp0j7q9Tp04Hhpfat2/Ptm3b2LRpE8uWLaN27do0bdqU2bNnM3v2bNq3b0+HDh34/PPPWb16daRvVURK6K9/hYYNIdRPjInE6skfpscd71JTUw9MR0xJSSE3N7fEz3XOFTmVsWLFiuTn5x/YZt++fQceq1at2kHb9uvXj2nTprFlyxb69+9/4DnDhg3jlltuOer3IyKR+eor+M9/4L77oFKl2O1HPfkI1KhRg127dhV7uyQ6d+7M66+/DsDkyZOL3KZbt25MnTqV7du3AxwYrmnevDmLFy8GYPr06ezfv7/Y/fTv35/Jkyczbdq0A8NLPXr0YPz48ezevRuAb775hm3bthX7GiISPc8+CxUqQKz7WEryEejfvz+jRo2iffv2fPnll1x33XXceuutBw68lsRTTz3FE088QadOndi8eTM1a9b8yTannHIKw4cP5+yzz6Zt27b87ne/A+Dmm2/mvffeo1OnTixcuPAnvfdDX2PXrl00btyYRo0aAdC9e3euvPJKunTpwqmnnkq/fv2O+kNKRI5edja8+CJcfDE0PuJKG5GJ+9o1mZmZtGrVKqCIYi87O5sqVapgZkyePJlJkyYxPVYTZksp2X8HImVt/Hi48UaYPx/OPjvy1ztc7ZrEGpNPQosXL+aOO+7AOUetWrUYP3580CGJSAw5B3/5C7RpA2UxkU1JPmBnnXUWy5YtCzoMESkjn3ziZ4I/+yyURWkojcmLiJShv/4VjjkGrr66bPanJC8iUka++Qb+8Q+4/noo5uT2qFOSFxEpI3/9qy9lcOedZbdPJXkRkTKwezc895yfNnnCCWW3XyX5MtC8eXO+/fbboMMQkQBNmODLCv/+92W7XyX5o+CcO1BGQHGISEnl5cGTT0LnztClS9nuW0n+CNauXUurVq247bbb6NChAxs2bCi2PG/fvn05/fTTOeWUUxg3btwRX3vWrFl06NCBtm3b0q1bN8CXAR49evSBbdq0acPatWt/EsdDDz3EkCFDDmw3YcIEfvvb3wLw6quv0qlTJ9q1a8ctt9xCXl5etH4cIlIKM2b4WjVl3YuHKM2TN7O7gJsAB3wGXA9UBaYAzYG1wK+dc99Fsp8AKg0DsGrVKl566SXGjh172PK848ePp06dOuzZs4eOHTty6aWXUrdu3SJfMysri5tvvpn333+fFi1alKh8cOE4srKy6NKlC48//jgAU6ZMYfjw4WRmZjJlyhQ++ugjUlNTue2225g4cSLXXnvtUf5kRCRanngCmjeHvn3Lft8RJ3kzawzcCbR2zu0xs6lAf6A1MNc5N9LM7gHuAYZGur8gNGvWjM6dOwMcVJ4XYPfu3axevZquXbsyZswY3njjDQA2bNjA6tWri03yCxYsoGvXrgfKAZekfHDhOOrXr88JJ5zAggULOPHEE1m1ahVnnnkmzzzzDIsXL6Zjx44A7NmzhwYNGkT2AxCRUlu0CD780HcmKwZw+mm0dlkRqGJm+/E9+E3AMOCc0OMvA/OJMMkHVWm4cOGv4srzzp8/n3feeYdPPvmEqlWrcs4555CTk1Psa5akfDBw0GscWoDs8ssvZ+rUqZx88slcfPHFmBnOOQYMGMCjjz561O9TRKLviSegZk244YZg9h/xmLxz7htgNLAe2Ax875ybDTR0zm0ObbMZKLI7aWYDzSzDzDKysrIiDSfmiivP+/3331O7dm2qVq3K559/zoIFCw77Ol26dOG9997j66+/Bg4uH7xkyRIAlixZcuDxolxyySW8+eabTJo0icsvvxzwZYmnTZt2oGTwjh07WLduXWRvWkRKZd06mDYNBg6EGjWCiSEawzW1gT5AC2An8A8zK/EJu865ccA48FUoI40n1rp3705mZiZdQofIq1evzquvvkrPnj157rnnOO2002jZsuWBYZXi1K9fn3HjxnHJJZeQn59PgwYNmDNnDpdeeimvvPIK7dq1o2PHjpx00knFvkbt2rVp3bo1K1eupFOnTgC0bt2ahx9+mO7du5Ofn09qairPPPMMzZo1i94PQURKZMwYX58mNCciEBGXGjazy4CezrkbQ7evBToD3YBznHObzawRMN851/Jwr1UeSw0nAv0ORI7ezp3QtCn06gWvvRbbfR2u1HA0plCuBzqbWVXzg8zdgExgBjAgtM0AIL6KpIuIxNDYsbBrFwwNeLpJxMM1zrmFZjYNWALkAp/ih1+qA1PN7Eb8B8Flke5LRCQRZGf7k58uuADatg02lqjMrnHOPQA8cMjde/G9+mi8fpEzUST24mnlMJFE8cIL8O23cO+9QUeSAGe8pqWlsX37diWbADjn2L59O2lpaUGHIpIw9u2DUaP8qk9nnhl0NAmwMlSTJk3YuHEjiTC9MhmlpaXRpEmToMMQSRivvgobN/refDyI+ySfmpp64KxQEZF4lpcHI0dChw7QvXvQ0Xhxn+RFRBLF66/D6tX+BKh4OYwY92PyIiKJwDl45BE4+WS/MEi8UE9eRCQKZs6EZcvgpZegQhx1n+MoFBGRxBTuxTdtClddFXQ0B1NPXkQkQvPmwUcfwV/+AqmpQUdzMPXkRUQi4Bzcfz80bgw33RR0ND+lnryISARmz4aPP/a1auLxvEH15EVESinci2/aNLhFQY5EPXkRkVJ66y2/vN/zz0PlykFHUzT15EVESiHci2/RAgYMOPL2QVFPXkSkFGbMgCVL/Lz4eJtRU5h68iIiRyk/Hx54AH7+c7i6xIudBkM9eRGRo/TGG/7s1r//HSrGeRZVT15E5CiEe/EtW8IVVwQdzZHF+WeQiEh8mTgRVqyASZMgJSXoaI5MPXkRkRLKyYH77oPTT4df/zroaEpGPXkRkRJ65hlYvx7Gj4+vSpOHkyBhiogE67vvYMQI6NEDunULOpqSi0qSN7NaZjbNzD43s0wz62JmdcxsjpmtDl3Wjsa+RESCMHIk7NwJjz0WdCRHJ1o9+aeBWc65k4G2QCZwDzDXOXciMDd0W0Qk4WzYAE8/DddcA23bBh3N0Yk4yZvZMUBX4EUA59w+59xOoA/wcmizl4G+ke5LRCQI99/vLx96KNg4SiMaPfkTgCzgJTP71MxeMLNqQEPn3GaA0GWDop5sZgPNLMPMMrKysqIQjohI9Hz2Gbz8Mvz2t77aZKKJRpKvCHQAnnXOtQd+5CiGZpxz45xz6c659Pr160chHBGR6LnnHqhZE4YNCzqS0olGkt8IbHTOLQzdnoZP+lvNrBFA6HJbFPYlIlJm5s715YTvvRfq1Ak6mtKJOMk757YAG8ysZeiubsBKYAYQLsA5AJge6b5ERMpKbi4MGgQnnOCHahJVtE6G+i0w0cwqAV8B1+M/QKaa2Y3AeuCyKO1LRCTmnn3Wly944434XNavpKKS5J1zS4H0Ih5KoFMGRES8b7/1M2rOOw/69Ak6msjojFcRkUPcdx/s2uXnxpsFHU1klORFRApZuhTGjYM77oDWrYOOJnJK8iIiIc7BnXdC3brw4INBRxMdqkIpIhIydSp88AH87W9Qq1bQ0USHevIiIkB2Ntx9N7RvDzfeGHQ00aOevIgIvi7Nhg1+5adEWPGppNSTF5Fy77PPYPRouP56OOusoKOJLiV5ESnX8vPhllv8GPyoUUFHE30arhGRcm3cOPjkE19psm7doKOJPvXkRaTc2rzZV5k891y/IEgyUpIXkXLrrrsgJ8fXqUn0M1uLoyQvIuXSzJkwZQoMHw4nnRR0NLGjJC8i5U52Ntx2G5x8MgwZEnQ0saUDryJS7gwfDmvXwnvvQeXKQUcTW+rJi0i58v77vrrk7bdD165BRxN7SvIiUm7s3u1PeGrRAh57LOhoyoaGa0Sk3Bg6FL7+2g/TVKsWdDRlQz15ESkX5s6FsWNh8ODkK11wOEryIpL0fvgBbrgBWraEESOCjqZsabhGRJLe738PGzfCRx9BlSpBR1O21JMXkaT21lvwwgu+VnznzkFHU/ailuTNLMXMPjWzf4du1zGzOWa2OnRZO1r7EhEpic2b4brr4NRT4Y9/DDqaYESzJz8IyCx0+x5grnPuRGBu6LaISJnIz4drr/XTJqdMSf6TnooTlSRvZk2AC4EXCt3dB3g5dP1loG809iUiUhKjRsE778CYMdCqVdDRBCdaPfmngCFAfqH7GjrnNgOELhsU9UQzG2hmGWaWkZWVFaVwRKQ8W7gQ7rsPLrssudZrLY2Ik7yZ9QK2OecWl+b5zrlxzrl051x6/fr1Iw1HRMq577+HK66Axo39giDJWkK4pKIxhfJMoLeZXQCkAceY2avAVjNr5JzbbGaNgG1R2JeISLGcg1tvhfXrfY2aWrWCjih4EffknXPDnHNNnHPNgf7APOfc1cAMYEBoswHA9Ej3JSJyOC++CJMn+5k0Z5wRdDTxIZbz5EcC55vZauD80G0RkZj473/hjjvgvPP8kn7iRfWMV+fcfGB+6Pp2oFs0X19EpCjbtsGll8Kxx8KkSZCSEnRE8UNlDUQkoeXmQv/+kJXlyxbUqxd0RPFFSV5EEtqwYfDuuzBhAnToEHQ08Ue1a0QkYU2dCqNH+1WeBgw48vblkZK8iCSk5ct9+eAzz4Qnngg6mvilJC8iCWfrVrjoIqhRA/7xD6hUKeiI4pfG5EUkoWRnQ+/efkbNe+9Bo0ZBRxTflORFJGHk58M11/g58W+8AenpQUcU/5TkRSRh3HMP/POf8OST0KdP0NEkBo3Ji0hC+NvffPng22+HQYOCjiZxKMmLSNybNcsn9wsugKeeUmXJo6EkLyJx7eOPfcmCNm188bGKGmQ+KkryIhK3li71vffjjoO33/ZTJuXoKMmLSFz64gvo3h2OOcYv49ewYdARJSYleRGJO+vX+5LBAHPmQLNmwcaTyDS6JSJxZetWOP98+OEHX3isZcugI0psSvIiEje2bvU9+A0bfA++ffugI0p8SvIiEhc2bYJu3WDdOvjXv3zhMYmckryIBG7DBjj3XNiyxc+J79o16IiSh5K8iATq6699gt+xA2bPhi5dgo4ouSjJi0hg1qzxCX73bpg7VwXHYiHiKZRmdryZvWtmmWa2wswGhe6vY2ZzzGx16LJ25OGKSLJYvBh+8QvYs8fPolGCj41ozJPPBX7vnGsFdAZuN7PWwD3AXOfcicDc0G0REd56C84+G9LS4IMPoG3boCNKXhEneefcZufcktD1XUAm0BjoA7wc2uxloG+k+xKRxPfCC37Rj5YtYcECOPnkoCNKblE949XMmgPtgYVAQ+fcZvAfBECDaO5LRBKLc3D//XDzzf5kp/nz4dhjg44q+UUtyZtZdeB1YLBz7oejeN5AM8sws4ysrKxohSMicSQnB667Dh56CG68EWbMULGxshKVJG9mqfgEP9E598/Q3VvNrFHo8UbAtqKe65wb55xLd86l169fPxrhiEgc2bABzjoLXnkF/vQneP55SE0NOqryIxqzawx4Ech0zj1R6KEZwIDQ9QHA9Ej3JSKJ5b334PTTYdUqePNN+H//Twt+lLVo9OTPBK4BzjWzpaF2ATASON/MVgPnh26LSDngHDz9tC9TULcuLFqkNVmDEvHJUM65D4HiPpu7Rfr6IpJYdu2C3/wGJk6Evn3h5Zd9TXgJhurJi0jULFzoK0dOmuQPsr7+uhJ80JTkRSRieXkwYoSvHLl/vx+Lv+8+qKAMEzjVrhGRiGzYAFdfDe+/D/37w7PPQq1aQUclYfqcFZFSyc/30yFPPRWWLPFj76+9pgQfb5TkReSorVoFv/wlDBzox+CXLoVrr9X0yHikJC8iJbZvnx97b9sW/vc/X4dm3jz42c+CjkyKozF5ESmRefNg0CBYvhwuuwzGjFHtmUSgnryIHNYXX/iqkd26+cU9pk+HqVOV4BOFkryIFGnHDhg8GE45xVeMfPRRyMz0CV8Sh4ZrROQgP/wAf/kL/PnP8P33cNNNvrBYw4ZBRyaloSQvIoBP6GPGwJNPwnffQa9e8MgjfoqkJC4leZFy7ttvYexYn9x37vTDMfff76tHSuJTkhcppz77zFeKnDjRL+rRp49P7h06BB2ZRJOSvEg5kpsL//mPT+7vvgtVqviTmO680x9gleSjJC9SDqxY4csO/P3vsGULHH88jBzp11utUyfo6CSWlORFktSWLTBtmk/uGRlQsSJccIFfa/Wii/xtSX76NYskka++gjfe8O3jj/0KTe3awVNPwRVXQIMGQUcoZU1JXiSB5eT4ZP7OO/DWW7Bsmb+/bVt48EG45BJo0ybQECVgSvIiCWTvXvj0U1+7/Z134IMPfKJPSYEuXWD0aLj4YjjhhKAjlXihJC8Sp/Ly4Msvfa32BQt8+/RTXwkSfA/91lvhvPOga1eoUSPYeCU+KcmLBGzfPli71if0zEw/f/2zz2DlStizx29TtSp07Ah33QWdO/umAmFSEjFP8mbWE3gaSAFecM6NjPU+ReKBc5CdDdu3w9atsGlTQfvmG1i3zif2DRv8Kkthxx7rSwn85jf+sm1bf6nZMFIaMf2zMbMU4BngfGAj8F8zm+GcWxnL/UrZ2bPHJ7CtW30y27nz4LZrl0902dl+2+xsP4a8f78/MWf/ft/y8vzrOXfw66ek+MWgU1IKWmqqT3iFL1NToVIl3wrfPvR64eeFW4UKBc3MXzrnE2/hy337Ctrevf4yO9uX3929G3780V/u3Ol/Fjt2+O0OVaGCT+RNm8IvfuEX3Ai3k06C+vVj+iuTcibWfYNOwBrn3FcAZjYZ6ANEN8nv3OmnFjRp4s/yaNzY/1dLRJyDjRv9tLy1aw9umzb5edg//FD886tUgerVoVo1f71qVd+qVStIvuGkm5JSsHRc+DKcYPPyDm7hD4fcXJ9Ed+/2t/ftK7gMXw+38O1oMXNUruSompZP9ar5VK+SR/UqeVRLy+PEurl0bpFLnRr7qXPMfuoek0uD2vtpXH8vx9XbT4M6uaRUNP9Gw59c4U+yrBTYmVrwAwq3ypV9S0mJ3puQciHWSb4xsKHQ7Y3A/0V9L5mZcNVVB9/XsKFP+I0aHdyOO84/1qCBb1WrRj2cRLR5Myxe7M+MXLnS/0gzM30CDTPzn5/Nmvm518ce63+U4ct69fwizuFWuXIUAtu/v+ArwKFfCQpfHtpycn5y6bL3kL93P7l79pObk0tuTi77c/Jw+/aTvy/3oFZhXw4VyMNwVCAfw1GJfQdaisuHvfj2fRTeZ0lVrOh/sGlp/pMz/OkZvqxRw3+yhluNGnDMMVCzZkGrVQtq1/anutas6T9gJGnFOskXtazvQV/IzWwgMBCgadOmpdtLhw4+I23Y4NvGjQXX163z0xKysop+brVq/vtxvXr+j75uXd/q1PH/CIWzVs2a/h+mRg3f0tIScuXiXbvgk09g0SJ/JmRGhh8jDjvuOGjdGm64AVq1gp//HJo395+ZlSvju9PFJdVvsmFNEUn40AR9aCvu/tzc0r3JSpV84gsnw7Q0rHJlUqpUIaVyZSrXqQxpNQt6yJUr++dUrlzQcy5qzKfwOE9KSsFl4d54eNyncAP/1STcoOBrSuHLwmNYhb+GhMeIwi388y78s/vxRz9OFB4/Co8hHY6Z/9uuU8f/D9Sv71uDBv7y2GMPbnXqJOTffHlm7tBB0Gi+uFkX4EHnXI/Q7WEAzrlHi9o+PT3dZWRkxCaYffv8wPHmzbBtm29ZWf4yPKC8Y4e/DA8uH0lKik/21ar5Fh6PCLdwjystrSCRHJo4wmMV4VZ4kDg8QFxUoghfL5wgCo9t5OYeaLt+rMBHXzdi/lfNmL+uBRlbm5Dn/Nf+lsdsIr3Wl6Qf8wXp1TI5tdIqaubtKEgk4cvCrbSJt0KFn47dhFuVKj/9GR66XeGea+HHw/eHW1qahjXC8vL8p/r33xe0nTt9wfgdO3z77jv/N5+VdXAranwrNdX3Apo0KWiNG/sDDM2a+d5A3br6IChjZrbYOZde5GMxTvIVgS+AbsA3wH+BK51zK4raPqZJ/mjl5voB5/A/ReEjiYe2H38suhd6aILcu7egdxZja/gZ/+Ii/sVFfMBZ5JJKKvvoVCGDc1I/5uy0hXSqtoKaaXsP7sEW/kAKXw9/UIUTaOHrhRProUMH4cfCiT01Vf/8icI5/7e/das/+BJumzf7r30bNxa0nJyDn1u1qk/4J5xw8FHln/0MWrSI0jieFHa4JB/T4RrnXK6Z3QG8jZ9COb64BB93Klb0X01jUaLPuZ9OLSncA8/LO/irfXh6R+Ejk+HroeEBVyGFJcsrMeVfVZjxdhqr1viebOtW+fzuQji/B3TpUolq1c4Azoj+e5LkEh7GqVULWrYsfjvn/LeB9ev90Ojatf5y3Tp/xH7+/IOHjCpU8Im+ZcuC1qqVr3Nct25s31M5FdOe/NGKq558gvjyS7/ow2uvwapVvrN89tm+ymCvXjq9XQLmnB8S/fJL3774wv+hrlrlrxf+FtCwoU/2p5ziTwxo186f1lulSmDhJ4rAevISG9nZMGkSPP88LFzo7zv7bPj97+HSS1UfXOKImU/eDRvCGYd8g8zP95MjMjP9tK4VK2D5chg/vqD3X6GCP3mgXTvf0tP9RIvatcv6nSQs9eQTyKpV8OyzMGGCHy5t08av6tO/v5/5IpIU8vP9sM/Spb6s5rJl/vq6dQXb/OxnPuF37OhrPHToUK57/IEdeD1aSvI/5RzMnAlPPAFz5/rhmMsu86e8n3mmjmNKObJ9u6/WFp73m5HhjwWA/8do184n/DPO8KcSN2kSaLhlSUk+AeXl+VV9Hn3Ud2SOP94n9htu8N98RQQ/42fhwoIynYsW+fFM8NM5zzrLt65d/bBPkvaKlOQTyL59fh3Oxx6D1avh5JPhnnvgyit9Z0VEDiM3F/73P19o/4MPfOH98ImQjRrBuecWtObNAw01mpTkE0B+PkydCsOH+5lnp58O994LffvqrHORUnPOz+J57z14912YN8/P9gE/lbNHD9/OPdefzZ6glOTj3Lx5MGSIrx1z2ml+iOZXv0rab5YiwXHOF2eaN88vrTVvni//ULGiX1qrZ0+48EL/j5hA/4BK8nFqxQr4wx9g1iw/5v7ww77Oms7IFykj+/b5RXLfftu3Tz/19x9/vD/R5KKL4Je/9Gd0xzEl+Tjz44/wpz/5GTPVq/shmjvuiPu/I5Hkt2UL/Oc/8O9/w+zZ/iBu1aq+h3/JJb6XX6tW0FH+hJJ8HJk+He6808/8uv56ePxxX/xPROJMTo4vyzBjBrz5pq/bk5oK3br5hN+3b9ys8HK4JK9DemVk/Xro08f/XdSo4Q/8jx+vBC8St9LSfA9+7FhfiO3jj2HwYD/tbeBAP1unRw946SVfyTNOKcnHmHP+DNU2bfxxnscf98N+v/hF0JGJSIlVqOAPzD7+uE/yS5f62RJr1hScvHLRRTB5csHq63FCST6Gtm3z3+quv96fjLd8Odx9t+a7iyQ0M7+6+iOP+CS/aJEfg126FK64wif8G27wUzYLr9AeECX5GJkxwxfSe+stGD3a/75btAg6KhGJKjNfP2f0aF9bZ9486NfPn64ePuHqvvv8yS8BUZKPsj17/HBdnz5+AZ3Fi311SE2LFElyFSr46Zbjx/tZOpMm+XHaRx/1BdXOPdfXBS/j4Rwl+Sj66itfNOz5530pgoUL/e9YRMqZqlV9edi33vI9/Ice8pU1r77a9/7uvNOflFUGlOSjZMYMX+107Vo/xfbRR/2qeiJSzjVp4ods1qzxpWQvuAD+9je/OMo558CUKf6krBhRko9Qbq7vtffpAz//ua+EeuGFQUclInGnQoWCIZuNG2HkSD+3OrwgxKhRsdltTF61nPjuOz9N9rHH4JZb4MMPk6qwnYjESv36MHSo792/9Rb83//FbK69lv8rpS+/9D32r77y50Jcd13QEYlIwqlQwVcj/NWv/Ek1MaAkXwoffujPXHXOn+DUtWvQEYlIwotR1cuIhmvMbJSZfW5m/zOzN8ysVqHHhpnZGjNbZWY9Io40Tkyc6EtX1K3rF6JRgheReBbpmPwcoI1z7jTgC2AYgJm1BvoDpwA9gbFmltAzxZ3zlSOvvtovIfnJJ3DiiUFHJSJyeBEleefcbOdcbujmAiC8cm4fYLJzbq9z7mtgDdApkn0FKT/f1yV64AG49lpfdrpOnaCjEhE5smjOrrkBmBm63hjYUOixjaH7fsLMBppZhpllZIXXYowjublw440wZgzcdZcvNqb57yKSKI6Y5M3sHTNbXkTrU2ib4UAuMDF8VxEvVeShY+fcOOdcunMuvX6c1GYO27vXT2GdMAEefBD+/OeEWhFMROTIs2ucc+cd7nEzGwD0Arq5ghVINgLHF9qsCbCptEEGITvbV5B8+2148kk/XCMikmginV3TExgK9HbOZRd6aAbQ38wqm1kL4ERgUST7Kks//uinrc6ZAy++qAQvIokr0nnyfwUqA3PMj2MscM7d6pxbYWZTgZX4YZzbnXN5Ee6rTOzZA717+7nwr70Gl18edEQiIqUXUZJ3zv38MI+NAEZE8vplbe9eP0Tz7rvw978rwYtI4tMZryH798Ovfw2zZsELL8BVVwUdkYhI5FSgDD9N8qqrfLngZ57xUyZFRJJBuU/y+fl+OcZ//MNPkbzttqAjEhGJnnKf5IcO9ePvDz0Ev/td0NGIiERXuU7yY8b49Xdvvx2GDw86GhGR6Cu3Sf711/3897594emndSariCSncpnkP/rIH2jt3NnPhU9J6PqYIiLFK3dJ/vPP4aKLoFkzP5umSpWgIxIRiZ1yleSzsny5gtRUmDkT6tULOiIRkdgqNydD7d8P/frBli3w/vtwwglBRyQiEnvlJskPHuyT+8SJ0LFj0NGIiJSNcjFcM24cjB0LQ4bAlVcGHY2ISNlJ+iT/wQd+HnzPnvDII0FHIyJStpI6ya9fD5deCi1awKRJmiopIuVP0ib5PXvg4oshJwemT4datYKOSESk7CXtgdfBg2HJEj8XvlWroKMREQlGUvbkJ03yB1uHDPEnPomIlFdJl+S/+AIGDoQzzoCHHw46GhGRYCVVks/J8as7VaoEkyf7M1tFRMqzpBqTv+suWLYM/v1vOP74oKMREQle0vTkp0yB556Du++GCy8MOhoRkfgQlSRvZn8wM2dm9QrdN8zM1pjZKjPrEY39FGf1arj5ZujSBUaMiOWeREQSS8TDNWZ2PHA+sL7Qfa2B/sApwHHAO2Z2knMuL9L9FaViRZ/gn39e4/AiIoVFoyf/JDAEcIXu6wNMds7tdc59DawBOkVhX0Vq0QLefhuaNo3VHkREElNESd7MegPfOOeWHfJQY2BDodsbQ/cV9RoDzSzDzDKysrIiCUdERA5xxOEaM3sHOLaIh4YD9wLdi3paEfe5Iu7DOTcOGAeQnp5e5DYiIlI6R0zyzrnzirrfzE4FWgDLzK+C3QRYYmad8D33wpMYmwCbIo5WRESOSqmHa5xznznnGjjnmjvnmuMTewfn3BZgBtDfzCqbWQvgRGBRVCIWEZESi8nJUM65FWY2FVgJ5AK3x2pmjYiIFC9qST7Umy98ewSgWesiIgFKmjNeRUTkp5TkRUSSmDkXP7MWzSwLWBfBS9QDvo1SOEFKlvcBei/xKFneB+i9hDVzztUv6oG4SvKRMrMM51x60HFEKlneB+i9xKNkeR+g91ISGq4REUliSvIiIkks2ZL8uKADiJJkeR+g9xKPkuV9gN7LESXVmLyIiBws2XryIiJSSFIleTN7yMz+Z2ZLzWy2mR0XdEylZWajzOzz0Pt5w8xqBR1TaZnZZWa2wszyzSzhZkKYWc/QCmdrzOyeoOMpLTMbb2bbzGx50LFEysyON7N3zSwz9Lc1KOiYSsPM0sxskZktC72PP0Z9H8k0XGNmxzjnfghdvxNo7Zy7NeCwSsXMugPznHO5ZvYYgHNuaMBhlYqZtQLygb8Bf3DOZQQcUomZWQrwBX71s43Af4ErnHMrAw2sFMysK7AbeMU51yboeCJhZo2ARs65JWZWA1gM9E2034v5Er7VnHO7zSwV+BAY5JxbEK19JFVPPpzgQ6pRTA37ROCcm+2cyw3dXIAv15yQnHOZzrlVQcdRSp2ANc65r5xz+4DJ+JXPEo5z7n1gR9BxRINzbrNzbkno+i4gk2IWJopnztsdupkaalHNW0mV5AHMbISZbQCuAu4POp4ouQGYGXQQ5VSJVzmTYJhZc6A9sDDgUErFzFLMbCmwDZjjnIvq+0i4JG9m75jZ8iJaHwDn3HDn3PHAROCOYKM9vCO9l9A2w/HlmicGF+mRleS9JKgSr3ImZc/MqgOvA4MP+SafMJxzec65dvhv653MLKpDaTGpJx9Lxa1UVYTXgP8AD8QwnIgc6b2Y2QCgF9DNxfnBk6P4vSQarXIWp0Jj2K8DE51z/ww6nkg553aa2XygJxC1g+MJ15M/HDM7sdDN3sDnQcUSKTPrCQwFejvnsoOOpxz7L3CimbUws0pAf/zKZxKg0AHLF4FM59wTQcdTWmZWPzxzzsyqAOcR5byVbLNrXgda4mdyrANudc59E2xUpWNma4DKwPbQXQsSeKbQxcBfgPrATmCpc65HoEEdBTO7AHgKSAHGhxbESThmNgk4B1/tcCvwgHPuxUCDKiUz+wXwAfAZ/v8d4F7n3FvBRXX0zOw04GX831YFYKpz7k9R3UcyJXkRETlYUg3XiIjIwZTkRUSSmJK8iEgSU5IXEUliSvIiIklMSV5EJIkpyYuIJDEleRGRJPb/AT43r/9qplA4AAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出更新之前的模型\n",
- "y_pred = multi_linear(x_train)\n",
- "\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以发现,这两条曲线之间存在差异,计算一下他们之间的误差"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(1144.2654, grad_fn=<MeanBackward0>)\n"
- ]
- }
- ],
- "source": [
- "# 计算误差,这里的误差和一元的线性模型的误差是相同的,前面已经定义过了 get_loss\n",
- "loss = get_loss(y_pred, y_train)\n",
- "print(loss)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 自动求导\n",
- "loss.backward()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor([[ -94.7455],\n",
- " [-139.1247],\n",
- " [-629.8584]])\n",
- "tensor([-25.7413])\n"
- ]
- }
- ],
- "source": [
- "# 查看一下 w 和 b 的梯度\n",
- "print(w.grad)\n",
- "print(b.grad)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 更新一下参数\n",
- "w.data = w.data - 0.001 * w.grad.data\n",
- "b.data = b.data - 0.001 * b.grad.data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7ac392850>"
- ]
- },
- "execution_count": 25,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAApTklEQVR4nO3deXxU1f3/8deHEHZlCyD78i0iyC7wBRfUougXKeBWccVdiwvYVpDiz6WWVsWVKlpUBBVBilXSKsgmLggioFTZBNllCyAIhC3J+f1xJiRgAklmJndm8n4+HvdxZ+7cmfuZLJ85c+65n2POOUREJDGVCjoAERGJHiV5EZEEpiQvIpLAlORFRBKYkryISAJTkhcRSWAFTvJmNtrMtpnZd7m2VTOz6Wa2MrSumuuxIWa2ysxWmNlFkQ5cREROrDAt+THAxcdsewCY6ZxrCswM3cfMWgB9gdNDzxlpZklhRysiIoVSuqA7Ouc+NbNGx2zuDZwXuj0WmA0MDm2f4Jw7CKwxs1VAJ2Du8Y6RkpLiGjU69hAiInI8Cxcu3O6cq5HXYwVO8vmo5ZzbDOCc22xmNUPb6wLzcu23MbTtuBo1asSCBQvCDElEpGQxs3X5PRatE6+Wx7Y86yeY2e1mtsDMFqSlpUUpHBGRkincJL/VzGoDhNbbQts3AvVz7VcP2JTXCzjnRjnnOjjnOtSokee3DRERKaJwk3wq0C90ux8wOdf2vmZW1swaA02B+WEeS0RECqnAffJmNh5/kjXFzDYCDwOPAxPN7BZgPXAlgHNuiZlNBJYCGcBdzrnMogR4+PBhNm7cyIEDB4rydAlTuXLlqFevHsnJyUGHIiJFYLFUarhDhw7u2BOva9as4aSTTqJ69eqY5dXVL9HinGPHjh3s2bOHxo0bBx2OiOTDzBY65zrk9VjMX/F64MABJfiAmBnVq1fXtyiROBbzSR5Qgg+QfvYi8S0ukryISCIbMQJSU6Pz2kryBTBixAiaN2/OtddeS2pqKo8//jgA77//PkuXLj2y35gxY9i0KWek6K233nrU4yIix9q1C4YMgcmTT7hrkYR7xWuJMHLkSKZMmXLk5GOvXr0An+R79uxJixYtAJ/kW7ZsSZ06dQB49dVXgwk4l8zMTJKSVDZIJFaNGQPp6XDXXdF5fbXkT+DOO+9k9erV9OrVi2effZYxY8Zw991388UXX5Camsr9999P27ZteeKJJ1iwYAHXXnstbdu2Zf/+/Zx33nlHyjRUqlSJoUOH0qZNGzp37szWrVsB+OGHH+jcuTMdO3bkoYceolKlSnnG8cYbb9C6dWvatGnD9ddfD8CNN97IpEmTjuyT/dzZs2dz/vnnc80119CqVSsGDx7MyJEjj+z3yCOP8PTTTwMwfPhwOnbsSOvWrXn44Ycj/wMUkXxlZcGLL8KZZ0L79tE5Rny15AcOhG++iexrtm0Lzz2X78Mvv/wyU6dO5eOPPyYlJYUxY8YAcOaZZ9KrVy969uzJFVdcAcCUKVN46qmn6NDhlyOZ9u3bR+fOnRk2bBiDBg3ilVde4cEHH2TAgAEMGDCAq6++mpdffjnPGJYsWcKwYcOYM2cOKSkp7Ny584Rva/78+Xz33Xc0btyYr7/+moEDB9K/f38AJk6cyNSpU5k2bRorV65k/vz5OOfo1asXn376KV27dj3h64tI+KZPh1Wr4NFHo3cMteSLSZkyZejZsycAZ5xxBmvXrgVg7ty5XHnllQBcc801eT531qxZXHHFFaSkpABQrVq1Ex6vU6dOR7qX2rVrx7Zt29i0aROLFy+matWqNGjQgGnTpjFt2jTatWtH+/btWb58OStXrgz3rYpIAb3wAtSqBaF2YlTEV0v+OC3uWJecnHxkOGJSUhIZGRkFfq5zLs+hjKVLlyYrK+vIPocOHTryWMWKFY/a94orrmDSpEls2bKFvn37HnnOkCFDuOOOOwr9fkQkPKtXwwcfwIMPQpky0TuOWvJhOOmkk9izZ0++9wuic+fOvPvuuwBMmDAhz326devGxIkT2bFjB8CR7ppGjRqxcOFCACZPnszhw4fzPU7fvn2ZMGECkyZNOtK9dNFFFzF69Gj27t0LwI8//si2bdvyfQ0RiZyXXoJSpSDabSwl+TD07duX4cOH065dO3744QduvPFG7rzzziMnXgviueee45lnnqFTp05s3ryZypUr/2Kf008/naFDh3LuuefSpk0bfv/73wNw22238cknn9CpUye+/PLLX7Tej32NPXv2ULduXWrXrg1A9+7dueaaa+jSpQutWrXiiiuuKPSHlIgUXno6vPYaXHop1D3hTBvhifnaNcuWLaN58+YBRRR96enplC9fHjNjwoQJjB8/nsnRGjBbRIn+OxApbqNHwy23wOzZcO654b/e8WrXxFeffAJauHAhd999N845qlSpwujRo4MOSUSiyDn4+9+hZUsojoFsSvIBO+ecc1i8eHHQYYhIMZk7148Ef+klKI7SUOqTFxEpRi+8ACefDNddVzzHU5IXESkmP/4I//wn3HQT5HNxe8QpyYuIFJMXXvClDO69t/iOqSQvIlIM9u6Fl1/2wyabNCm+4yrJF4NGjRqxffv2oMMQkQCNGePLCv/hD8V7XCX5QnDOHSkjoDhEpKAyM+HZZ6FzZ+jSpXiPrSR/AmvXrqV58+b079+f9u3bs2HDhnzL8/bp04czzjiD008/nVGjRp3wtadOnUr79u1p06YN3bp1A3wZ4KeeeurIPi1btmTt2rW/iOOxxx5j0KBBR/YbM2YM99xzDwBvvfUWnTp1om3bttxxxx1kZmZG6schIkWQmupr1RR3Kx4iNE7ezO4DbgUc8C1wE1ABeAdoBKwFfuuc+ymc4wRQaRiAFStW8PrrrzNy5MjjlucdPXo01apVY//+/XTs2JHLL7+c6tWr5/maaWlp3HbbbXz66ac0bty4QOWDc8eRlpZGly5dePLJJwF45513GDp0KMuWLeOdd95hzpw5JCcn079/f8aNG8cNN9xQyJ+MiETKM89Ao0bQp0/xHzvsJG9mdYF7gRbOuf1mNhHoC7QAZjrnHjezB4AHgMHhHi8IDRs2pHPnzgBHlecF2Lt3LytXrqRr166MGDGC9957D4ANGzawcuXKfJP8vHnz6Nq165FywAUpH5w7jho1atCkSRPmzZtH06ZNWbFiBWeddRYvvvgiCxcupGPHjgDs37+fmjVrhvcDEJEimz8fPv/cNyZLB3D5aaQOWRoob2aH8S34TcAQ4LzQ42OB2YSZ5IOqNJy78Fd+5Xlnz57NjBkzmDt3LhUqVOC8887jwIED+b5mQcoHA0e9xrEFyK666iomTpzIaaedxqWXXoqZ4ZyjX79+/O1vfyv0+xSRyHvmGahcGW6+OZjjh90n75z7EXgKWA9sBnY756YBtZxzm0P7bAbybE6a2e1mtsDMFqSlpYUbTtTlV5539+7dVK1alQoVKrB8+XLmzZt33Nfp0qULn3zyCWvWrAGOLh+8aNEiABYtWnTk8bxcdtllvP/++4wfP56rrroK8GWJJ02adKRk8M6dO1m3bl14b1pEimTdOpg0CW6/HU46KZgYItFdUxXoDTQGdgH/NLMCX7DrnBsFjAJfhTLceKKte/fuLFu2jC6hU+SVKlXirbfe4uKLL+bll1+mdevWNGvW7Ei3Sn5q1KjBqFGjuOyyy8jKyqJmzZpMnz6dyy+/nDfeeIO2bdvSsWNHTj311Hxfo2rVqrRo0YKlS5fSqVMnAFq0aMFf/vIXunfvTlZWFsnJybz44os0bNgwcj8EESmQESN8fZrQmIhAhF1q2MyuBC52zt0Sun8D0BnoBpznnNtsZrWB2c65Zsd7rZJYajge6HcgUni7dkGDBtCzJ7z9dnSPdbxSw5EYQrke6GxmFcx3MncDlgGpQL/QPv2A2CqSLiISRSNHwp49MDjg4SZhd9c45740s0nAIiAD+Brf/VIJmGhmt+A/CK4M91giIvEgPd1f/NSjB7RpE2wsERld45x7GHj4mM0H8a36SLx+niNRJPpiaeYwkXjx6quwfTv86U9BRxIHV7yWK1eOHTt2KNkEwDnHjh07KFeuXNChiMSNQ4dg+HA/69NZZwUdTRzMDFWvXj02btxIPAyvTETlypWjXr16QYchEjfeegs2bvSt+VgQ80k+OTn5yFWhIiKxLDMTHn8c2reH7t2DjsaL+SQvIhIv3n0XVq70F0DFymnEmO+TFxGJB87BX/8Kp53mJwaJFWrJi4hEwJQpsHgxvP46lIqh5nMMhSIiEp+yW/ENGsC11wYdzdHUkhcRCdOsWTBnDvz975CcHHQ0R1NLXkQkDM7BQw9B3bpw661BR/NLasmLiIRh2jT44gtfqyYWrxtUS15EpIiyW/ENGgQ3KciJqCUvIlJEH37op/d75RUoWzboaPKmlryISBFkt+IbN4Z+/U68f1DUkhcRKYLUVFi0yI+Lj7URNbmpJS8iUkhZWfDww/CrX8F1BZ7sNBhqyYuIFNJ77/mrW998E0rHeBZVS15EpBCyW/HNmsHVVwcdzYnF+GeQiEhsGTcOliyB8eMhKSnoaE5MLXkRkQI6cAAefBDOOAN++9ugoykYteRFRAroxRdh/XoYPTq2Kk0eT5yEKSISrJ9+gmHD4KKLoFu3oKMpuIgkeTOrYmaTzGy5mS0zsy5mVs3MppvZytC6aiSOJSIShMcfh1274Ikngo6kcCLVkn8emOqcOw1oAywDHgBmOueaAjND90VE4s6GDfD883D99dCmTdDRFE7YSd7MTga6Aq8BOOcOOed2Ab2BsaHdxgJ9wj2WiEgQHnrIrx97LNg4iiISLfkmQBrwupl9bWavmllFoJZzbjNAaF0zryeb2e1mtsDMFqSlpUUgHBGRyPn2Wxg7Fu65x1ebjDeRSPKlgfbAS865dsA+CtE145wb5Zzr4JzrUKNGjQiEIyISOQ88AJUrw5AhQUdSNJFI8huBjc65L0P3J+GT/lYzqw0QWm+LwLFERIrNzJm+nPCf/gTVqgUdTdGEneSdc1uADWbWLLSpG7AUSAWyC3D2AyaHeywRkeKSkQEDBkCTJr6rJl5F6mKoe4BxZlYGWA3chP8AmWhmtwDrgSsjdCwRkah76SVfvuC992JzWr+CikiSd859A3TI46E4umRARMTbvt2PqLngAujdO+howqMrXkVEjvHgg7Bnjx8bbxZ0NOFRkhcRyeWbb2DUKLj7bmjRIuhowqckLyIS4hzcey9Urw6PPBJ0NJGhKpQiIiETJ8Jnn8E//gFVqgQdTWSoJS8iAqSnw/33Q7t2cMstQUcTOWrJi4jg69Js2OBnfoqHGZ8KSi15ESnxvv0WnnoKbroJzjkn6GgiS0leREq0rCy44w7fBz98eNDRRJ66a0SkRBs1CubO9ZUmq1cPOprIU0teREqszZt9lclf/9pPCJKIlORFpMS67z44cMDXqYn3K1vzoyQvIiXSlCnwzjswdCicemrQ0USPkryIlDjp6dC/P5x2GgwaFHQ00aUTryJS4gwdCmvXwiefQNmyQUcTXWrJi0iJ8umnvrrkXXdB165BRxN9SvIiUmLs3esveGrcGJ54Iuhoioe6a0SkxBg8GNas8d00FSsGHU3xUEteREqEmTNh5EgYODDxShccj5K8iCS8n3+Gm2+GZs1g2LCgoyle6q4RkYT3hz/Axo0wZw6ULx90NMVLLXkRSWgffgivvuprxXfuHHQ0xS9iSd7MkszsazP7T+h+NTObbmYrQ+uqkTqWiEhBbN4MN94IrVrBo48GHU0wItmSHwAsy3X/AWCmc64pMDN0X0SkWGRlwQ03+GGT77yT+Bc95SciSd7M6gGXAK/m2twbGBu6PRboE4ljiYgUxPDhMGMGjBgBzZsHHU1wItWSfw4YBGTl2lbLObcZILSumdcTzex2M1tgZgvS0tIiFI6IlGRffgkPPghXXplY87UWRdhJ3sx6AtuccwuL8nzn3CjnXAfnXIcaNWqEG46IlHC7d8PVV0Pdun5CkEQtIVxQkRhCeRbQy8x6AOWAk83sLWCrmdV2zm02s9rAtggcS0QkX87BnXfC+vW+Rk2VKkFHFLywW/LOuSHOuXrOuUZAX2CWc+46IBXoF9qtHzA53GOJiBzPa6/BhAl+JM2ZZwYdTWyI5jj5x4ELzWwlcGHovohIVHz1Fdx9N1xwgZ/ST7yIXvHqnJsNzA7d3gF0i+Tri4jkZds2uPxyOOUUGD8ekpKCjih2qKyBiMS1jAzo2xfS0nzZgpSUoCOKLUryIhLXhgyBjz+GMWOgffugo4k9ql0jInFr4kR46ik/y1O/fifevyRSkheRuPTdd7588FlnwTPPBB1N7FKSF5G4s3Ur/OY3cNJJ8M9/QpkyQUcUu9QnLyJxJT0devXyI2o++QRq1w46otimJC8icSMrC66/3o+Jf+896NAh6Ihin5K8iMSNBx6Af/0Lnn0WevcOOpr4oD55EYkL//iHLx98110wYEDQ0cQPJXkRiXlTp/rk3qMHPPecKksWhpK8iMS0L77wJQtatvTFx0qrk7lQlORFJGZ9841vvdepAx995IdMSuEoyYtITPr+e+jeHU4+2U/jV6tW0BHFJyV5EYk569f7ksEA06dDw4bBxhPP1LslIjFl61a48EL4+WdfeKxZs6Ajim9K8iISM7Zu9S34DRt8C75du6Ajin9K8iISEzZtgm7dYN06+Pe/feExCZ+SvIgEbsMG+PWvYcsWPya+a9egI0ocSvIiEqg1a3yC37kTpk2DLl2CjiixKMmLSGBWrfIJfu9emDlTBceiIewhlGZW38w+NrNlZrbEzAaEtlczs+lmtjK0rhp+uCKSKBYuhLPPhv37/SgaJfjoiMQ4+QzgD8655kBn4C4zawE8AMx0zjUFZobui4jw4Ydw7rlQrhx89hm0aRN0RIkr7CTvnNvsnFsUur0HWAbUBXoDY0O7jQX6hHssEYl/r77qJ/1o1gzmzYPTTgs6osQW0StezawR0A74EqjlnNsM/oMAqBnJY4lIfHEOHnoIbrvNX+w0ezacckrQUSW+iCV5M6sEvAsMdM79XIjn3W5mC8xsQVpaWqTCEZEYcuAA3HgjPPYY3HILpKaq2FhxiUiSN7NkfIIf55z7V2jzVjOrHXq8NrAtr+c650Y55zo45zrUqFEjEuGISAzZsAHOOQfeeAP+/Gd45RVITg46qpIjEqNrDHgNWOaceybXQ6lAv9DtfsDkcI8lIvHlk0/gjDNgxQp4/334f/9PE34Ut0i05M8Crgd+bWbfhJYewOPAhWa2ErgwdF9ESgDn4PnnfZmC6tVh/nzNyRqUsC+Gcs59DuT32dwt3NcXkfiyZw/87ncwbhz06QNjx/qa8BIM1ZMXkYj58ktfOXL8eH+S9d13leCDpiQvImHLzIRhw3zlyMOHfV/8gw9CKWWYwKl2jYiEZcMGuO46+PRT6NsXXnoJqlQJOirJps9ZESmSrCw/HLJVK1i0yPe9v/22EnysUZIXkUJbsQLOPx9uv933wX/zDdxwg4ZHxiIleREpsEOHfN97mzbw3//6OjSzZsH//E/QkUl+1CcvIgUyaxYMGADffQdXXgkjRqj2TDxQS15Ejuv7733VyG7d/OQekyfDxIlK8PFCSV5E8rRzJwwcCKef7itG/u1vsGyZT/gSP9RdIyJH+fln+Pvf4emnYfduuPVWX1isVq2gI5OiUJIXEcAn9BEj4Nln4aefoGdP+Otf/RBJiV9K8iIl3PbtMHKkT+67dvnumIce8tUjJf4pyYuUUN9+6ytFjhvnJ/Xo3dsn9/btg45MIklJXqQEyciADz7wyf3jj6F8eX8R0733+hOskniU5EVKgCVLfNmBN9+ELVugfn14/HE/32q1akFHJ9GkJC+SoLZsgUmTfHJfsABKl4YePfxcq7/5jb8viU+/ZpEEsno1vPeeX774ws/Q1LYtPPccXH011KwZdIRS3JTkReLYgQM+mc+YAR9+CIsX++1t2sAjj8Bll0HLloGGKAFTkheJIwcPwtdf+9rtM2bAZ5/5RJ+UBF26wFNPwaWXQpMmQUcqsUJJXiRGZWbCDz/4Wu3z5vnl6699JUjwLfQ774QLLoCuXeGkk4KNVwpp717YvBk2bfLrWrV8/eYIU5IXCdihQ7B2rU/oy5b58evffgtLl8L+/X6fChWgY0e47z7o3NkvKhAWY7Ky/KXCaWlHL1u3/nLZvNkn+dwuvzw+k7yZXQw8DyQBrzrnHo/2MUVigXOQng47dvj/602bcpYff4R163xi37DB54dsp5ziSwn87nd+3aaNX2s0TDHIyPDFe3bvzlnv3u2T908/+UuCs2/v3Ol/udnLTz8d/YvMrWpV31KvVcvPstKjB9SpA7Vr56zr1o3KW4rqn42ZJQEvAhcCG4GvzCzVObc0mseV4rN/f07jZMcO/z+Qe9mzxye69HS/b3q670M+fNj/Px0+7JfMTP96zh39+klJfjLopKScJTnZJ7zc6+RkKFPGL7nvH3s79/Oyl1KlchYzv3bO/7/mXh86lLMcPOjX6em+QbZ3L+zb59e7dvmfxc6dfr9jlSrlE3mDBnD22X7Cjezl1FOhRo2o/sriU2am/2Hmtezf7/+oDhzIuZ39R5d7yf4F7duXc3vPnqOX7K9Ox3PyyX6Ow+rV/UUG9evn3K5RI2dJSfHrmjX9H2BAot026ASscs6tBjCzCUBvQEk+DjgHGzf6YXlr1x69bNrkx2H//HP+zy9fHipVgooV/e0KFfxSsWJO8s1OuklJOVPHZa+zE2xm5tFL9odDRob/H9+7198/dChnnX07e8m+HylmjrJlHBXKZlKpfCaVymVQqVwGFctm0LTSITqfcpBqFQ5SreIBqlc4QM1K6dStvJc6J++lZsV9JJHp31z2sjMLtmfBvDw+XU60ZP+wjr1/7Pp4+xZkyR1Pdty5b2dm/nKd+5d27O2MjKM/6Y/95M/9qZpfC7mgypb1f3jZf5DZ6/r1/cmM3Evlyj6RV66cs1Sp4lvjlSvH3VeqaEdbF9iQ6/5G4H+jfEwpgs2bYeFCf2Xk0qW+b3jZsqO7Dc38N8qGDf3Y61NO8d8+s9cpKf5/IXspWzYKgWZlHd0C+/lnv85umeVuUh/TknP70snaf5CM/YfJSD9Exv7DHD6QiTtwkKxDGWQdPOyXQ4cpRRalyMJwR9ZlOHRkSXJZcBC/HOeDrliZHf1JeeynZl6P53U/ryX7a0727dxfe8yO/sqV/Xj2V6/SpY++nZzsP/Wztx/7iZ+9lC2b8/Useylb9pdL+fJQrlzOuly5nBZFhQp+e1JS8f4uYki0k3xe0/oe9YXczG4Hbgdo0KBBlMMR8Dlx7lyYP99fCblgge8jzlanDrRoATffDM2bw69+BY0a+UZPRBJ3Zqbvv9y+Pac/c+fOnL7O3Et2n+ju3Tn9PwWVnHzUP7uVL09ShQoklStH2ZTyUL5aTlLIThi5b+dOLsf2/+TVZ5SdtHIntdzJLXcSzP7qkjsp5pVEC7KIHEe0k/xGoH6u+/WATbl3cM6NAkYBdOjQ4ZgeWYmEPXtgzhw/u8/s2T6pZ/eBN2sG550HHTr4pVUr/4200Pbt818Htm71/TjZ623bckYZZN/+6adfdr5nM8v5apz9laBp06O/Op988i+/Yleq9Muv4snJRXgjIokl2kn+K6CpmTUGfgT6AtdE+ZgCrFoF//63Xz77zHd1JidDp07wwANw7rn+9gkTelaWT8wbN/phINnr3ENFNm3Ku3PezPfh1KzpT0C1bu1vp6T4E1W5l2rV/HLyyb4VKyIREdUk75zLMLO7gY/wQyhHO+eWRPOYJZVz/qKZd96B1FRYscJvb9ECfv97uPBCf0VkxYrHPDEry7e4V6/OOcO6bl3Osn59ztU32cqU8X06der4K3K6d88ZBpa7oz4lpUT3hYrEgqifJnbOfQh8GO3jlFQ//OAnfXj7bZ/Yk5N9K71/fz99W5Mm+L6ZdetgzipYudIvq1b5pL5mjR9yltspp/izq+3b+2vkGzTwHfL16vl1Sopa2yJxIr7GAgngB4uMHw+vvAJffum3nXsu/OGuA1x+2hKqbV7iM/79y2H5cp/Uc48frFjRn01t3hwuucR/EjRu7JeGDf3JRxFJCErycWTFCnjpJRgzxrF7t9Gy7k882XUOfUtNpP7qT+De9Tk7ly7tE/lpp/ni4aee6u83bepb6hqVIVIiKMnHOLdpM1Ne2cgzb6Yw84fGJHOIK5nE7xjJWT/OwbaX9Yn87LP9/G0tWvgWepMmGl0iIkryMcM5P1j9q69g4UIyF37DpLl1+dvu37GYjtRnPX+tOpybO31HrU4Noc190Gq0T+ZxdgWeiBQfZYeg7Nrlr0b68kuf2L/6CrZs4RDJvGn9eCL5BVYeasRpNXcy5pYVXPOH2iRXvz/oqEUkzijJFwfn/AnQOXN8UfC5c33NAOd833izZmRd0J2JSVczdPp5rN5UjjNawbt/gj59qlGqlGZaFpGiUZKPhkOH/KD1zz6Dzz/3yX3HDv9YtWq+GPjVV/t1x47MWliZQYN87ZjWreGDV+D//k/nRkUkfErykXDokK8VkF03YM4cP84R/GiWXr3gnHPgrLP8/VD2XrIE/tgXpk71w8/HjoVrr9X1QyISOUryRZGV5WdMnjHDL59/npPUW7eGW2/187Gdfba/+vMY+/bBn/8Mzzzjy6wMHw53363h6SISeUryBbVxo29yT58Os2b5CoqQU67x/PN9Yk9JOe7LTJ4M997rqwXcdBM8+eQJnyIiUmRK8vk5eNC30KdO9ct33/ntder4qbsuuAC6dfP3C2D9erjnHl9X5vTTfXf92WdHMX4REZTkj7Z1K3zwAfznPzBtmu9XKVPGt9BvvBEuvti33AtxRtQ539d+772+hMyTT8LAgbpOSUSKR8lO8s75s5/vv+9r8s6f77fXrw833OBb7Oefn0fpxoLZtg3uuMO//Dnn+GTfuHHEohcROaGSl+SzsvxY9ffe89l31Sq//X//F/7yF1+6sXXrsMcvpqbCbbf5a56eesq33jVqRkSKW8lI8pmZvhP8n/+Ef/3L109PToZf/xr++Ec/xLF27Ygcav9+GDDAV4hs2xZmzvQl10VEgpC4ST47sU+c6BP71q1+Qt8ePeDyy/26SPPc5W/1arjiCvj6az/70qOP+i59EZGgJFaSd85fNvr2236KpE2bfGK/5BK48kq/LmL/+omkpvpu/FKl/HnbSy6JymFERAolMZL8pk3wj3/4mTRWrvRdMT16+NIBPXtGLbGDnzv1wQfhiSfgjDNg0iRo1ChqhxMRKZTESPLbt8Njj/mRMIMHw2WXQdWqUT/sTz/57plZs/womuee01WrIhJbEiPJt2oFmzfnWUIgWn74wXfJrF4Nr7/uh9GLiMSaxEjyZsWa4D//HPr08acAZszw10qJiMSiUuE82cyGm9lyM/uvmb1nZlVyPTbEzFaZ2QozuyjsSGPEuHG+mkH16n64vRK8iMSysJI8MB1o6ZxrDXwPDAEwsxZAX+B04GJgpJnF9aVAzvnKkdddB2ee6ef9aNo06KhERI4vrCTvnJvmnMsI3Z0H1Avd7g1McM4ddM6tAVYBncI5VpCysvwVqw8/7IdJfvSRn/tDRCTWhduSz+1mYErodl1gQ67HNoa2/YKZ3W5mC8xsQVpaWgTDiYyMDLjlFhgxAu67D8aM0QVOIhI/TpjkzWyGmX2Xx9I71z5DgQxgXPamPF7K5fX6zrlRzrkOzrkONWrUKMp7iJqDB6FvX5/YH3kEnn5aU/KJSHw54ega59wFx3vczPoBPYFuzrnsRL4RqJ9rt3rApqIGGYT0dD/c/qOP4NlnfXeNiEi8CXd0zcXAYKCXcy4910OpQF8zK2tmjYGmwPxwjlWc9u3zE2lPnw6vvaYELyLxK9xx8i8AZYHp5vsx5jnn7nTOLTGzicBSfDfOXc65zDCPVSz27/dFKT//3JfAueqqoCMSESm6sJK8c+5Xx3lsGDAsnNcvbgcP+i6ajz+GN99UgheR+JcYV7xGwOHD8Nvf+ulcX30Vrr026IhERMIXySGUcSsjwyf11FR48UU/ZFJEJBGU+CSflQU33+wnjXr6aejfP+iIREQip8Qn+cGDff/7Y4/B738fdDQiIpFVopP8iBF+ku277oKhQ4OORkQk8kpskn/3XT/+vU8feP55XckqIompRCb5OXP8idbOnf1Y+KS4ro8pIpK/Epfkly+H3/wGGjb0o2nKlw86IhGR6ClRST4tzZcrSE6GKVMgJSXoiEREoqvEXAx1+LCfdHvLFvj0U2jSJOiIRESir8Qk+YEDfXIfNw46dgw6GhGR4lEiumtGjYKRI2HQILjmmqCjEREpPgmf5D/7zI+Dv/hi+Otfg45GRKR4JXSSX78eLr8cGjeG8eM1VFJESp6ETfL798Oll8KBAzB5MlSpEnREIiLFL2FPvA4cCIsW+bHwzZsHHY2ISDASsiU/frw/2TpokL/wSUSkpEq4JP/993D77XDmmfCXvwQdjYhIsBIqyR844Gd3KlMGJkzwV7aKiJRkCdUnf999sHgx/Oc/UL9+0NGIiAQvYVry77wDL78M998Pl1wSdDQiIrEhIknezP5oZs7MUnJtG2Jmq8xshZldFInj5GflSrjtNujSBYYNi+aRRETiS9jdNWZWH7gQWJ9rWwugL3A6UAeYYWanOucywz1eXkqX9gn+lVfUDy8iklskWvLPAoMAl2tbb2CCc+6gc24NsAroFIFj5alxY/joI2jQIFpHEBGJT2EleTPrBfzonFt8zEN1gQ257m8MbcvrNW43swVmtiAtLS2ccERE5Bgn7K4xsxnAKXk8NBT4E9A9r6flsc3lsQ3n3ChgFECHDh3y3EdERIrmhEneOXdBXtvNrBXQGFhsfhbsesAiM+uEb7nnHsRYD9gUdrQiIlIoRe6ucc5965yr6Zxr5JxrhE/s7Z1zW4BUoK+ZlTWzxkBTYH5EIhYRkQKLysVQzrklZjYRWApkAHdFa2SNiIjkL2JJPtSaz31/GKBR6yIiAUqYK15FROSXlORFRBKYORc7oxbNLA1YF8ZLpADbIxROkBLlfYDeSyxKlPcBei/ZGjrnauT1QEwl+XCZ2QLnXIeg4whXorwP0HuJRYnyPkDvpSDUXSMiksCU5EVEEliiJflRQQcQIYnyPkDvJRYlyvsAvZcTSqg+eREROVqiteRFRCSXhEryZvaYmf3XzL4xs2lmVifomIrKzIab2fLQ+3nPzKoEHVNRmdmVZrbEzLLMLO5GQpjZxaEZzlaZ2QNBx1NUZjbazLaZ2XdBxxIuM6tvZh+b2bLQ39aAoGMqCjMrZ2bzzWxx6H08GvFjJFJ3jZmd7Jz7OXT7XqCFc+7OgMMqEjPrDsxyzmWY2RMAzrnBAYdVJGbWHMgC/gH80Tm3IOCQCszMkoDv8bOfbQS+Aq52zi0NNLAiMLOuwF7gDedcy6DjCYeZ1QZqO+cWmdlJwEKgT7z9XsyX8K3onNtrZsnA58AA59y8SB0joVry2Qk+pCL51LCPB865ac65jNDdefhyzXHJObfMObci6DiKqBOwyjm32jl3CJiAn/ks7jjnPgV2Bh1HJDjnNjvnFoVu7wGWkc/ERLHMeXtDd5NDS0TzVkIleQAzG2ZmG4BrgYeCjidCbgamBB1ECVXgWc4kGGbWCGgHfBlwKEViZklm9g2wDZjunIvo+4i7JG9mM8zsuzyW3gDOuaHOufrAOODuYKM9vhO9l9A+Q/HlmscFF+mJFeS9xKkCz3Imxc/MKgHvAgOP+SYfN5xzmc65tvhv653MLKJdaVGpJx9N+c1UlYe3gQ+Ah6MYTlhO9F7MrB/QE+jmYvzkSSF+L/FGs5zFqFAf9rvAOOfcv4KOJ1zOuV1mNhu4GIjYyfG4a8kfj5k1zXW3F7A8qFjCZWYXA4OBXs659KDjKcG+ApqaWWMzKwP0xc98JgEKnbB8DVjmnHsm6HiKysxqZI+cM7PywAVEOG8l2uiad4Fm+JEc64A7nXM/BhtV0ZjZKqAssCO0aV4cjxS6FPg7UAPYBXzjnLso0KAKwcx6AM8BScDo0IQ4ccfMxgPn4asdbgUeds69FmhQRWRmZwOfAd/i/98B/uSc+zC4qArPzFoDY/F/W6WAic65P0f0GImU5EVE5GgJ1V0jIiJHU5IXEUlgSvIiIglMSV5EJIEpyYuIJDAleRGRBKYkLyKSwJTkRUQS2P8H6dIMFIfK5rgAAAAASUVORK5CYII=\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出更新一次之后的模型\n",
- "y_pred = multi_linear(x_train)\n",
- "\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "因为只更新了一次,所以两条曲线之间的差异仍然存在,下面进行 100 次迭代"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "epoch 20, Loss: 65.56586\n",
- "epoch 40, Loss: 15.41177\n",
- "epoch 60, Loss: 3.70702\n",
- "epoch 80, Loss: 0.97122\n",
- "epoch 100, Loss: 0.32874\n"
- ]
- }
- ],
- "source": [
- "# 进行 100 次参数更新\n",
- "for e in range(100):\n",
- " y_pred = multi_linear(x_train)\n",
- " loss = get_loss(y_pred, y_train)\n",
- " \n",
- " w.grad.data.zero_()\n",
- " b.grad.data.zero_()\n",
- " loss.backward()\n",
- " \n",
- " # 更新参数\n",
- " w.data = w.data - 0.001 * w.grad.data\n",
- " b.data = b.data - 0.001 * b.grad.data\n",
- " if (e + 1) % 20 == 0:\n",
- " print('epoch {}, Loss: {:.5f}'.format(e+1, loss.data.item()))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看到更新完成之后 loss 已经非常小了,我们画出更新之后的曲线对比"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7ac305e80>"
- ]
- },
- "execution_count": 27,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD7CAYAAACPDORaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqjElEQVR4nO3dd3gU5drH8e+dRghFSgLSAhEB6cUQwYIogg0RKYpYEEVsKOpRFDmiR+UIBysqKCICihRBAQu8KFUR1ICoQEB6LyG0UELK3u8fu2DUACG7m9nd3J/r2mt3Z2bnuSfoL0+enXlGVBVjjDGhKczpAowxxviPhbwxxoQwC3ljjAlhFvLGGBPCLOSNMSaEWcgbY0wIy3fIi8hoEdkjIityLSsnIt+IyFrPc9lc6/qLyDoRWSMiV/u6cGOMMWd2Nj35McA1f1v2NDBHVWsBczzvEZF6QDegvuczw0Uk3OtqjTHGnJWI/G6oqgtFpMbfFt8ItPa8HgvMB57yLJ+oqseBjSKyDkgCFp+ujdjYWK1R4+9NGGOMOZ2lS5fuVdW4vNblO+RPoaKq7gRQ1Z0iUsGzvAqwJNd22zzLTqtGjRokJyd7WZIxxhQtIrL5VOv89cWr5LEsz/kTRKS3iCSLSHJqaqqfyjHGmKLJ25DfLSKVADzPezzLtwHVcm1XFdiR1w5UdaSqJqpqYlxcnn9tGGOMKSBvQ34G0MPzugcwPdfybiJSTEQSgFrAT162ZYwx5izle0xeRCbg/pI1VkS2Ac8Bg4HJInIPsAXoCqCqK0VkMrAKyAYeUtWcghSYlZXFtm3byMjIKMjHjZeio6OpWrUqkZGRTpdijCkACaSphhMTE/XvX7xu3LiRUqVKUb58eUTyGuo3/qKqpKWlkZ6eTkJCgtPlGGNOQUSWqmpiXusC/orXjIwMC3iHiAjly5e3v6KMCWIBH/KABbyD7GdvTHALipA3xphQNmwYzJjhn31byOfDsGHDqFu3LrfddhszZsxg8ODBAEybNo1Vq1ad3G7MmDHs2PHnmaK9evX6y3pjjPm7Awegf3+YPv2MmxaIt1e8FgnDhw9n5syZJ7987NChA+AO+fbt21OvXj3AHfINGjSgcuXKAIwaNcqZgnPJyckhPNymDTImUI0ZA0ePwkP15gFX+Hz/1pM/g/vvv58NGzbQoUMHXn/9dcaMGUOfPn344YcfmDFjBk8++SRNmjRhyJAhJCcnc9ttt9GkSROOHTtG69atT07TULJkSQYMGEDjxo1p0aIFu3fvBmD9+vW0aNGC5s2bM3DgQEqWLJlnHePGjaNRo0Y0btyYO+64A4C77rqLKVOmnNzmxGfnz5/PFVdcQffu3WnYsCFPPfUUw4cPP7nd888/z6uvvgrA0KFDad68OY0aNeK5557z/Q/QGHNKLhe881YOF4ctplnKeL+0EVw9+UcfheXLfbvPJk3gjTdOufrdd99l1qxZzJs3j9jYWMaMGQPAxRdfTIcOHWjfvj1dunQBYObMmbzyyiskJv7zTKYjR47QokULBg0aRL9+/Xj//ff597//Td++fenbty+33nor7777bp41rFy5kkGDBrFo0SJiY2PZt2/fGQ/rp59+YsWKFSQkJPDLL7/w6KOP8uCDDwIwefJkZs2axezZs1m7di0//fQTqkqHDh1YuHAhrVq1OuP+jTHemz0b1m0I5wWGwSP9/dKG9eQLSVRUFO3btwfgwgsvZNOmTQAsXryYrl27AtC9e/c8Pzt37ly6dOlCbGwsAOXKlTtje0lJSSeHl5o2bcqePXvYsWMHv/76K2XLliU+Pp7Zs2cze/ZsmjZtSrNmzVi9ejVr16719lCNMfn09lsuKobtoXOrvdCokV/aCK6e/Gl63IEuMjLy5OmI4eHhZGdn5/uzqprnqYwRERG4XK6T22RmZp5cV6JEib9s26VLF6ZMmcKuXbvo1q3byc/079+f++6776yPxxjjnfXr4euZwrM6gqjHHvJbO9aT90KpUqVIT08/5fv8aNGiBVOnTgVg4sSJeW7Tpk0bJk+eTFpaGsDJ4ZoaNWqwdOlSAKZPn05WVtYp2+nWrRsTJ05kypQpJ4eXrr76akaPHs3hw4cB2L59O3v27DnlPowxvjNiBISTw31Vv4YbbvBbOxbyXujWrRtDhw6ladOmrF+/nrvuuov777//5Bev+fHGG2/w2muvkZSUxM6dOznnnHP+sU39+vUZMGAAl19+OY0bN+bxxx8H4N5772XBggUkJSXx448//qP3/vd9pKenU6VKFSpVqgRAu3bt6N69Oy1btqRhw4Z06dLlrH9JGWPO3tGj8MHIbDrpVCo/ejP48Qy4gJ+7JiUlhbp16zpUkf8dPXqU4sWLIyJMnDiRCRMmMN1fJ8wWUKj/GxhT2EaNgnvvhYXF2nLZzslQtuyZP3Qap5u7JrjG5EPQ0qVL6dOnD6pKmTJlGD16tNMlGWP8SBXefiObRrKKS3vW8jrgz8RC3mGXXXYZv/76q9NlGGMKyaJF8OvKCEbyFvLI435vz0LeGGMK0dvDXJSRQ3S/YjcUwjCoffFqjDGFZOtWmDIVeupoSjxeOKcuW8gbY0wheestUJfySI0v4NprC6VNG64xxphCkJ4OI0dk04Wp1PhXZwgrnD629eQLQY0aNdi7d6/TZRhjHPTBB3DwcAT/Kj0KevYstHYt5M+Cqp6cRsDqMMbkV3Y2vPlKJpfyHUmPXgynuXDR1yzkz2DTpk3UrVuXBx98kGbNmrF169ZTTs/bsWNHLrzwQurXr8/IkSPPuO9Zs2bRrFkzGjduTJs2bQD3NMCvvPLKyW0aNGjApk2b/lHHiy++SL9+/U5uN2bMGB5++GEAPv74Y5KSkmjSpAn33XcfOTk5vvpxGGMK4PPPYdP2KB6PfBv69CnUtn0yJi8ijwG9AAV+B3oCMcAkoAawCbhZVfd7044DMw0DsGbNGj788EOGDx9+2ul5R48eTbly5Th27BjNmzenc+fOlC9fPs99pqamcu+997Jw4UISEhLyNX1w7jpSU1Np2bIl//vf/wCYNGkSAwYMICUlhUmTJrFo0SIiIyN58MEHGT9+PHfeeedZ/mSMMb6gCq8OzqQmW+lwdyzExRVq+16HvIhUAR4B6qnqMRGZDHQD6gFzVHWwiDwNPA085W17TqhevTotWrQA+Mv0vACHDx9m7dq1tGrVimHDhvH5558DsHXrVtauXXvKkF+yZAmtWrU6OR1wfqYPzl1HXFwc5513HkuWLKFWrVqsWbOGSy65hHfeeYelS5fSvHlzAI4dO0aFChW8+wEYYwps8WL4cVkUb/M64U88Vujt++rsmgiguIhk4e7B7wD6A60968cC8/Ey5J2aaTj3xF+nmp53/vz5fPvttyxevJiYmBhat25NRkbGKfeZn+mDgb/s4+8TkN1yyy1MnjyZCy64gJtuugkRQVXp0aMHL7/88lkfpzHG914dkkVZOcxdHfbD+ecXevtej8mr6nbgFWALsBM4qKqzgYqqutOzzU4gz+6kiPQWkWQRSU5NTfW2HL871fS8Bw8epGzZssTExLB69WqWLFly2v20bNmSBQsWsHHjRuCv0wcvW7YMgGXLlp1cn5dOnToxbdo0JkyYwC233AK4pyWeMmXKySmD9+3bx+bNm707aGNMgaxfD59/EcEDOpwS/R9xpAZfDNeUBW4EEoADwKcicnt+P6+qI4GR4J6F0tt6/K1du3akpKTQsmVLwH1f1Y8//phrrrmGd999l0aNGlGnTp2TwyqnEhcXx8iRI+nUqRMul4sKFSrwzTff0LlzZ8aNG0eTJk1o3rw5tWvXPuU+ypYtS7169Vi1ahVJSUkA1KtXj5deeol27drhcrmIjIzknXfeoXr16r77IRhj8uXN13OI0BweumgpXDTAkRq8nmpYRLoC16jqPZ73dwItgDZAa1XdKSKVgPmqWud0+yqKUw0HA/s3MObspaVB9SpZdDk+njFfxsH11/utrdNNNeyLUyi3AC1EJEbcg8xtgBRgBtDDs00PILAmSTfGGD8a9qZy5HgkT573WaFNYZAXr4drVPVHEZkCLAOygV9wD7+UBCaLyD24fxF09bYtY4wJBunpMOy1LG7iS+o/26nQpjDIi0/OrlHV54Dn/rb4OO5evS/2n+eZKMb/AunOYcYEi3dHKAeORNG/0li4bYqjtQT8Fa/R0dGkpaVZ2DhAVUlLSyM6OtrpUowJGhkZ7ouf2jKb5s9fD5GRjtYT8LNQVq1alW3bthEMp1eGoujoaKpWrep0GcYEjQ9HK7v3F+OZ2Pehx8dOlxP4IR8ZGXnyqlBjjAlkWVkw5IXjtGQZlw+8HIoVc7qkwA95Y4wJFhMmwObd0bxdZgTS68yTFBYGC3ljjPEBlwteHniURqzl+gFNoHhxp0sCLOSNMcYnpk2D1ZtjmFDqbeSBN5wu5yQLeWOM8ZIq/PffRzifHXR9qmah3hTkTAL+FEpjjAl0X30FS1NK8FTxtwh/+EGny/kL68kbY4wXVGHgk0dJYBc9/hULpUs7XdJfWMgbY4wXpk+HX1bH8GHMq0Q+8V+ny/kHC3ljjCkglwuee+IItdjO7QOqwznnOF3SP1jIG2NMAU2dovy2vgQfl36DiL5DnS4nTxbyxhhTADk58Hy/I9RlC91eqBdQZ9TkZiFvjDEFMHmSsmpzSSaVf5vw+193upxTspA3xpizlJ0Nzz95hIZsoMvLFwbEHDWnYiFvjDFn6ZOPXfyxoySfVXqPsJ5vOl3OaVnIG2PMWcjKgheePkJT1tLx1csgIrBjNLCrM8aYADPqvRzW7y7FFzU+RG4J7F48WMgbY0y+pafD888c5zKSuf7Ndo7euzW/LOSNMSafXh2UwZ70GGY0nYDcMNzpcvIl8H8NGWNMANi1C155TejCp1z0fi8QcbqkfPFJyItIGRGZIiKrRSRFRFqKSDkR+UZE1nqey/qiLWOMccILT6ZzPCuM/3b4ES680Oly8s1XPfk3gVmqegHQGEgBngbmqGotYI7nvTHGBJ01a2Dk+BjuC/+AWm894nQ5Z8XrkBeR0kAr4AMAVc1U1QPAjcBYz2ZjgY7etmWMMU545oH9FNejDHwoDeLjnS7nrPiiJ38ekAp8KCK/iMgoESkBVFTVnQCe5wp5fVhEeotIsogkp6am+qAcY4zxncU/KJ/NK0u/mHeo8OLDTpdz1nwR8hFAM2CEqjYFjnAWQzOqOlJVE1U1MS4uzgflGGOMb6jCk732cS47eXxQ+YC7IUh++CLktwHbVPVHz/spuEN/t4hUAvA87/FBW8YYU2g+m5zNopTyPF/xXUr06el0OQXidcir6i5gq4jU8SxqA6wCZgA9PMt6ANO9bcsYYwrLsWPwrweO0JDfuOe9pICfvuBUfFX1w8B4EYkCNgA9cf8CmSwi9wBbgK4+assYY/zulefS2bz/HOZd9AkRHV52upwC80nIq+pyIDGPVW18sX9jjClMW7bAy69F0TVsCq0/uidoLnzKi13xaowxf9OvZyqa42LoAxuhVi2ny/GKhbwxxuSyYG42k+bG8XTpEVQf8qDT5XgtOL9JMMYYP8jOhkfuOEA8R3hyeELA3rf1bFhP3hhjPN5/LZ3fdsTySsNxxHTv6HQ5PmEhb4wxwL598O+BwuUsoMvELkH9ZWtuFvLGGAM8ddduDh6PZliPpUi9uk6X4zMW8saYIu+7uVmM+qIij5UcRaO37nW6HJ+yL16NMUXa8ePQ+5aDVOcwz39YHUqVcrokn7KevDGmSPvfk6ms3hvL8FaTKNHlWqfL8TkLeWNMkfXHaheD3i7NzVGfc92nwTkB2ZlYyBtjiiRVuP/GnUTrMd4YkgkV8rzlRdCzkDfGFEkfDdvHvD+qMLjOGCr1vdnpcvzGQt4YU+TsTVUe7xdBy7Al9P7ihpA5Jz4vFvLGmCLn4Y5bOJQZzXuPriasVk2ny/ErC3ljTJEyZeQ+Jv5QnYFVRtNwyO1Ol+N3FvLGmCJjz27lgT7hJMpSnv6/K4L2bk9nw0LeGFMkqML912wiPasYY/uvJqJ+nTN/KARYyBtjioRPXtvF58sTePH8cdR78Vanyyk0FvLGmJC3Y2sOfZ6K4eLwJTz+7XUQVnSir+gcqTGmSFKFe9tu5HhOBGOG7CG8elWnSypUPgt5EQkXkV9E5EvP+3Ii8o2IrPU8l/VVW8YYk1+jntvK12vOZ3DjidR6/Aanyyl0vuzJ9wVScr1/GpijqrWAOZ73xhhTaFYlH6HvS7FcFbWQPrM7hPRFT6fik5AXkarA9cCoXItvBMZ6Xo8FOvqiLWOMyY9jR5VubdMoqemM+ySCsAqxTpfkCF/15N8A+gGuXMsqqupOAM9zaM7+Y4wJSE/csJrfD8Qz7rbZVOp8sdPlOMbrkBeR9sAeVV1awM/3FpFkEUlOTU31thxjjOHzYVsZPrcu/6o2mWvGFp3TJfMiqurdDkReBu4AsoFooDTwGdAcaK2qO0WkEjBfVU979UFiYqImJyd7VY8xpmjbsuYYTepnUlM2sGjduURVr+R0SX4nIktVNTGvdV735FW1v6pWVdUaQDdgrqreDswAeng26wFM97YtY4w5nexsuL31VrJywpjw/pEiEfBn4s/z5AcDbUVkLdDW894YY/zm+S4r+G5XbUZ0mMn5d13qdDkBwaez86jqfGC+53Ua0MaX+zfGmFOZ9up6Bk1vwD3nfsntUzs5XU7AsCtejTFBb/WiNO58sgLNo5bz9o9JRWJ2yfyykDfGBLVDaVnc1O4w0XqMqZ+FER1vZ2vnZiFvjAlaLhfclbSKtUerMPmZX6l2fSOnSwo4FvLGmKA15OZkPt/QmKGtv6L1oLZOlxOQLOSNMUFp1hurGTC1Gd0qzuPR2dc7XU7AspA3xgSdFTO3cvPjVWgYtYZRPzVCIu2L1lOxkDfGBJWdK/dxXYdwSnGYL/8vihLx5Z0uKaBZyBtjgsaRtAxuaLGHfdml+eL9XVRrXdPpkgKehbwxJijkZLm4rdHv/HK4FhP7/UKze5o6XVJQsJA3xgSFJ1ouYvqO5rzZcT7th1zmdDlBw0LeGBPwhnVZyBtLL6Nvo3n0+exKp8sJKhbyxpiANq73d/Sd2oqO5y7h1Z9bFclb+HnDQt4YE7A+f2IRd7/fkjblljEhpQnhUeFOlxR0LOSNMQHpmxeX0O3VRJqXWs20VXWILhPtdElByULeGBNwfhiWTMeBDbmg+Ga+/j2ekhVLOF1S0LKQN8YElOVjlnNd3/OpHLWX2cviKFu9tNMlBTULeWNMwFg2Mpk2PeMpFXGMbxfFUPGCsk6XFPQs5I0xAeHHNxbT5r6alIzMYP7CcKonxjldUkiwkDfGOO77/y6k7WP1KRd1mIVLilGzpd34w1cs5I0xjpo7YA5XD7iQytH7WfhLaao3swnHfMlC3hjjmFl9Z3L9fy/mvBK7WbCiPFXqneN0SSHH65AXkWoiMk9EUkRkpYj09SwvJyLfiMhaz7N9g2KMcVNlXIcp3DDsKi4otYN5KZWoWLOk01WFJF/05LOBf6lqXaAF8JCI1AOeBuaoai1gjue9MaaI04zjvNRkCj2+6MLlldcyf0M8sdWKO11WyPI65FV1p6ou87xOB1KAKsCNwFjPZmOBjt62ZYwJbtmp++l93jc8+1tX7mjyG19vqMs5sZFOlxXSfDomLyI1gKbAj0BFVd0J7l8EgH1dbkwRdnjlZjok/M6one0ZcOMKxi5rRFQxm2zM33wW8iJSEpgKPKqqh87ic71FJFlEklNTU31VjjEmgGye8AOtmhxk9pGLee9ff/DStAY2mWQh8UnIi0gk7oAfr6qfeRbvFpFKnvWVgD15fVZVR6pqoqomxsXZxQ/GhBRV5tz/KRd2r80GVw1mvLuT3q/UdrqqIsUXZ9cI8AGQoqqv5Vo1A+jhed0DmO5tW8aY4KHph3mlyce0e68TFUsd4+fkMK67r5rTZRU5ET7YxyXAHcDvIrLcs+wZYDAwWUTuAbYAXX3QljEmCBxZvpZ7Ll/HpEN30KXhaj5cVIeSpWx8xgleh7yqfg+c6l+vjbf7N8YEEVVWvPwFtz5bk1WudgzutY5+Iy+w8XcH+aInb4wx6L79vNN2Gk8su5VzIo8yc2wa7W4/3+myijwLeWOM13Z/toi7b8vg64yeXFdnHaPnJlCxst2qLxDY3DXGmILLyGBm19E06nw+c45fyltPbubLlPMt4AOI9eSNMQWyd9r3PH7nXj5Kv5sGZbczZ1YODZKqO12W+RvryRtjzoqm7WP85SOpe1MdJqRfz4DuG/l5RxUaJMU4XZrJg4W8MSZ/VNk0bAbXVl7O7Qt7U7PyMZb9mM1L4xOIjna6OHMqNlxjjDmjIwuSGXrHb/xvazfCwuDNftt56L/xhNvQe8CzkDfGnJJr81Y+uvVrnlncnh0kcnPieoZOrkF8QhWnSzP5ZMM1xph/OnCAhXeOonlCKnctvo8qlZTvZx9l0s81iU+w7nswsZA3xvzpwAEW3f0B7eKWcflHvUgtHs/4N1NZsq0ql7S1L1aDkQ3XGGNg/36+e+wz/vNxTebk3EOFYgcY+vB2HnypCjGW7UHNQt6YIsy1fiOznviWV76ow7yce6hYbD+vPrqd+1+oQkxMGafLMz5gIW9MUaNK+qxFjHl6NW/91oq13Evl4vt4/aHt9P5PFWJiyjpdofEhC3ljior9+1n56ixGjcxhdOoNHOJSWlTewgvPpNG5d3ki7VarIclC3phQlpND2tT5TBi8mTHLG7NUbyWCLG5O2kzfocVIahXvdIXGzyzkjQk1OTkc+uZHZg5by6R5cXyZcRVZtKFp3Fbe6LGN7v2qEhdnUwAXFRbyxoSCzEz2TPuBGcO38fnic/k28zIyuZiKxfbzcPtN9HiuBo0S7dZ7RZGFvDHBSJXjK9byw8gVfDsrm283JJDsugwX4dSI2UOfqzZyU994WrYpS3i4fZFalFnIGxMMcnI49NNqfpq0kR8XHGPh6gp8l9GcY9QmnGxanLuJZy9fS8e+1WncogIiFZyu2AQIC3ljAo3Lxf5fNrHi6y38vugQy1ZEsWRnPKtcdVHqA1C/1BZ6t1zHVbfG0eqWSpQubWPsJm8W8sY4JPNwJpu/38r6JXtY/+sR1q9TUraVZMXBamzT84DzACgXfoAWVbdyS+JvtOhQgeYdKlGmbDxgZ8aYM/N7yIvINcCbQDgwSlUH+7tNE7iysyEjw/2clfXnc06Oe73qX7cPC4PwcPfjxOvISPcjIsL9HBYgMzCpuo/twH5l39YjpK0/wL4th0nbdozdW46zY7uLHXsi2XEwhu1Hy7IjpyIuagI1ASjOUWqV2EHr2jtp2GAHDVuVpeF11ahSswwiZRw9NhO8/BryIhIOvAO0BbYBP4vIDFVd5c92jX8dPw67dsHu3e7nE6/T0uDAgb8+0tPh6FH349gxJStLfF5PmCiRES6iPI/ICCUyXImKdBEZ7lkXqURGKBHh7nURuR5hooSJizBRBCUMRV2KK8eFK0fRHMWVo2RlwfFMyMwUMrOEzGzhaGYEhzOjOJwdzeGc4rgIBwQo6Xn8qSz7qByZSuWSh7ig+j7iK6+hZr1i1LywDDUvq8y5F5RBxIZdjG/5uyefBKxT1Q0AIjIRuBGwkA9w+/dDSor7sWEDbNr052PHjrw/U7pYBmUij1ImIp0yHKS67qNkzkFKZB8kJusQMTmHiOEoxThOJFlEkH3yOZwcBHc3/sSzIrgII4fwk885hJNFJNlEkEWk+6GRZGVFkpkVRRaRZBJFJlF/rvcsO/G5bCLI8DxnEXmyndyPE6/coe9+jjqxZ8miRFgOUeE5xERlU7JUNiWL51CyhIsSJYQy5cIoXzGCcpWjKVetBOUTShN3QXliEipCWLlC+fcz5gR/h3wVYGuu99uAi3JvICK9gd4A8fE2xljYcnJg9WpIToalS2HlSli1yt07PyE8zEV86YPUKLaTq3Uj1UumUPXwaiqyi3PZRUV2U4E9FMvMghJloXz5Px9lykCpUrkesVCiBBQvDsWKQXS0+xEV9ecYzIlHWBiI/POh+ue4jiq4XH8uc7n+fOS17iQFskAz3e38/XFiTOhEXZGREBPjrtuu/zdBxN8hn9ff5n8ZdVXVkcBIgMTERM1je+ND+/bBd9/BwoXw88+wbBkcOeJeVzI6iwZld3Bt+BrqnvMz9Q7+QF1SiHdtIeKgC6pVg/POg4QEqF4dKl8ElSv/+YiNxe4HZ0xg8XfIbwNyX2ZXFTjFH/vGHw4fhjlzYP589+PXX92d2uioHJrFbeWesktJzPk/mmcspHbGH4Tti4J69eCK+tCgFdR/AGrXdod6sWJOH44x5iz5O+R/BmqJSAKwHegGdPdzm0Xe5s3wxRfw5Zcwbx5kZrpD/ZKK63ihwje03j2R5pk/U2yPQuPG0CEJkvpD8+ZQp471xo0JIX4NeVXNFpE+wP/hPoVytKqu9GebRdWGDTBhAkyaBL//7l5WO3YfD8fN5vod73Nx5vcUSw2DSy6Bh6+FK/4HF15ovXNjQpzfz5NX1a+Br/3dTlG0Zw9MngyffAKLF7uXXVplA6/GTaJ96ofU3rvW3VPveQO0fQ4uushC3Zgixq54DTIuF8ydCyNGwPTp7rNjGlXcxeDYj7l17zDi9+yCK6+EG/rCDTeAnbFkTJFmIR8k9u+HMWPc4b52LZQvkcFjcVPosWswDfasgtatoftA6NwZytqsg8YYNwv5ALdpEwwdCh9+CMeOwcWVNjAw+mW6HPmI6PNqw7/ugm7doGpVp0s1xgQgC/kAtWoVDB4Mn3zivuz+ztiZPHzsGRqnrYauXeHBudCypfviIGOMOQUL+QCzfDn85z8wbRrERGXzSJnxPJ42gKrFImDIg9CzJ8TFOV2mMSZIWMgHiE2b4N//hvHjoUzxDAae8x4PH3yR2EqV4M0hcMst7kv9jTHmLFhqOGzvXhg0CIYPV8I0h6dKvcfT6QMo06Qe9P8Qrr8+cObSNcYEHQt5h2RlwbBh8MILyuHDcFepz/jPwb5UbVoFXv4MrrjCxtuNMV6zkHfAokXwwAPuK1OvLbOEoa5e1K+YDR+8CZ06WbgbY3zGxgEK0d69cM89cOmlcGBDGp/Tka+iO1P/vb7uOX47d7aAN8b4lIV8IVCFceOgTh1l3BgXT0YPY1VGTTo+UQv5Yw307m1fqhpj/MKSxc/27IH77nOfEnlJyd9413UbDZLKw/BFUL++0+UZY0Kc9eT9aPp0aNBA+fqLbIaG9WNBzLU0+Ohp98TuFvDGmEJgPXk/OHQI+vZ1zzXTpPgfzM3pTIPbmsCwFVDO7vFpjCk8FvI+9ssv0KWLsmmjMiBsCANLvE3Ux2+5z5oxxphCZiHvQx98AA89pMRqKgv1Ji656VwY8QtUqOB0acaYIsrG5H3g6FG4+27o1Qsucy1kWcRFXDL2PpgyxQLeGOMo68l7ad066NxZ+e034Vle5LlanxI+dSZccIHTpRljjIW8N+bNg043uQg7ks7XdOPanpXg7SUQE+N0acYYA9hwTYGNHg3t2rqofGQtyeEtuPbDW9wLLeCNMQHEQv4suVzw9NPu6Qmu0Ln8cG5nEn6aBHfd5XRpxhjzD16FvIgMFZHVIvKbiHwuImVyresvIutEZI2IXO11pQHg6FG4+WZlyBC4j3f5KukFzlk6Fxo1cro0Y4zJk7c9+W+ABqraCPgD6A8gIvWAbkB94BpguIiEe9mWo/btgytbu/hsqvIajzHi9h+InDfbzp4xxgQ0r0JeVWerarbn7RLgxN2kbwQmqupxVd0IrAOSvGnLSTt3wuWXZrM8OYvP6Mxjg89Fxo2F6GinSzPGmNPy5dk1dwOTPK+r4A79E7Z5lv2DiPQGegPEx8f7sBzf2LQJrmqdza4tmXwd0YkrJ/a2q1eNMUHjjCEvIt8C5+axaoCqTvdsMwDIBsaf+Fge22te+1fVkcBIgMTExDy3cUpKCrS9IoujqUeYU7wTF301EFq3drosY4zJtzOGvKpedbr1ItIDaA+0UdUTIb0NqJZrs6rAjoIW6YRly+DqNllEHNrHgnNupuG3r0OzZk6XZYwxZ8Xbs2uuAZ4COqjq0VyrZgDdRKSYiCQAtYCfvGmrMC1dCle2yqbEwR18d+7NNFzyvgW8MSYoeTsm/zZQDPhG3LetW6Kq96vqShGZDKzCPYzzkKrmeNlWofj1V2h3RSZlj+5gwfm9iJ/3CVTJ8+sEY4wJeF6FvKqef5p1g4BB3uy/sK1aBW1bZxKTvoe5dR4g/vuJEBvrdFnGGFNgNneNx9q10Oay44Qf2M+cWveT8N04C3hjTNCzkAc2boQrL8kge186C2r2pvb3oyEuzumyjDHGa0U+5HfuhCsvzuBI6lHmJfSm3qL37SpWY0zIKNIhn54O119xlNRdLuZVv4/GP4yAihWdLssYY3ymyIZ8VhZ0ue4Iv60pxhcVetH8hzfh3Lyu+TLGmOBVJENeFXrddozZ35fgg5J9ufa7Z6ByZafLMsYYnyuS88k/2y+DcZ8W5/nIQdw95zaoXdvpkowxxi+KXMi/93YWg16Jppd8wMAZiZAUtJNjGmPMGRWpkP/m/1w8+HA41/EVI8YUR64JiXuZGGPMKRWZkF+3Dm7pmEE9VjLxpfVE3Nnd6ZKMMcbvikTIHzoEN7Y+gGQcY3q3iZR65mGnSzLGmEIR8mfXuFxwx/X7WLO9NLMb9+O8cUNA8pru3hhjQk/I9+Sfe/QgM74vx+uxg7hyzgCIjHS6JGOMKTQhHfKffnSMl946h7sjP6LPgpuhfHmnSzLGmEIVsiG/8ncXd/UUWvIDw6dUQOrVdbokY4wpdCEZ8keOQNc2aZTKOcDU51dQrIOdKmmMKZpCMuQf6rSD1anlGX/FB1QaeK/T5RhjjGNCLuTHvL6fsbMrMzB2BG1m9LUzaYwxRVpIhfzK5Vk8+ERxrghbwLPzroSSJZ0uyRhjHBUyIX9iHL606wCfvLOf8Ab2Rasxxvgk5EXkCRFREYnNtay/iKwTkTUi4vdvPh9qv5nV+yow/sZPOff+jv5uzhhjgoLXV7yKSDWgLbAl17J6QDegPlAZ+FZEaqtqjrft5eWbMdsZO786z1V+nzaTevujCWOMCUq+6Mm/DvQDNNeyG4GJqnpcVTcC6wC/zel7VatMJjYcxLPfXw3FivmrGWOMCTpe9eRFpAOwXVV/lb+exVIFWJLr/TbPMr+Q8xK45bcB/tq9McYErTOGvIh8C+R189MBwDNAu7w+lscyzWMZItIb6A0QHx9/pnKMMcachTOGvKpelddyEWkIJAAnevFVgWUikoS7514t1+ZVgR2n2P9IYCRAYmJinr8IjDHGFEyBx+RV9XdVraCqNVS1Bu5gb6aqu4AZQDcRKSYiCUAt4CefVGyMMSbf/DKfvKquFJHJwCogG3jIX2fWGGOMOTWfhbynN5/7/SBgkK/2b4wx5uyFzBWvxhhj/slC3hhjQpiFvDHGhDBRDZyzFkUkFdjsxS5igb0+KsdJoXIcYMcSiELlOMCO5YTqqhqX14qACnlviUiyqiY6XYe3QuU4wI4lEIXKcYAdS37YcI0xxoQwC3ljjAlhoRbyI50uwEdC5TjAjiUQhcpxgB3LGYXUmLwxxpi/CrWevDHGmFxCKuRF5EUR+U1ElovIbBGp7HRNBSUiQ0Vkted4PheRMk7XVFAi0lVEVoqIS0SC7kwIEbnGcxvLdSLytNP1FJSIjBaRPSKywulavCUi1URknoikeP7b6ut0TQUhItEi8pOI/Oo5jv/4vI1QGq4RkdKqesjz+hGgnqre73BZBSIi7YC5qpotIkMAVPUph8sqEBGpC7iA94AnVDXZ4ZLyTUTCgT9w3+JyG/AzcKuqrnK0sAIQkVbAYWCcqjZwuh5viEgloJKqLhORUsBSoGOw/buIe572Eqp6WEQige+Bvqq65AwfzbeQ6smfCHiPEpziRiXBQFVnq2q25+0S3HPyByVVTVHVNU7XUUBJwDpV3aCqmcBE3Le3DDqquhDY53QdvqCqO1V1med1OpCCH+8+5y/qdtjzNtLz8GluhVTIA4jIIBHZCtwGDHS6Hh+5G5jpdBFFVBVga673fr2VpTl7IlIDaAr86HApBSIi4SKyHNgDfKOqPj2OoAt5EflWRFbk8bgRQFUHqGo1YDzQx9lqT+9Mx+LZZgDuOfnHO1fpmeXnWIJUvm9laQqfiJQEpgKP/u0v+aChqjmq2gT3X+tJIuLToTS/3DTEn051O8I8fAJ8BTznx3K8cqZjEZEeQHugjQb4lydn8e8SbPJ9K0tTuDxj2FOB8ar6mdP1eEtVD4jIfOAawGdfjgddT/50RKRWrrcdgNVO1eItEbkGeArooKpHna6nCPsZqCUiCSISBXTDfXtL4yDPF5YfACmq+prT9RSUiMSdOHNORIoDV+Hj3Aq1s2umAnVwn8mxGbhfVbc7W1XBiMg6oBiQ5lm0JIjPFLoJeAuIAw4Ay1X1akeLOgsich3wBhAOjPbc9SzoiMgEoDXu2Q53A8+p6geOFlVAInIp8B3wO+7/3wGeUdWvnavq7IlII2As7v+2woDJqvqCT9sIpZA3xhjzVyE1XGOMMeavLOSNMSaEWcgbY0wIs5A3xpgQZiFvjDEhzELeGGNCmIW8McaEMAt5Y4wJYf8PReMFzmKrczQAAAAASUVORK5CYII=\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出更新之后的结果\n",
- "y_pred = multi_linear(x_train)\n",
- "\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看到,经过 100 次更新之后,可以看到拟合的线和真实的线已经完全重合了"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": true
- },
- "source": [
- "## 练习题\n",
- "\n",
- "* 上面的例子是一个三次的多项式,尝试使用二次的多项式去拟合它,看看最后能做到多好\n"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.5.4"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|