{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 最小二乘(Generalized Least Squares)\n", "\n", "## 1. 最小二乘的基本原理\n", "\n", "最小二乘法(generalized least squares)是一种数学优化技术,它通过最小化误差的平方和找到一组数据的最佳函数匹配。 最小二乘法通常用于曲线拟合、求解模型。很多其他的优化问题也可通过最小化能量或最大化熵用最小二乘形式表达。\n", "\n", "最小二乘原理的一般形式为:\n", "$$\n", "L = \\sum (V_{obv} - V_{target}(\\theta))^2\n", "$$\n", "其中$V_{obv}$是我们观测的多组样本值,$V_{target}$是我们假设拟合函数的输出值,$\\theta$为构造模型的参数。$L$是目标函数,如果通过调整模型参数$\\theta$,使得$L$下降到最小则表明,拟合函数与观测最为接近,也就是找到了最优的模型。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.1 示例\n", "\n", "假设我们有下面的一些观测数据,我们希望找到他们内在的规律。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import sklearn\n", "from sklearn import datasets\n", "\n", "# load data\n", "d = datasets.load_diabetes()\n", "\n", "X = d.data[:, 2]\n", "Y = d.target\n", "\n", "# draw original data\n", "plt.scatter(X, Y)\n", "plt.xlabel(\"X\")\n", "plt.ylabel(\"Y\")\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.2 数学原理\n", "有$N$个观测数据为:\n", "$$\n", "\\mathbf{X} = \\{x_1, x_2, ..., x_N \\} \\\\\n", "\\mathbf{Y} = \\{y_1, y_2, ..., y_N \\}\n", "$$\n", "其中$\\mathbf{X}$为自变量,$\\mathbf{Y}$为因变量。\n", "\n", "我们希望找到一个模型能够解释这些数据,假设我们使用最简单的线性模型来拟合数据:\n", "$$\n", "y = ax + b\n", "$$\n", "那么问题就变成求解参数$a$, $b$能够使得模型输出尽可能和观测数据有比较小的误差。\n", "\n", "如何构建函数来评估模型输出与观测数据之间的误差是一个关键问题,这里我们使用观测数据与模型输出的平方和来作为评估函数(也被称为损失函数Loss function):\n", "$$\n", "L = \\sum_{i=1}^{N} (y_i - a x_i - b)^2 \\\\\n", "L = \\sum_{i=1}^{N} \\{y_i - (a x_i + b)\\}^2\n", "$$\n", "\n", "使误差函数最小,那么我们就可以求出模型的参数:\n", "$$\n", "\\frac{\\partial L}{\\partial a} = -2 \\sum_{i=1}^{N} (y_i - a x_i - b) x_i \\\\\n", "\\frac{\\partial L}{\\partial b} = -2 \\sum_{i=1}^{N} (y_i - a x_i - b)\n", "$$\n", "既当偏微分为0时,误差函数为最小,因此我们可以得到:\n", "$$\n", "-2 \\sum_{i=1}^{N} (y_i - a x_i - b) x_i = 0 \\\\\n", "-2 \\sum_{i=1}^{N} (y_i - a x_i - b) = 0 \\\\\n", "$$\n", "\n", "将上式调整一下顺序可以得到:\n", "$$\n", "a \\sum x_i^2 + b \\sum x_i = \\sum y_i x_i \\\\\n", "a \\sum x_i + b N = \\sum y_i\n", "$$\n", "通过求解二元一次方程组,我们即可求出模型的最优参数。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.3 求解程序" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a = 949.435260, b = 152.133484\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "N = X.shape[0]\n", "\n", "S_X2 = np.sum(X*X)\n", "S_X = np.sum(X)\n", "S_XY = np.sum(X*Y)\n", "S_Y = np.sum(Y)\n", "\n", "A1 = np.array([[S_X2, S_X], \n", " [S_X, N]])\n", "B1 = np.array([S_XY, S_Y])\n", "# numpy.linalg模块包含线性代数的函数。使用这个模块,可以计算逆矩阵、求特征值、解线性方程组以及求解行列式等。\n", "coeff = np.linalg.inv(A1).dot(B1)\n", "\n", "print('a = %f, b = %f' % (coeff[0], coeff[1]))\n", "\n", "x_min = np.min(X)\n", "x_max = np.max(X)\n", "y_min = coeff[0] * x_min + coeff[1]\n", "y_max = coeff[0] * x_max + coeff[1]\n", "\n", "plt.scatter(X, Y, label='original data')\n", "plt.plot([x_min, x_max], [y_min, y_max], 'r', label='model')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. 如何使用迭代的方法求出模型参数\n", "\n", "当数据比较多的时候,或者模型比较复杂,无法直接使用解析的方式求出模型参数。因此更为常用的方式是,通过迭代的方式逐步逼近模型的参数。\n", "\n", "### 2.1 梯度下降法\n", "在机器学习算法中,对于很多监督学习模型,需要对原始的模型构建损失函数,接下来便是通过优化算法对损失函数进行优化,以便寻找到最优的参数。在求解机器学习参数的优化算法中,使用较多的是基于梯度下降的优化算法(Gradient Descent, GD)。\n", "\n", "梯度下降法有很多优点,其中最主要的优点是,在梯度下降法的求解过程中只需求解损失函数的一阶导数,计算的代价比较小,这使得梯度下降法能在很多大规模数据集上得到应用。\n", "\n", "梯度下降法的含义是通过当前点的梯度方向寻找到新的迭代点。梯度下降法的基本思想可以类比为一个下山的过程。假设这样一个场景:\n", "* 一个人被困在山上,需要从山上下来(i.e. 找到山的最低点,也就是山谷)。\n", "* 但此时山上的浓雾很大,导致可视度很低。因此,下山的路径就无法确定,他必须利用自己周围的信息去找到下山的路径。\n", "* 这个时候,他就可以利用梯度下降算法来帮助自己下山。\n", " - 具体来说就是,以他当前的所处的位置为基准,寻找这个位置最陡峭的地方,然后朝着山的高度下降的地方走\n", " - 同理,如果我们的目标是上山,也就是爬到山顶,那么此时应该是朝着最陡峭的方向往上走。\n", " - 然后每走一段距离,都反复采用同一个方法,最后就能成功的抵达山谷。\n", "\n", "\n", "我们同时可以假设这座山最陡峭的地方是无法通过肉眼立马观察出来的,而是需要一个复杂的工具来测量,同时,这个人此时正好拥有测量出最陡峭方向的能力。所以,此人每走一段距离,都需要一段时间来测量所在位置最陡峭的方向,这是比较耗时的。那么为了在太阳下山之前到达山底,就要尽可能的减少测量方向的次数。这是一个两难的选择,如果测量的频繁,可以保证下山的方向是绝对正确的,但又非常耗时,如果测量的过少,又有偏离轨道的风险。所以需要找到一个合适的测量方向的频率,来确保下山的方向不错误,同时又不至于耗时太多!\n", "\n", "\n", "![gradient_descent](images/gradient_descent.png)\n", "\n", "如上图所示,得到了局部最优解。x,y表示的是$\\theta_0$和$\\theta_1$,z方向表示的是花费函数,很明显出发点不同,最后到达的收敛点可能不一样。当然如果是碗状的,那么收敛点就应该是一样的。\n", "\n", "对于某一个损失函数\n", "$$\n", "L = \\sum_{i=1}^{N} (y_i - a x_i - b)^2\n", "$$\n", "\n", "我们更新的策略是:\n", "$$\n", "\\theta^1 = \\theta^0 - \\alpha \\triangledown L(\\theta)\n", "$$\n", "其中$\\theta$代表了模型中的参数,例如$a$, $b$\n", "\n", "此公式的意义是:$L$是关于$\\theta$的一个函数,我们当前所处的位置为$\\theta_0$点,要从这个点走到L的最小值点,也就是山底。首先我们先确定前进的方向,也就是梯度的反向,然后走一段距离的步长,也就是$\\alpha$,走完这个段步长,就到达了$\\theta_1$这个点!\n", "\n", "我们更新的策略是:\n", "\n", "FIXME: 和后面的公式表达一样,好对比\n", "$$\n", "a^1 = a^0 + 2 \\alpha [ y - (ax+b)]*x \\\\\n", "b^1 = b^0 + 2 \\alpha [ y - (ax+b)] \n", "$$\n", "\n", "下面就这个公式的几个常见的疑问:\n", "\n", "* **$\\alpha$是什么含义?**\n", "$\\alpha$在梯度下降算法中被称作为学习率或者步长,意味着我们可以通过$\\alpha$来控制每一步走的距离,以保证不要步子跨的太大,错过了最低点。同时也要保证不要走的太慢,导致太阳下山了,还没有走到山下。所以$\\alpha$的选择在梯度下降法中往往是很重要的。\n", "![gd_stepsize](images/gd_stepsize.png)\n", "\n", "* **为什么要梯度要乘以一个负号?**\n", "梯度前加一个负号,就意味着朝着梯度相反的方向前进!梯度的方向实际就是函数在此点上升最快的方向,而我们需要朝着下降最快的方向走,自然就是负的梯度的方向,所以此处需要加上负号。\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 示例代码" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0: loss = 2590736.867664, a = 18.689017, b = 148.815329\n", "epoch 1: loss = 2557255.189453, a = 36.831348, b = 148.828655\n", "epoch 2: loss = 2525130.800853, a = 54.614824, b = 148.822534\n", "epoch 3: loss = 2494255.916234, a = 72.046438, b = 148.816532\n", "epoch 4: loss = 2464581.753419, a = 89.133153, b = 148.810649\n", "epoch 5: loss = 2436061.445196, a = 105.881792, b = 148.804882\n", "epoch 6: loss = 2408649.957111, a = 122.299046, b = 148.799229\n", "epoch 7: loss = 2382304.015732, a = 138.391470, b = 148.793687\n", "epoch 8: loss = 2356982.039720, a = 154.165492, b = 148.788256\n", "epoch 9: loss = 2332644.073594, a = 169.627411, b = 148.782932\n", "epoch 10: loss = 2309251.724102, a = 184.783402, b = 148.777713\n", "epoch 11: loss = 2286768.099073, a = 199.639520, b = 148.772598\n", "epoch 12: loss = 2265157.748666, a = 214.201696, b = 148.767584\n", "epoch 13: loss = 2244386.608923, a = 228.475748, b = 148.762669\n", "epoch 14: loss = 2224421.947529, a = 242.467375, b = 148.757851\n", "epoch 15: loss = 2205232.311697, a = 256.182166, b = 148.753129\n", "epoch 16: loss = 2186787.478095, a = 269.625597, b = 148.748500\n", "epoch 17: loss = 2169058.404731, a = 282.803039, b = 148.743962\n", "epoch 18: loss = 2152017.184726, a = 295.719754, b = 148.739515\n", "epoch 19: loss = 2135637.001889, a = 308.380901, b = 148.735155\n", "epoch 20: loss = 2119892.088045, a = 320.791536, b = 148.730882\n", "epoch 21: loss = 2104757.682020, a = 332.956616, b = 148.726693\n", "epoch 22: loss = 2090209.990240, a = 344.881000, b = 148.722587\n", "epoch 23: loss = 2076226.148871, a = 356.569449, b = 148.718562\n", "epoch 24: loss = 2062784.187440, a = 368.026633, b = 148.714617\n", "epoch 25: loss = 2049862.993883, a = 379.257126, b = 148.710750\n", "epoch 26: loss = 2037442.280956, a = 390.265414, b = 148.706960\n", "epoch 27: loss = 2025502.553972, a = 401.055894, b = 148.703244\n", "epoch 28: loss = 2014025.079785, a = 411.632875, b = 148.699602\n", "epoch 29: loss = 2002991.857005, a = 422.000581, b = 148.696032\n", "epoch 30: loss = 1992385.587369, a = 432.163154, b = 148.692533\n", "epoch 31: loss = 1982189.648231, a = 442.124651, b = 148.689103\n", "epoch 32: loss = 1972388.066142, a = 451.889051, b = 148.685741\n", "epoch 33: loss = 1962965.491447, a = 461.460255, b = 148.682445\n", "epoch 34: loss = 1953907.173892, a = 470.842084, b = 148.679215\n", "epoch 35: loss = 1945198.939172, a = 480.038285, b = 148.676048\n", "epoch 36: loss = 1936827.166411, a = 489.052532, b = 148.672944\n", "epoch 37: loss = 1928778.766513, a = 497.888424, b = 148.669902\n", "epoch 38: loss = 1921041.161364, a = 506.549490, b = 148.666919\n", "epoch 39: loss = 1913602.263848, a = 515.039190, b = 148.663996\n", "epoch 40: loss = 1906450.458648, a = 523.360914, b = 148.661131\n", "epoch 41: loss = 1899574.583794, a = 531.517985, b = 148.658322\n", "epoch 42: loss = 1892963.912935, a = 539.513662, b = 148.655569\n", "epoch 43: loss = 1886608.138306, a = 547.351138, b = 148.652870\n", "epoch 44: loss = 1880497.354362, a = 555.033542, b = 148.650225\n", "epoch 45: loss = 1874622.042047, a = 562.563943, b = 148.647632\n", "epoch 46: loss = 1868973.053689, a = 569.945349, b = 148.645090\n", "epoch 47: loss = 1863541.598476, a = 577.180708, b = 148.642599\n", "epoch 48: loss = 1858319.228507, a = 584.272908, b = 148.640157\n", "epoch 49: loss = 1853297.825385, a = 591.224784, b = 148.637763\n", "epoch 50: loss = 1848469.587337, a = 598.039110, b = 148.635417\n", "epoch 51: loss = 1843827.016840, a = 604.718609, b = 148.633117\n", "epoch 52: loss = 1839362.908723, a = 611.265949, b = 148.630862\n", "epoch 53: loss = 1835070.338746, a = 617.683744, b = 148.628653\n", "epoch 54: loss = 1830942.652617, a = 623.974558, b = 148.626486\n", "epoch 55: loss = 1826973.455442, a = 630.140902, b = 148.624363\n", "epoch 56: loss = 1823156.601589, a = 636.185240, b = 148.622282\n", "epoch 57: loss = 1819486.184948, a = 642.109985, b = 148.620242\n", "epoch 58: loss = 1815956.529568, a = 647.917504, b = 148.618242\n", "epoch 59: loss = 1812562.180670, a = 653.610117, b = 148.616282\n", "epoch 60: loss = 1809297.895998, a = 659.190096, b = 148.614361\n", "epoch 61: loss = 1806158.637522, a = 664.659671, b = 148.612477\n", "epoch 62: loss = 1803139.563458, a = 670.021025, b = 148.610631\n", "epoch 63: loss = 1800236.020598, a = 675.276301, b = 148.608822\n", "epoch 64: loss = 1797443.536950, a = 680.427596, b = 148.607048\n", "epoch 65: loss = 1794757.814654, a = 685.476969, b = 148.605309\n", "epoch 66: loss = 1792174.723185, a = 690.426435, b = 148.603605\n", "epoch 67: loss = 1789690.292815, a = 695.277972, b = 148.601934\n", "epoch 68: loss = 1787300.708335, a = 700.033517, b = 148.600297\n", "epoch 69: loss = 1785002.303019, a = 704.694970, b = 148.598692\n", "epoch 70: loss = 1782791.552830, a = 709.264191, b = 148.597119\n", "epoch 71: loss = 1780665.070844, a = 713.743007, b = 148.595576\n", "epoch 72: loss = 1778619.601901, a = 718.133206, b = 148.594065\n", "epoch 73: loss = 1776652.017457, a = 722.436540, b = 148.592583\n", "epoch 74: loss = 1774759.310646, a = 726.654730, b = 148.591130\n", "epoch 75: loss = 1772938.591527, a = 730.789459, b = 148.589707\n", "epoch 76: loss = 1771187.082519, a = 734.842379, b = 148.588311\n", "epoch 77: loss = 1769502.114022, a = 738.815108, b = 148.586943\n", "epoch 78: loss = 1767881.120197, a = 742.709234, b = 148.585602\n", "epoch 79: loss = 1766321.634920, a = 746.526311, b = 148.584288\n", "epoch 80: loss = 1764821.287889, a = 750.267864, b = 148.583000\n", "epoch 81: loss = 1763377.800890, a = 753.935387, b = 148.581737\n", "epoch 82: loss = 1761988.984199, a = 757.530345, b = 148.580499\n", "epoch 83: loss = 1760652.733134, a = 761.054173, b = 148.579286\n", "epoch 84: loss = 1759367.024737, a = 764.508280, b = 148.578096\n", "epoch 85: loss = 1758129.914584, a = 767.894044, b = 148.576931\n", "epoch 86: loss = 1756939.533728, a = 771.212818, b = 148.575788\n", "epoch 87: loss = 1755794.085750, a = 774.465927, b = 148.574668\n", "epoch 88: loss = 1754691.843935, a = 777.654671, b = 148.573570\n", "epoch 89: loss = 1753631.148552, a = 780.780322, b = 148.572493\n", "epoch 90: loss = 1752610.404243, a = 783.844130, b = 148.571438\n", "epoch 91: loss = 1751628.077518, a = 786.847318, b = 148.570404\n", "epoch 92: loss = 1750682.694337, a = 789.791085, b = 148.569391\n", "epoch 93: loss = 1749772.837797, a = 792.676607, b = 148.568397\n", "epoch 94: loss = 1748897.145906, a = 795.505037, b = 148.567423\n", "epoch 95: loss = 1748054.309442, a = 798.277504, b = 148.566469\n", "epoch 96: loss = 1747243.069899, a = 800.995115, b = 148.565533\n", "epoch 97: loss = 1746462.217508, a = 803.658956, b = 148.564616\n", "epoch 98: loss = 1745710.589343, a = 806.270090, b = 148.563716\n", "epoch 99: loss = 1744987.067493, a = 808.829561, b = 148.562835\n", "epoch 100: loss = 1744290.577312, a = 811.338391, b = 148.561971\n", "epoch 101: loss = 1743620.085732, a = 813.797581, b = 148.561125\n", "epoch 102: loss = 1742974.599648, a = 816.208114, b = 148.560295\n", "epoch 103: loss = 1742353.164357, a = 818.570953, b = 148.559481\n", "epoch 104: loss = 1741754.862068, a = 820.887041, b = 148.558683\n", "epoch 105: loss = 1741178.810465, a = 823.157303, b = 148.557902\n", "epoch 106: loss = 1740624.161324, a = 825.382646, b = 148.557135\n", "epoch 107: loss = 1740090.099189, a = 827.563959, b = 148.556384\n", "epoch 108: loss = 1739575.840097, a = 829.702112, b = 148.555648\n", "epoch 109: loss = 1739080.630352, a = 831.797961, b = 148.554926\n", "epoch 110: loss = 1738603.745351, a = 833.852341, b = 148.554219\n", "epoch 111: loss = 1738144.488447, a = 835.866073, b = 148.553526\n", "epoch 112: loss = 1737702.189871, a = 837.839962, b = 148.552846\n", "epoch 113: loss = 1737276.205677, a = 839.774796, b = 148.552180\n", "epoch 114: loss = 1736865.916748, a = 841.671348, b = 148.551527\n", "epoch 115: loss = 1736470.727823, a = 843.530375, b = 148.550887\n", "epoch 116: loss = 1736090.066576, a = 845.352619, b = 148.550259\n", "epoch 117: loss = 1735723.382723, a = 847.138809, b = 148.549644\n", "epoch 118: loss = 1735370.147167, a = 848.889657, b = 148.549041\n", "epoch 119: loss = 1735029.851171, a = 850.605863, b = 148.548450\n", "epoch 120: loss = 1734702.005574, a = 852.288113, b = 148.547871\n", "epoch 121: loss = 1734386.140028, a = 853.937078, b = 148.547303\n", "epoch 122: loss = 1734081.802266, a = 855.553417, b = 148.546747\n", "epoch 123: loss = 1733788.557403, a = 857.137775, b = 148.546201\n", "epoch 124: loss = 1733505.987262, a = 858.690785, b = 148.545666\n", "epoch 125: loss = 1733233.689723, a = 860.213068, b = 148.545142\n", "epoch 126: loss = 1732971.278103, a = 861.705231, b = 148.544628\n", "epoch 127: loss = 1732718.380559, a = 863.167870, b = 148.544125\n", "epoch 128: loss = 1732474.639507, a = 864.601570, b = 148.543631\n", "epoch 129: loss = 1732239.711074, a = 866.006902, b = 148.543147\n", "epoch 130: loss = 1732013.264567, a = 867.384429, b = 148.542673\n", "epoch 131: loss = 1731794.981959, a = 868.734701, b = 148.542208\n", "epoch 132: loss = 1731584.557401, a = 870.058256, b = 148.541752\n", "epoch 133: loss = 1731381.696752, a = 871.355623, b = 148.541306\n", "epoch 134: loss = 1731186.117121, a = 872.627321, b = 148.540868\n", "epoch 135: loss = 1730997.546436, a = 873.873858, b = 148.540438\n", "epoch 136: loss = 1730815.723022, a = 875.095730, b = 148.540018\n", "epoch 137: loss = 1730640.395202, a = 876.293427, b = 148.539605\n", "epoch 138: loss = 1730471.320908, a = 877.467426, b = 148.539201\n", "epoch 139: loss = 1730308.267311, a = 878.618197, b = 148.538805\n", "epoch 140: loss = 1730151.010464, a = 879.746199, b = 148.538416\n", "epoch 141: loss = 1729999.334955, a = 880.851882, b = 148.538036\n", "epoch 142: loss = 1729853.033583, a = 881.935688, b = 148.537663\n", "epoch 143: loss = 1729711.907037, a = 882.998050, b = 148.537297\n", "epoch 144: loss = 1729575.763593, a = 884.039393, b = 148.536938\n", "epoch 145: loss = 1729444.418820, a = 885.060132, b = 148.536587\n", "epoch 146: loss = 1729317.695300, a = 886.060674, b = 148.536242\n", "epoch 147: loss = 1729195.422355, a = 887.041420, b = 148.535904\n", "epoch 148: loss = 1729077.435792, a = 888.002761, b = 148.535573\n", "epoch 149: loss = 1728963.577645, a = 888.945081, b = 148.535249\n", "epoch 150: loss = 1728853.695944, a = 889.868757, b = 148.534931\n", "epoch 151: loss = 1728747.644477, a = 890.774156, b = 148.534619\n", "epoch 152: loss = 1728645.282573, a = 891.661642, b = 148.534314\n", "epoch 153: loss = 1728546.474885, a = 892.531568, b = 148.534014\n", "epoch 154: loss = 1728451.091187, a = 893.384282, b = 148.533720\n", "epoch 155: loss = 1728359.006178, a = 894.220124, b = 148.533433\n", "epoch 156: loss = 1728270.099288, a = 895.039428, b = 148.533150\n", "epoch 157: loss = 1728184.254503, a = 895.842521, b = 148.532874\n", "epoch 158: loss = 1728101.360185, a = 896.629725, b = 148.532603\n", "epoch 159: loss = 1728021.308903, a = 897.401353, b = 148.532337\n", "epoch 160: loss = 1727943.997275, a = 898.157714, b = 148.532077\n", "epoch 161: loss = 1727869.325813, a = 898.899110, b = 148.531821\n", "epoch 162: loss = 1727797.198769, a = 899.625836, b = 148.531571\n", "epoch 163: loss = 1727727.523995, a = 900.338184, b = 148.531326\n", "epoch 164: loss = 1727660.212803, a = 901.036437, b = 148.531086\n", "epoch 165: loss = 1727595.179837, a = 901.720875, b = 148.530850\n", "epoch 166: loss = 1727532.342938, a = 902.391770, b = 148.530619\n", "epoch 167: loss = 1727471.623027, a = 903.049391, b = 148.530392\n", "epoch 168: loss = 1727412.943986, a = 903.694001, b = 148.530170\n", "epoch 169: loss = 1727356.232544, a = 904.325856, b = 148.529953\n", "epoch 170: loss = 1727301.418168, a = 904.945210, b = 148.529740\n", "epoch 171: loss = 1727248.432959, a = 905.552309, b = 148.529531\n", "epoch 172: loss = 1727197.211551, a = 906.147396, b = 148.529326\n", "epoch 173: loss = 1727147.691014, a = 906.730709, b = 148.529125\n", "epoch 174: loss = 1727099.810763, a = 907.302481, b = 148.528928\n", "epoch 175: loss = 1727053.512464, a = 907.862939, b = 148.528735\n", "epoch 176: loss = 1727008.739955, a = 908.412309, b = 148.528546\n", "epoch 177: loss = 1726965.439156, a = 908.950808, b = 148.528360\n", "epoch 178: loss = 1726923.557995, a = 909.478653, b = 148.528179\n", "epoch 179: loss = 1726883.046328, a = 909.996055, b = 148.528000\n", "epoch 180: loss = 1726843.855869, a = 910.503219, b = 148.527826\n", "epoch 181: loss = 1726805.940116, a = 911.000348, b = 148.527655\n", "epoch 182: loss = 1726769.254286, a = 911.487641, b = 148.527487\n", "epoch 183: loss = 1726733.755249, a = 911.965292, b = 148.527322\n", "epoch 184: loss = 1726699.401462, a = 912.433493, b = 148.527161\n", "epoch 185: loss = 1726666.152915, a = 912.892430, b = 148.527003\n", "epoch 186: loss = 1726633.971067, a = 913.342287, b = 148.526848\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch 187: loss = 1726602.818794, a = 913.783243, b = 148.526696\n", "epoch 188: loss = 1726572.660334, a = 914.215474, b = 148.526548\n", "epoch 189: loss = 1726543.461235, a = 914.639153, b = 148.526402\n", "epoch 190: loss = 1726515.188306, a = 915.054449, b = 148.526259\n", "epoch 191: loss = 1726487.809572, a = 915.461529, b = 148.526119\n", "epoch 192: loss = 1726461.294221, a = 915.860553, b = 148.525981\n", "epoch 193: loss = 1726435.612569, a = 916.251683, b = 148.525846\n", "epoch 194: loss = 1726410.736010, a = 916.635074, b = 148.525714\n", "epoch 195: loss = 1726386.636980, a = 917.010879, b = 148.525585\n", "epoch 196: loss = 1726363.288915, a = 917.379249, b = 148.525458\n", "epoch 197: loss = 1726340.666217, a = 917.740330, b = 148.525334\n", "epoch 198: loss = 1726318.744212, a = 918.094267, b = 148.525212\n", "epoch 199: loss = 1726297.499121, a = 918.441201, b = 148.525093\n", "epoch 200: loss = 1726276.908024, a = 918.781270, b = 148.524975\n", "epoch 201: loss = 1726256.948827, a = 919.114611, b = 148.524861\n", "epoch 202: loss = 1726237.600231, a = 919.441356, b = 148.524748\n", "epoch 203: loss = 1726218.841703, a = 919.761637, b = 148.524638\n", "epoch 204: loss = 1726200.653451, a = 920.075580, b = 148.524530\n", "epoch 205: loss = 1726183.016388, a = 920.383312, b = 148.524424\n", "epoch 206: loss = 1726165.912113, a = 920.684955, b = 148.524320\n", "epoch 207: loss = 1726149.322881, a = 920.980630, b = 148.524218\n", "epoch 208: loss = 1726133.231583, a = 921.270455, b = 148.524118\n", "epoch 209: loss = 1726117.621717, a = 921.554545, b = 148.524021\n", "epoch 210: loss = 1726102.477370, a = 921.833014, b = 148.523925\n", "epoch 211: loss = 1726087.783192, a = 922.105974, b = 148.523831\n", "epoch 212: loss = 1726073.524378, a = 922.373533, b = 148.523739\n", "epoch 213: loss = 1726059.686650, a = 922.635798, b = 148.523648\n", "epoch 214: loss = 1726046.256229, a = 922.892873, b = 148.523560\n", "epoch 215: loss = 1726033.219826, a = 923.144863, b = 148.523473\n", "epoch 216: loss = 1726020.564620, a = 923.391866, b = 148.523388\n", "epoch 217: loss = 1726008.278237, a = 923.633982, b = 148.523305\n", "epoch 218: loss = 1725996.348741, a = 923.871308, b = 148.523223\n", "epoch 219: loss = 1725984.764612, a = 924.103938, b = 148.523143\n", "epoch 220: loss = 1725973.514733, a = 924.331966, b = 148.523064\n", "epoch 221: loss = 1725962.588374, a = 924.555481, b = 148.522987\n", "epoch 222: loss = 1725951.975181, a = 924.774575, b = 148.522912\n", "epoch 223: loss = 1725941.665156, a = 924.989333, b = 148.522838\n", "epoch 224: loss = 1725931.648652, a = 925.199842, b = 148.522765\n", "epoch 225: loss = 1725921.916352, a = 925.406186, b = 148.522694\n", "epoch 226: loss = 1725912.459264, a = 925.608447, b = 148.522625\n", "epoch 227: loss = 1725903.268703, a = 925.806706, b = 148.522556\n", "epoch 228: loss = 1725894.336285, a = 926.001043, b = 148.522489\n", "epoch 229: loss = 1725885.653913, a = 926.191534, b = 148.522424\n", "epoch 230: loss = 1725877.213768, a = 926.378257, b = 148.522360\n", "epoch 231: loss = 1725869.008296, a = 926.561285, b = 148.522297\n", "epoch 232: loss = 1725861.030202, a = 926.740692, b = 148.522235\n", "epoch 233: loss = 1725853.272442, a = 926.916548, b = 148.522174\n", "epoch 234: loss = 1725845.728205, a = 927.088926, b = 148.522115\n", "epoch 235: loss = 1725838.390917, a = 927.257893, b = 148.522057\n", "epoch 236: loss = 1725831.254223, a = 927.423516, b = 148.522000\n", "epoch 237: loss = 1725824.311982, a = 927.585863, b = 148.521944\n", "epoch 238: loss = 1725817.558260, a = 927.744997, b = 148.521889\n", "epoch 239: loss = 1725810.987324, a = 927.900983, b = 148.521835\n", "epoch 240: loss = 1725804.593630, a = 928.053883, b = 148.521783\n", "epoch 241: loss = 1725798.371821, a = 928.203757, b = 148.521731\n", "epoch 242: loss = 1725792.316718, a = 928.350666, b = 148.521680\n", "epoch 243: loss = 1725786.423315, a = 928.494668, b = 148.521631\n", "epoch 244: loss = 1725780.686770, a = 928.635821, b = 148.521582\n", "epoch 245: loss = 1725775.102403, a = 928.774181, b = 148.521535\n", "epoch 246: loss = 1725769.665688, a = 928.909804, b = 148.521488\n", "epoch 247: loss = 1725764.372248, a = 929.042743, b = 148.521442\n", "epoch 248: loss = 1725759.217848, a = 929.173052, b = 148.521397\n", "epoch 249: loss = 1725754.198393, a = 929.300783, b = 148.521353\n", "epoch 250: loss = 1725749.309922, a = 929.425986, b = 148.521310\n", "epoch 251: loss = 1725744.548603, a = 929.548712, b = 148.521268\n", "epoch 252: loss = 1725739.910728, a = 929.669010, b = 148.521226\n", "epoch 253: loss = 1725735.392708, a = 929.786928, b = 148.521186\n", "epoch 254: loss = 1725730.991072, a = 929.902512, b = 148.521146\n", "epoch 255: loss = 1725726.702459, a = 930.015810, b = 148.521107\n", "epoch 256: loss = 1725722.523617, a = 930.126866, b = 148.521069\n", "epoch 257: loss = 1725718.451397, a = 930.235724, b = 148.521031\n", "epoch 258: loss = 1725714.482753, a = 930.342429, b = 148.520995\n", "epoch 259: loss = 1725710.614734, a = 930.447022, b = 148.520959\n", "epoch 260: loss = 1725706.844482, a = 930.549546, b = 148.520923\n", "epoch 261: loss = 1725703.169232, a = 930.650042, b = 148.520889\n", "epoch 262: loss = 1725699.586303, a = 930.748549, b = 148.520855\n", "epoch 263: loss = 1725696.093102, a = 930.845107, b = 148.520822\n", "epoch 264: loss = 1725692.687115, a = 930.939754, b = 148.520789\n", "epoch 265: loss = 1725689.365908, a = 931.032529, b = 148.520757\n", "epoch 266: loss = 1725686.127121, a = 931.123468, b = 148.520726\n", "epoch 267: loss = 1725682.968469, a = 931.212608, b = 148.520695\n", "epoch 268: loss = 1725679.887739, a = 931.299985, b = 148.520665\n", "epoch 269: loss = 1725676.882782, a = 931.385632, b = 148.520635\n", "epoch 270: loss = 1725673.951521, a = 931.469585, b = 148.520606\n", "epoch 271: loss = 1725671.091938, a = 931.551877, b = 148.520578\n", "epoch 272: loss = 1725668.302080, a = 931.632540, b = 148.520550\n", "epoch 273: loss = 1725665.580052, a = 931.711608, b = 148.520523\n", "epoch 274: loss = 1725662.924017, a = 931.789111, b = 148.520496\n", "epoch 275: loss = 1725660.332194, a = 931.865080, b = 148.520470\n", "epoch 276: loss = 1725657.802855, a = 931.939547, b = 148.520445\n", "epoch 277: loss = 1725655.334325, a = 932.012540, b = 148.520420\n", "epoch 278: loss = 1725652.924980, a = 932.084089, b = 148.520395\n", "epoch 279: loss = 1725650.573243, a = 932.154222, b = 148.520371\n", "epoch 280: loss = 1725648.277584, a = 932.222968, b = 148.520347\n", "epoch 281: loss = 1725646.036521, a = 932.290353, b = 148.520324\n", "epoch 282: loss = 1725643.848614, a = 932.356405, b = 148.520301\n", "epoch 283: loss = 1725641.712465, a = 932.421150, b = 148.520279\n", "epoch 284: loss = 1725639.626719, a = 932.484614, b = 148.520257\n", "epoch 285: loss = 1725637.590060, a = 932.546823, b = 148.520236\n", "epoch 286: loss = 1725635.601209, a = 932.607801, b = 148.520215\n", "epoch 287: loss = 1725633.658927, a = 932.667572, b = 148.520194\n", "epoch 288: loss = 1725631.762010, a = 932.726160, b = 148.520174\n", "epoch 289: loss = 1725629.909289, a = 932.783590, b = 148.520154\n", "epoch 290: loss = 1725628.099627, a = 932.839883, b = 148.520135\n", "epoch 291: loss = 1725626.331922, a = 932.895062, b = 148.520116\n", "epoch 292: loss = 1725624.605102, a = 932.949149, b = 148.520097\n", "epoch 293: loss = 1725622.918128, a = 933.002166, b = 148.520079\n", "epoch 294: loss = 1725621.269989, a = 933.054135, b = 148.520061\n", "epoch 295: loss = 1725619.659702, a = 933.105075, b = 148.520043\n", "epoch 296: loss = 1725618.086312, a = 933.155007, b = 148.520026\n", "epoch 297: loss = 1725616.548894, a = 933.203951, b = 148.520009\n", "epoch 298: loss = 1725615.046544, a = 933.251927, b = 148.519993\n", "epoch 299: loss = 1725613.578388, a = 933.298953, b = 148.519977\n", "epoch 300: loss = 1725612.143573, a = 933.345050, b = 148.519961\n", "epoch 301: loss = 1725610.741272, a = 933.390234, b = 148.519945\n", "epoch 302: loss = 1725609.370679, a = 933.434524, b = 148.519930\n", "epoch 303: loss = 1725608.031013, a = 933.477937, b = 148.519915\n", "epoch 304: loss = 1725606.721512, a = 933.520492, b = 148.519900\n", "epoch 305: loss = 1725605.441435, a = 933.562205, b = 148.519886\n", "epoch 306: loss = 1725604.190064, a = 933.603092, b = 148.519872\n", "epoch 307: loss = 1725602.966697, a = 933.643171, b = 148.519858\n", "epoch 308: loss = 1725601.770653, a = 933.682456, b = 148.519845\n", "epoch 309: loss = 1725600.601270, a = 933.720964, b = 148.519831\n", "epoch 310: loss = 1725599.457903, a = 933.758711, b = 148.519818\n", "epoch 311: loss = 1725598.339923, a = 933.795710, b = 148.519806\n", "epoch 312: loss = 1725597.246722, a = 933.831977, b = 148.519793\n", "epoch 313: loss = 1725596.177703, a = 933.867527, b = 148.519781\n", "epoch 314: loss = 1725595.132289, a = 933.902373, b = 148.519769\n", "epoch 315: loss = 1725594.109917, a = 933.936530, b = 148.519757\n", "epoch 316: loss = 1725593.110038, a = 933.970011, b = 148.519746\n", "epoch 317: loss = 1725592.132118, a = 934.002830, b = 148.519734\n", "epoch 318: loss = 1725591.175638, a = 934.034999, b = 148.519723\n", "epoch 319: loss = 1725590.240092, a = 934.066532, b = 148.519712\n", "epoch 320: loss = 1725589.324986, a = 934.097441, b = 148.519702\n", "epoch 321: loss = 1725588.429841, a = 934.127738, b = 148.519691\n", "epoch 322: loss = 1725587.554189, a = 934.157436, b = 148.519681\n", "epoch 323: loss = 1725586.697574, a = 934.186546, b = 148.519671\n", "epoch 324: loss = 1725585.859554, a = 934.215081, b = 148.519661\n", "epoch 325: loss = 1725585.039694, a = 934.243051, b = 148.519651\n", "epoch 326: loss = 1725584.237575, a = 934.270467, b = 148.519642\n", "epoch 327: loss = 1725583.452786, a = 934.297341, b = 148.519633\n", "epoch 328: loss = 1725582.684926, a = 934.323683, b = 148.519624\n", "epoch 329: loss = 1725581.933606, a = 934.349504, b = 148.519615\n", "epoch 330: loss = 1725581.198445, a = 934.374814, b = 148.519606\n", "epoch 331: loss = 1725580.479074, a = 934.399623, b = 148.519598\n", "epoch 332: loss = 1725579.775131, a = 934.423941, b = 148.519589\n", "epoch 333: loss = 1725579.086264, a = 934.447779, b = 148.519581\n", "epoch 334: loss = 1725578.412130, a = 934.471144, b = 148.519573\n", "epoch 335: loss = 1725577.752394, a = 934.494048, b = 148.519565\n", "epoch 336: loss = 1725577.106730, a = 934.516498, b = 148.519557\n", "epoch 337: loss = 1725576.474819, a = 934.538504, b = 148.519550\n", "epoch 338: loss = 1725575.856351, a = 934.560074, b = 148.519542\n", "epoch 339: loss = 1725575.251023, a = 934.581218, b = 148.519535\n", "epoch 340: loss = 1725574.658540, a = 934.601943, b = 148.519528\n", "epoch 341: loss = 1725574.078613, a = 934.622259, b = 148.519521\n", "epoch 342: loss = 1725573.510962, a = 934.642172, b = 148.519514\n", "epoch 343: loss = 1725572.955312, a = 934.661691, b = 148.519507\n", "epoch 344: loss = 1725572.411396, a = 934.680825, b = 148.519501\n", "epoch 345: loss = 1725571.878952, a = 934.699579, b = 148.519494\n", "epoch 346: loss = 1725571.357726, a = 934.717963, b = 148.519488\n", "epoch 347: loss = 1725570.847468, a = 934.735982, b = 148.519482\n", "epoch 348: loss = 1725570.347937, a = 934.753646, b = 148.519476\n", "epoch 349: loss = 1725569.858895, a = 934.770959, b = 148.519470\n", "epoch 350: loss = 1725569.380112, a = 934.787931, b = 148.519464\n", "epoch 351: loss = 1725568.911360, a = 934.804566, b = 148.519458\n", "epoch 352: loss = 1725568.452420, a = 934.820872, b = 148.519453\n", "epoch 353: loss = 1725568.003077, a = 934.836856, b = 148.519447\n", "epoch 354: loss = 1725567.563120, a = 934.852523, b = 148.519442\n", "epoch 355: loss = 1725567.132345, a = 934.867881, b = 148.519436\n", "epoch 356: loss = 1725566.710551, a = 934.882934, b = 148.519431\n", "epoch 357: loss = 1725566.297542, a = 934.897690, b = 148.519426\n", "epoch 358: loss = 1725565.893128, a = 934.912154, b = 148.519421\n", "epoch 359: loss = 1725565.497121, a = 934.926331, b = 148.519416\n", "epoch 360: loss = 1725565.109340, a = 934.940228, b = 148.519411\n", "epoch 361: loss = 1725564.729607, a = 934.953850, b = 148.519407\n", "epoch 362: loss = 1725564.357747, a = 934.967203, b = 148.519402\n", "epoch 363: loss = 1725563.993590, a = 934.980291, b = 148.519398\n", "epoch 364: loss = 1725563.636972, a = 934.993120, b = 148.519393\n", "epoch 365: loss = 1725563.287729, a = 935.005696, b = 148.519389\n", "epoch 366: loss = 1725562.945702, a = 935.018023, b = 148.519385\n", "epoch 367: loss = 1725562.610739, a = 935.030106, b = 148.519380\n", "epoch 368: loss = 1725562.282686, a = 935.041949, b = 148.519376\n", "epoch 369: loss = 1725561.961396, a = 935.053559, b = 148.519372\n", "epoch 370: loss = 1725561.646724, a = 935.064938, b = 148.519368\n", "epoch 371: loss = 1725561.338531, a = 935.076093, b = 148.519365\n", "epoch 372: loss = 1725561.036676, a = 935.087027, b = 148.519361\n", "epoch 373: loss = 1725560.741026, a = 935.097744, b = 148.519357\n", "epoch 374: loss = 1725560.451449, a = 935.108250, b = 148.519354\n", "epoch 375: loss = 1725560.167816, a = 935.118547, b = 148.519350\n", "epoch 376: loss = 1725559.890000, a = 935.128641, b = 148.519347\n", "epoch 377: loss = 1725559.617879, a = 935.138535, b = 148.519343\n", "epoch 378: loss = 1725559.351332, a = 935.148234, b = 148.519340\n", "epoch 379: loss = 1725559.090241, a = 935.157740, b = 148.519337\n", "epoch 380: loss = 1725558.834492, a = 935.167059, b = 148.519333\n", "epoch 381: loss = 1725558.583972, a = 935.176193, b = 148.519330\n", "epoch 382: loss = 1725558.338570, a = 935.185146, b = 148.519327\n", "epoch 383: loss = 1725558.098179, a = 935.193922, b = 148.519324\n", "epoch 384: loss = 1725557.862695, a = 935.202525, b = 148.519321\n", "epoch 385: loss = 1725557.632013, a = 935.210957, b = 148.519318\n", "epoch 386: loss = 1725557.406033, a = 935.219223, b = 148.519315\n", "epoch 387: loss = 1725557.184657, a = 935.227324, b = 148.519313\n", "epoch 388: loss = 1725556.967789, a = 935.235266, b = 148.519310\n", "epoch 389: loss = 1725556.755334, a = 935.243051, b = 148.519307\n", "epoch 390: loss = 1725556.547200, a = 935.250681, b = 148.519305\n", "epoch 391: loss = 1725556.343298, a = 935.258160, b = 148.519302\n", "epoch 392: loss = 1725556.143538, a = 935.265492, b = 148.519299\n", "epoch 393: loss = 1725555.947836, a = 935.272678, b = 148.519297\n", "epoch 394: loss = 1725555.756105, a = 935.279723, b = 148.519295\n", "epoch 395: loss = 1725555.568265, a = 935.286628, b = 148.519292\n", "epoch 396: loss = 1725555.384233, a = 935.293396, b = 148.519290\n", "epoch 397: loss = 1725555.203932, a = 935.300030, b = 148.519288\n", "epoch 398: loss = 1725555.027283, a = 935.306533, b = 148.519285\n", "epoch 399: loss = 1725554.854212, a = 935.312908, b = 148.519283\n", "epoch 400: loss = 1725554.684644, a = 935.319156, b = 148.519281\n", "epoch 401: loss = 1725554.518507, a = 935.325281, b = 148.519279\n", "epoch 402: loss = 1725554.355730, a = 935.331284, b = 148.519277\n", "epoch 403: loss = 1725554.196244, a = 935.337169, b = 148.519275\n", "epoch 404: loss = 1725554.039980, a = 935.342937, b = 148.519273\n", "epoch 405: loss = 1725553.886873, a = 935.348591, b = 148.519271\n", "epoch 406: loss = 1725553.736857, a = 935.354133, b = 148.519269\n", "epoch 407: loss = 1725553.589869, a = 935.359566, b = 148.519267\n", "epoch 408: loss = 1725553.445846, a = 935.364891, b = 148.519265\n", "epoch 409: loss = 1725553.304728, a = 935.370111, b = 148.519263\n", "epoch 410: loss = 1725553.166456, a = 935.375227, b = 148.519262\n", "epoch 411: loss = 1725553.030969, a = 935.380242, b = 148.519260\n", "epoch 412: loss = 1725552.898213, a = 935.385158, b = 148.519258\n", "epoch 413: loss = 1725552.768130, a = 935.389977, b = 148.519257\n", "epoch 414: loss = 1725552.640666, a = 935.394701, b = 148.519255\n", "epoch 415: loss = 1725552.515767, a = 935.399331, b = 148.519253\n", "epoch 416: loss = 1725552.393381, a = 935.403869, b = 148.519252\n", "epoch 417: loss = 1725552.273456, a = 935.408317, b = 148.519250\n", "epoch 418: loss = 1725552.155943, a = 935.412678, b = 148.519249\n", "epoch 419: loss = 1725552.040792, a = 935.416952, b = 148.519247\n", "epoch 420: loss = 1725551.927954, a = 935.421142, b = 148.519246\n", "epoch 421: loss = 1725551.817384, a = 935.425249, b = 148.519244\n", "epoch 422: loss = 1725551.709033, a = 935.429274, b = 148.519243\n", "epoch 423: loss = 1725551.602858, a = 935.433220, b = 148.519242\n", "epoch 424: loss = 1725551.498814, a = 935.437088, b = 148.519240\n", "epoch 425: loss = 1725551.396858, a = 935.440879, b = 148.519239\n", "epoch 426: loss = 1725551.296947, a = 935.444595, b = 148.519238\n", "epoch 427: loss = 1725551.199039, a = 935.448238, b = 148.519237\n", "epoch 428: loss = 1725551.103094, a = 935.451809, b = 148.519235\n", "epoch 429: loss = 1725551.009073, a = 935.455309, b = 148.519234\n", "epoch 430: loss = 1725550.916935, a = 935.458739, b = 148.519233\n", "epoch 431: loss = 1725550.826644, a = 935.462102, b = 148.519232\n", "epoch 432: loss = 1725550.738161, a = 935.465399, b = 148.519231\n", "epoch 433: loss = 1725550.651449, a = 935.468630, b = 148.519229\n", "epoch 434: loss = 1725550.566474, a = 935.471797, b = 148.519228\n", "epoch 435: loss = 1725550.483199, a = 935.474901, b = 148.519227\n", "epoch 436: loss = 1725550.401591, a = 935.477945, b = 148.519226\n", "epoch 437: loss = 1725550.321615, a = 935.480927, b = 148.519225\n", "epoch 438: loss = 1725550.243239, a = 935.483851, b = 148.519224\n", "epoch 439: loss = 1725550.166431, a = 935.486717, b = 148.519223\n", "epoch 440: loss = 1725550.091158, a = 935.489527, b = 148.519222\n", "epoch 441: loss = 1725550.017389, a = 935.492280, b = 148.519221\n", "epoch 442: loss = 1725549.945095, a = 935.494980, b = 148.519220\n", "epoch 443: loss = 1725549.874246, a = 935.497625, b = 148.519219\n", "epoch 444: loss = 1725549.804812, a = 935.500219, b = 148.519219\n", "epoch 445: loss = 1725549.736765, a = 935.502761, b = 148.519218\n", "epoch 446: loss = 1725549.670077, a = 935.505253, b = 148.519217\n", "epoch 447: loss = 1725549.604720, a = 935.507696, b = 148.519216\n", "epoch 448: loss = 1725549.540668, a = 935.510090, b = 148.519215\n", "epoch 449: loss = 1725549.477894, a = 935.512437, b = 148.519214\n", "epoch 450: loss = 1725549.416373, a = 935.514737, b = 148.519214\n", "epoch 451: loss = 1725549.356080, a = 935.516992, b = 148.519213\n", "epoch 452: loss = 1725549.296990, a = 935.519202, b = 148.519212\n", "epoch 453: loss = 1725549.239078, a = 935.521369, b = 148.519211\n", "epoch 454: loss = 1725549.182321, a = 935.523493, b = 148.519211\n", "epoch 455: loss = 1725549.126696, a = 935.525574, b = 148.519210\n", "epoch 456: loss = 1725549.072179, a = 935.527615, b = 148.519209\n", "epoch 457: loss = 1725549.018750, a = 935.529615, b = 148.519208\n", "epoch 458: loss = 1725548.966385, a = 935.531575, b = 148.519208\n", "epoch 459: loss = 1725548.915064, a = 935.533497, b = 148.519207\n", "epoch 460: loss = 1725548.864766, a = 935.535381, b = 148.519206\n", "epoch 461: loss = 1725548.815470, a = 935.537227, b = 148.519206\n", "epoch 462: loss = 1725548.767155, a = 935.539037, b = 148.519205\n", "epoch 463: loss = 1725548.719803, a = 935.540811, b = 148.519205\n", "epoch 464: loss = 1725548.673394, a = 935.542550, b = 148.519204\n", "epoch 465: loss = 1725548.627910, a = 935.544255, b = 148.519203\n", "epoch 466: loss = 1725548.583330, a = 935.545926, b = 148.519203\n", "epoch 467: loss = 1725548.539638, a = 935.547564, b = 148.519202\n", "epoch 468: loss = 1725548.496816, a = 935.549169, b = 148.519202\n", "epoch 469: loss = 1725548.454846, a = 935.550743, b = 148.519201\n", "epoch 470: loss = 1725548.413712, a = 935.552285, b = 148.519201\n", "epoch 471: loss = 1725548.373396, a = 935.553797, b = 148.519200\n", "epoch 472: loss = 1725548.333882, a = 935.555279, b = 148.519200\n", "epoch 473: loss = 1725548.295154, a = 935.556732, b = 148.519199\n", "epoch 474: loss = 1725548.257196, a = 935.558156, b = 148.519199\n", "epoch 475: loss = 1725548.219993, a = 935.559552, b = 148.519198\n", "epoch 476: loss = 1725548.183531, a = 935.560920, b = 148.519198\n", "epoch 477: loss = 1725548.147793, a = 935.562261, b = 148.519197\n", "epoch 478: loss = 1725548.112766, a = 935.563576, b = 148.519197\n", "epoch 479: loss = 1725548.078435, a = 935.564864, b = 148.519196\n", "epoch 480: loss = 1725548.044787, a = 935.566128, b = 148.519196\n", "epoch 481: loss = 1725548.011808, a = 935.567366, b = 148.519195\n", "epoch 482: loss = 1725547.979484, a = 935.568579, b = 148.519195\n", "epoch 483: loss = 1725547.947802, a = 935.569769, b = 148.519195\n", "epoch 484: loss = 1725547.916751, a = 935.570935, b = 148.519194\n", "epoch 485: loss = 1725547.886316, a = 935.572078, b = 148.519194\n", "epoch 486: loss = 1725547.856486, a = 935.573198, b = 148.519193\n", "epoch 487: loss = 1725547.827248, a = 935.574296, b = 148.519193\n", "epoch 488: loss = 1725547.798592, a = 935.575373, b = 148.519193\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch 489: loss = 1725547.770504, a = 935.576428, b = 148.519192\n", "epoch 490: loss = 1725547.742975, a = 935.577462, b = 148.519192\n", "epoch 491: loss = 1725547.715992, a = 935.578476, b = 148.519192\n", "epoch 492: loss = 1725547.689545, a = 935.579470, b = 148.519191\n", "epoch 493: loss = 1725547.663624, a = 935.580444, b = 148.519191\n", "epoch 494: loss = 1725547.638217, a = 935.581399, b = 148.519191\n", "epoch 495: loss = 1725547.613314, a = 935.582335, b = 148.519190\n", "epoch 496: loss = 1725547.588906, a = 935.583252, b = 148.519190\n", "epoch 497: loss = 1725547.564983, a = 935.584152, b = 148.519190\n", "epoch 498: loss = 1725547.541534, a = 935.585033, b = 148.519189\n", "epoch 499: loss = 1725547.518551, a = 935.585897, b = 148.519189\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "n_epoch = 500 # epoch size\n", "a, b = 1, 1 # initial parameters\n", "epsilon = 0.01 # learning rate\n", "\n", "for i in range(n_epoch):\n", " for j in range(N):\n", " a = a + epsilon*2*(Y[j] - a*X[j] - b)*X[j]\n", " b = b + epsilon*2*(Y[j] - a*X[j] - b)\n", "\n", " L = 0\n", " for j in range(N):\n", " L = L + (Y[j]-a*X[j]-b)**2\n", " print(\"epoch %4d: loss = %f, a = %f, b = %f\" % (i, L, a, b))\n", " \n", "x_min = np.min(X)\n", "x_max = np.max(X)\n", "y_min = a * x_min + b\n", "y_max = a * x_max + b\n", "\n", "plt.scatter(X, Y, label='original data')\n", "plt.plot([x_min, x_max], [y_min, y_max], 'r', label='model')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. 如何可视化迭代过程" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "/* global mpl */\n", "window.mpl = {};\n", "\n", "mpl.get_websocket_type = function () {\n", " if (typeof WebSocket !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof MozWebSocket !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert(\n", " 'Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.'\n", " );\n", " }\n", "};\n", "\n", "mpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = this.ws.binaryType !== undefined;\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById('mpl-warnings');\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent =\n", " 'This browser does not support binary websocket messages. ' +\n", " 'Performance may be slow.';\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = document.createElement('div');\n", " this.root.setAttribute('style', 'display: inline-block');\n", " this._root_extra_style(this.root);\n", "\n", " parent_element.appendChild(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message('supports_binary', { value: fig.supports_binary });\n", " fig.send_message('send_image_mode', {});\n", " if (mpl.ratio !== 1) {\n", " fig.send_message('set_dpi_ratio', { dpi_ratio: mpl.ratio });\n", " }\n", " fig.send_message('refresh', {});\n", " };\n", "\n", " this.imageObj.onload = function () {\n", " if (fig.image_mode === 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function () {\n", " fig.ws.close();\n", " };\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "};\n", "\n", "mpl.figure.prototype._init_header = function () {\n", " var titlebar = document.createElement('div');\n", " titlebar.classList =\n", " 'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n", " var titletext = document.createElement('div');\n", " titletext.classList = 'ui-dialog-title';\n", " titletext.setAttribute(\n", " 'style',\n", " 'width: 100%; text-align: center; padding: 3px;'\n", " );\n", " titlebar.appendChild(titletext);\n", " this.root.appendChild(titlebar);\n", " this.header = titletext;\n", "};\n", "\n", "mpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n", "\n", "mpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n", "\n", "mpl.figure.prototype._init_canvas = function () {\n", " var fig = this;\n", "\n", " var canvas_div = (this.canvas_div = document.createElement('div'));\n", " canvas_div.setAttribute(\n", " 'style',\n", " 'border: 1px solid #ddd;' +\n", " 'box-sizing: content-box;' +\n", " 'clear: both;' +\n", " 'min-height: 1px;' +\n", " 'min-width: 1px;' +\n", " 'outline: 0;' +\n", " 'overflow: hidden;' +\n", " 'position: relative;' +\n", " 'resize: both;'\n", " );\n", "\n", " function on_keyboard_event_closure(name) {\n", " return function (event) {\n", " return fig.key_event(event, name);\n", " };\n", " }\n", "\n", " canvas_div.addEventListener(\n", " 'keydown',\n", " on_keyboard_event_closure('key_press')\n", " );\n", " canvas_div.addEventListener(\n", " 'keyup',\n", " on_keyboard_event_closure('key_release')\n", " );\n", "\n", " this._canvas_extra_style(canvas_div);\n", " this.root.appendChild(canvas_div);\n", "\n", " var canvas = (this.canvas = document.createElement('canvas'));\n", " canvas.classList.add('mpl-canvas');\n", " canvas.setAttribute('style', 'box-sizing: content-box;');\n", "\n", " this.context = canvas.getContext('2d');\n", "\n", " var backingStore =\n", " this.context.backingStorePixelRatio ||\n", " this.context.webkitBackingStorePixelRatio ||\n", " this.context.mozBackingStorePixelRatio ||\n", " this.context.msBackingStorePixelRatio ||\n", " this.context.oBackingStorePixelRatio ||\n", " this.context.backingStorePixelRatio ||\n", " 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n", " 'canvas'\n", " ));\n", " rubberband_canvas.setAttribute(\n", " 'style',\n", " 'box-sizing: content-box; position: absolute; left: 0; top: 0; z-index: 1;'\n", " );\n", "\n", " var resizeObserver = new ResizeObserver(function (entries) {\n", " var nentries = entries.length;\n", " for (var i = 0; i < nentries; i++) {\n", " var entry = entries[i];\n", " var width, height;\n", " if (entry.contentBoxSize) {\n", " width = entry.contentBoxSize.inlineSize;\n", " height = entry.contentBoxSize.blockSize;\n", " } else {\n", " width = entry.contentRect.width;\n", " height = entry.contentRect.height;\n", " }\n", "\n", " // Keep the size of the canvas and rubber band canvas in sync with\n", " // the canvas container.\n", " canvas.setAttribute('width', width * mpl.ratio);\n", " canvas.setAttribute('height', height * mpl.ratio);\n", " canvas.setAttribute(\n", " 'style',\n", " 'width: ' + width + 'px; height: ' + height + 'px;'\n", " );\n", "\n", " rubberband_canvas.setAttribute('width', width);\n", " rubberband_canvas.setAttribute('height', height);\n", "\n", " // And update the size in Python. We ignore the initial 0/0 size\n", " // that occurs as the element is placed into the DOM, which should\n", " // otherwise not happen due to the minimum size styling.\n", " if (width != 0 && height != 0) {\n", " fig.request_resize(width, height);\n", " }\n", " }\n", " });\n", " resizeObserver.observe(canvas_div);\n", "\n", " function on_mouse_event_closure(name) {\n", " return function (event) {\n", " return fig.mouse_event(event, name);\n", " };\n", " }\n", "\n", " rubberband_canvas.addEventListener(\n", " 'mousedown',\n", " on_mouse_event_closure('button_press')\n", " );\n", " rubberband_canvas.addEventListener(\n", " 'mouseup',\n", " on_mouse_event_closure('button_release')\n", " );\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband_canvas.addEventListener(\n", " 'mousemove',\n", " on_mouse_event_closure('motion_notify')\n", " );\n", "\n", " rubberband_canvas.addEventListener(\n", " 'mouseenter',\n", " on_mouse_event_closure('figure_enter')\n", " );\n", " rubberband_canvas.addEventListener(\n", " 'mouseleave',\n", " on_mouse_event_closure('figure_leave')\n", " );\n", "\n", " canvas_div.addEventListener('wheel', function (event) {\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " on_mouse_event_closure('scroll')(event);\n", " });\n", "\n", " canvas_div.appendChild(canvas);\n", " canvas_div.appendChild(rubberband_canvas);\n", "\n", " this.rubberband_context = rubberband_canvas.getContext('2d');\n", " this.rubberband_context.strokeStyle = '#000000';\n", "\n", " this._resize_canvas = function (width, height, forward) {\n", " if (forward) {\n", " canvas_div.style.width = width + 'px';\n", " canvas_div.style.height = height + 'px';\n", " }\n", " };\n", "\n", " // Disable right mouse context menu.\n", " this.rubberband_canvas.addEventListener('contextmenu', function (_e) {\n", " return false;\n", " });\n", "\n", " function set_focus() {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "};\n", "\n", "mpl.figure.prototype._init_toolbar = function () {\n", " var fig = this;\n", "\n", " var toolbar = document.createElement('div');\n", " toolbar.classList = 'mpl-toolbar';\n", " this.root.appendChild(toolbar);\n", "\n", " function on_click_closure(name) {\n", " return function (_event) {\n", " return fig.toolbar_button_onclick(name);\n", " };\n", " }\n", "\n", " function on_mouseover_closure(tooltip) {\n", " return function (event) {\n", " if (!event.currentTarget.disabled) {\n", " return fig.toolbar_button_onmouseover(tooltip);\n", " }\n", " };\n", " }\n", "\n", " fig.buttons = {};\n", " var buttonGroup = document.createElement('div');\n", " buttonGroup.classList = 'mpl-button-group';\n", " for (var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " /* Instead of a spacer, we start a new button group. */\n", " if (buttonGroup.hasChildNodes()) {\n", " toolbar.appendChild(buttonGroup);\n", " }\n", " buttonGroup = document.createElement('div');\n", " buttonGroup.classList = 'mpl-button-group';\n", " continue;\n", " }\n", "\n", " var button = (fig.buttons[name] = document.createElement('button'));\n", " button.classList = 'mpl-widget';\n", " button.setAttribute('role', 'button');\n", " button.setAttribute('aria-disabled', 'false');\n", " button.addEventListener('click', on_click_closure(method_name));\n", " button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n", "\n", " var icon_img = document.createElement('img');\n", " icon_img.src = '_images/' + image + '.png';\n", " icon_img.srcset = '_images/' + image + '_large.png 2x';\n", " icon_img.alt = tooltip;\n", " button.appendChild(icon_img);\n", "\n", " buttonGroup.appendChild(button);\n", " }\n", "\n", " if (buttonGroup.hasChildNodes()) {\n", " toolbar.appendChild(buttonGroup);\n", " }\n", "\n", " var fmt_picker = document.createElement('select');\n", " fmt_picker.classList = 'mpl-widget';\n", " toolbar.appendChild(fmt_picker);\n", " this.format_dropdown = fmt_picker;\n", "\n", " for (var ind in mpl.extensions) {\n", " var fmt = mpl.extensions[ind];\n", " var option = document.createElement('option');\n", " option.selected = fmt === mpl.default_extension;\n", " option.innerHTML = fmt;\n", " fmt_picker.appendChild(option);\n", " }\n", "\n", " var status_bar = document.createElement('span');\n", " status_bar.classList = 'mpl-message';\n", " toolbar.appendChild(status_bar);\n", " this.message = status_bar;\n", "};\n", "\n", "mpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n", " // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n", " // which will in turn request a refresh of the image.\n", " this.send_message('resize', { width: x_pixels, height: y_pixels });\n", "};\n", "\n", "mpl.figure.prototype.send_message = function (type, properties) {\n", " properties['type'] = type;\n", " properties['figure_id'] = this.id;\n", " this.ws.send(JSON.stringify(properties));\n", "};\n", "\n", "mpl.figure.prototype.send_draw_message = function () {\n", " if (!this.waiting) {\n", " this.waiting = true;\n", " this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n", " }\n", "};\n", "\n", "mpl.figure.prototype.handle_save = function (fig, _msg) {\n", " var format_dropdown = fig.format_dropdown;\n", " var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n", " fig.ondownload(fig, format);\n", "};\n", "\n", "mpl.figure.prototype.handle_resize = function (fig, msg) {\n", " var size = msg['size'];\n", " if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n", " fig._resize_canvas(size[0], size[1], msg['forward']);\n", " fig.send_message('refresh', {});\n", " }\n", "};\n", "\n", "mpl.figure.prototype.handle_rubberband = function (fig, msg) {\n", " var x0 = msg['x0'] / mpl.ratio;\n", " var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n", " var x1 = msg['x1'] / mpl.ratio;\n", " var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n", " x0 = Math.floor(x0) + 0.5;\n", " y0 = Math.floor(y0) + 0.5;\n", " x1 = Math.floor(x1) + 0.5;\n", " y1 = Math.floor(y1) + 0.5;\n", " var min_x = Math.min(x0, x1);\n", " var min_y = Math.min(y0, y1);\n", " var width = Math.abs(x1 - x0);\n", " var height = Math.abs(y1 - y0);\n", "\n", " fig.rubberband_context.clearRect(\n", " 0,\n", " 0,\n", " fig.canvas.width / mpl.ratio,\n", " fig.canvas.height / mpl.ratio\n", " );\n", "\n", " fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n", "};\n", "\n", "mpl.figure.prototype.handle_figure_label = function (fig, msg) {\n", " // Updates the figure title.\n", " fig.header.textContent = msg['label'];\n", "};\n", "\n", "mpl.figure.prototype.handle_cursor = function (fig, msg) {\n", " var cursor = msg['cursor'];\n", " switch (cursor) {\n", " case 0:\n", " cursor = 'pointer';\n", " break;\n", " case 1:\n", " cursor = 'default';\n", " break;\n", " case 2:\n", " cursor = 'crosshair';\n", " break;\n", " case 3:\n", " cursor = 'move';\n", " break;\n", " }\n", " fig.rubberband_canvas.style.cursor = cursor;\n", "};\n", "\n", "mpl.figure.prototype.handle_message = function (fig, msg) {\n", " fig.message.textContent = msg['message'];\n", "};\n", "\n", "mpl.figure.prototype.handle_draw = function (fig, _msg) {\n", " // Request the server to send over a new figure.\n", " fig.send_draw_message();\n", "};\n", "\n", "mpl.figure.prototype.handle_image_mode = function (fig, msg) {\n", " fig.image_mode = msg['mode'];\n", "};\n", "\n", "mpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n", " for (var key in msg) {\n", " if (!(key in fig.buttons)) {\n", " continue;\n", " }\n", " fig.buttons[key].disabled = !msg[key];\n", " fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n", " }\n", "};\n", "\n", "mpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n", " if (msg['mode'] === 'PAN') {\n", " fig.buttons['Pan'].classList.add('active');\n", " fig.buttons['Zoom'].classList.remove('active');\n", " } else if (msg['mode'] === 'ZOOM') {\n", " fig.buttons['Pan'].classList.remove('active');\n", " fig.buttons['Zoom'].classList.add('active');\n", " } else {\n", " fig.buttons['Pan'].classList.remove('active');\n", " fig.buttons['Zoom'].classList.remove('active');\n", " }\n", "};\n", "\n", "mpl.figure.prototype.updated_canvas_event = function () {\n", " // Called whenever the canvas gets updated.\n", " this.send_message('ack', {});\n", "};\n", "\n", "// A function to construct a web socket function for onmessage handling.\n", "// Called in the figure constructor.\n", "mpl.figure.prototype._make_on_message_function = function (fig) {\n", " return function socket_on_message(evt) {\n", " if (evt.data instanceof Blob) {\n", " /* FIXME: We get \"Resource interpreted as Image but\n", " * transferred with MIME type text/plain:\" errors on\n", " * Chrome. But how to set the MIME type? It doesn't seem\n", " * to be part of the websocket stream */\n", " evt.data.type = 'image/png';\n", "\n", " /* Free the memory for the previous frames */\n", " if (fig.imageObj.src) {\n", " (window.URL || window.webkitURL).revokeObjectURL(\n", " fig.imageObj.src\n", " );\n", " }\n", "\n", " fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n", " evt.data\n", " );\n", " fig.updated_canvas_event();\n", " fig.waiting = false;\n", " return;\n", " } else if (\n", " typeof evt.data === 'string' &&\n", " evt.data.slice(0, 21) === 'data:image/png;base64'\n", " ) {\n", " fig.imageObj.src = evt.data;\n", " fig.updated_canvas_event();\n", " fig.waiting = false;\n", " return;\n", " }\n", "\n", " var msg = JSON.parse(evt.data);\n", " var msg_type = msg['type'];\n", "\n", " // Call the \"handle_{type}\" callback, which takes\n", " // the figure and JSON message as its only arguments.\n", " try {\n", " var callback = fig['handle_' + msg_type];\n", " } catch (e) {\n", " console.log(\n", " \"No handler for the '\" + msg_type + \"' message type: \",\n", " msg\n", " );\n", " return;\n", " }\n", "\n", " if (callback) {\n", " try {\n", " // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n", " callback(fig, msg);\n", " } catch (e) {\n", " console.log(\n", " \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n", " e,\n", " e.stack,\n", " msg\n", " );\n", " }\n", " }\n", " };\n", "};\n", "\n", "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n", "mpl.findpos = function (e) {\n", " //this section is from http://www.quirksmode.org/js/events_properties.html\n", " var targ;\n", " if (!e) {\n", " e = window.event;\n", " }\n", " if (e.target) {\n", " targ = e.target;\n", " } else if (e.srcElement) {\n", " targ = e.srcElement;\n", " }\n", " if (targ.nodeType === 3) {\n", " // defeat Safari bug\n", " targ = targ.parentNode;\n", " }\n", "\n", " // pageX,Y are the mouse positions relative to the document\n", " var boundingRect = targ.getBoundingClientRect();\n", " var x = e.pageX - (boundingRect.left + document.body.scrollLeft);\n", " var y = e.pageY - (boundingRect.top + document.body.scrollTop);\n", "\n", " return { x: x, y: y };\n", "};\n", "\n", "/*\n", " * return a copy of an object with only non-object keys\n", " * we need this to avoid circular references\n", " * http://stackoverflow.com/a/24161582/3208463\n", " */\n", "function simpleKeys(original) {\n", " return Object.keys(original).reduce(function (obj, key) {\n", " if (typeof original[key] !== 'object') {\n", " obj[key] = original[key];\n", " }\n", " return obj;\n", " }, {});\n", "}\n", "\n", "mpl.figure.prototype.mouse_event = function (event, name) {\n", " var canvas_pos = mpl.findpos(event);\n", "\n", " if (name === 'button_press') {\n", " this.canvas.focus();\n", " this.canvas_div.focus();\n", " }\n", "\n", " var x = canvas_pos.x * mpl.ratio;\n", " var y = canvas_pos.y * mpl.ratio;\n", "\n", " this.send_message(name, {\n", " x: x,\n", " y: y,\n", " button: event.button,\n", " step: event.step,\n", " guiEvent: simpleKeys(event),\n", " });\n", "\n", " /* This prevents the web browser from automatically changing to\n", " * the text insertion cursor when the button is pressed. We want\n", " * to control all of the cursor setting manually through the\n", " * 'cursor' event from matplotlib */\n", " event.preventDefault();\n", " return false;\n", "};\n", "\n", "mpl.figure.prototype._key_event_extra = function (_event, _name) {\n", " // Handle any extra behaviour associated with a key event\n", "};\n", "\n", "mpl.figure.prototype.key_event = function (event, name) {\n", " // Prevent repeat events\n", " if (name === 'key_press') {\n", " if (event.which === this._key) {\n", " return;\n", " } else {\n", " this._key = event.which;\n", " }\n", " }\n", " if (name === 'key_release') {\n", " this._key = null;\n", " }\n", "\n", " var value = '';\n", " if (event.ctrlKey && event.which !== 17) {\n", " value += 'ctrl+';\n", " }\n", " if (event.altKey && event.which !== 18) {\n", " value += 'alt+';\n", " }\n", " if (event.shiftKey && event.which !== 16) {\n", " value += 'shift+';\n", " }\n", "\n", " value += 'k';\n", " value += event.which.toString();\n", "\n", " this._key_event_extra(event, name);\n", "\n", " this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n", " return false;\n", "};\n", "\n", "mpl.figure.prototype.toolbar_button_onclick = function (name) {\n", " if (name === 'download') {\n", " this.handle_save(this, null);\n", " } else {\n", " this.send_message('toolbar_button', { name: name });\n", " }\n", "};\n", "\n", "mpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n", " this.message.textContent = tooltip;\n", "};\n", "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n", "\n", "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n", "\n", "mpl.default_extension = \"png\";/* global mpl */\n", "\n", "var comm_websocket_adapter = function (comm) {\n", " // Create a \"websocket\"-like object which calls the given IPython comm\n", " // object with the appropriate methods. Currently this is a non binary\n", " // socket, so there is still some room for performance tuning.\n", " var ws = {};\n", "\n", " ws.close = function () {\n", " comm.close();\n", " };\n", " ws.send = function (m) {\n", " //console.log('sending', m);\n", " comm.send(m);\n", " };\n", " // Register the callback with on_msg.\n", " comm.on_msg(function (msg) {\n", " //console.log('receiving', msg['content']['data'], msg);\n", " // Pass the mpl event to the overridden (by mpl) onmessage function.\n", " ws.onmessage(msg['content']['data']);\n", " });\n", " return ws;\n", "};\n", "\n", "mpl.mpl_figure_comm = function (comm, msg) {\n", " // This is the function which gets called when the mpl process\n", " // starts-up an IPython Comm through the \"matplotlib\" channel.\n", "\n", " var id = msg.content.data.id;\n", " // Get hold of the div created by the display call when the Comm\n", " // socket was opened in Python.\n", " var element = document.getElementById(id);\n", " var ws_proxy = comm_websocket_adapter(comm);\n", "\n", " function ondownload(figure, _format) {\n", " window.open(figure.canvas.toDataURL());\n", " }\n", "\n", " var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n", "\n", " // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n", " // web socket which is closed, not our websocket->open comm proxy.\n", " ws_proxy.onopen();\n", "\n", " fig.parent_element = element;\n", " fig.cell_info = mpl.find_output_cell(\"
\");\n", " if (!fig.cell_info) {\n", " console.error('Failed to find cell for figure', id, fig);\n", " return;\n", " }\n", "};\n", "\n", "mpl.figure.prototype.handle_close = function (fig, msg) {\n", " var width = fig.canvas.width / mpl.ratio;\n", " fig.root.removeEventListener('remove', this._remove_fig_handler);\n", "\n", " // Update the output cell to use the data from the current canvas.\n", " fig.push_to_output();\n", " var dataURL = fig.canvas.toDataURL();\n", " // Re-enable the keyboard manager in IPython - without this line, in FF,\n", " // the notebook keyboard shortcuts fail.\n", " IPython.keyboard_manager.enable();\n", " fig.parent_element.innerHTML =\n", " '';\n", " fig.close_ws(fig, msg);\n", "};\n", "\n", "mpl.figure.prototype.close_ws = function (fig, msg) {\n", " fig.send_message('closing', msg);\n", " // fig.ws.close()\n", "};\n", "\n", "mpl.figure.prototype.push_to_output = function (_remove_interactive) {\n", " // Turn the data on the canvas into data in the output cell.\n", " var width = this.canvas.width / mpl.ratio;\n", " var dataURL = this.canvas.toDataURL();\n", " this.cell_info[1]['text/html'] =\n", " '';\n", "};\n", "\n", "mpl.figure.prototype.updated_canvas_event = function () {\n", " // Tell IPython that the notebook contents must change.\n", " IPython.notebook.set_dirty(true);\n", " this.send_message('ack', {});\n", " var fig = this;\n", " // Wait a second, then push the new image to the DOM so\n", " // that it is saved nicely (might be nice to debounce this).\n", " setTimeout(function () {\n", " fig.push_to_output();\n", " }, 1000);\n", "};\n", "\n", "mpl.figure.prototype._init_toolbar = function () {\n", " var fig = this;\n", "\n", " var toolbar = document.createElement('div');\n", " toolbar.classList = 'btn-toolbar';\n", " this.root.appendChild(toolbar);\n", "\n", " function on_click_closure(name) {\n", " return function (_event) {\n", " return fig.toolbar_button_onclick(name);\n", " };\n", " }\n", "\n", " function on_mouseover_closure(tooltip) {\n", " return function (event) {\n", " if (!event.currentTarget.disabled) {\n", " return fig.toolbar_button_onmouseover(tooltip);\n", " }\n", " };\n", " }\n", "\n", " fig.buttons = {};\n", " var buttonGroup = document.createElement('div');\n", " buttonGroup.classList = 'btn-group';\n", " var button;\n", " for (var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " /* Instead of a spacer, we start a new button group. */\n", " if (buttonGroup.hasChildNodes()) {\n", " toolbar.appendChild(buttonGroup);\n", " }\n", " buttonGroup = document.createElement('div');\n", " buttonGroup.classList = 'btn-group';\n", " continue;\n", " }\n", "\n", " button = fig.buttons[name] = document.createElement('button');\n", " button.classList = 'btn btn-default';\n", " button.href = '#';\n", " button.title = name;\n", " button.innerHTML = '';\n", " button.addEventListener('click', on_click_closure(method_name));\n", " button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n", " buttonGroup.appendChild(button);\n", " }\n", "\n", " if (buttonGroup.hasChildNodes()) {\n", " toolbar.appendChild(buttonGroup);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = document.createElement('span');\n", " status_bar.classList = 'mpl-message pull-right';\n", " toolbar.appendChild(status_bar);\n", " this.message = status_bar;\n", "\n", " // Add the close button to the window.\n", " var buttongrp = document.createElement('div');\n", " buttongrp.classList = 'btn-group inline pull-right';\n", " button = document.createElement('button');\n", " button.classList = 'btn btn-mini btn-primary';\n", " button.href = '#';\n", " button.title = 'Stop Interaction';\n", " button.innerHTML = '';\n", " button.addEventListener('click', function (_evt) {\n", " fig.handle_close(fig, {});\n", " });\n", " button.addEventListener(\n", " 'mouseover',\n", " on_mouseover_closure('Stop Interaction')\n", " );\n", " buttongrp.appendChild(button);\n", " var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n", " titlebar.insertBefore(buttongrp, titlebar.firstChild);\n", "};\n", "\n", "mpl.figure.prototype._remove_fig_handler = function () {\n", " this.close_ws(this, {});\n", "};\n", "\n", "mpl.figure.prototype._root_extra_style = function (el) {\n", " el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n", " el.addEventListener('remove', this._remove_fig_handler);\n", "};\n", "\n", "mpl.figure.prototype._canvas_extra_style = function (el) {\n", " // this is important to make the div 'focusable\n", " el.setAttribute('tabindex', 0);\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " } else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "};\n", "\n", "mpl.figure.prototype._key_event_extra = function (event, _name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager) {\n", " manager = IPython.keyboard_manager;\n", " }\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which === 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "};\n", "\n", "mpl.figure.prototype.handle_save = function (fig, _msg) {\n", " fig.ondownload(fig, null);\n", "};\n", "\n", "mpl.find_output_cell = function (html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i = 0; i < ncells; i++) {\n", " var cell = cells[i];\n", " if (cell.cell_type === 'code') {\n", " for (var j = 0; j < cell.output_area.outputs.length; j++) {\n", " var data = cell.output_area.outputs[j];\n", " if (data.data) {\n", " // IPython >= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] === html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "};\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel !== null) {\n", " IPython.notebook.kernel.comm_manager.register_target(\n", " 'matplotlib',\n", " mpl.mpl_figure_comm\n", " );\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib nbagg\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.animation as animation\n", "\n", "n_epoch = 3000 # epoch size\n", "a, b = 1, 1 # initial parameters\n", "epsilon = 0.001 # learning rate\n", "\n", "fig = plt.figure()\n", "imgs = []\n", "\n", "for i in range(n_epoch):\n", " for j in range(N):\n", " a = a + epsilon*2*(Y[j] - a*X[j] - b)*X[j]\n", " b = b + epsilon*2*(Y[j] - a*X[j] - b)\n", "\n", " L = 0\n", " for j in range(N):\n", " L = L + (Y[j]-a*X[j]-b)**2\n", " #print(\"epoch %4d: loss = %f, a = %f, b = %f\" % (i, L, a, b))\n", " \n", " if i % 50 == 0:\n", " x_min = np.min(X)\n", " x_max = np.max(X)\n", " y_min = a * x_min + b\n", " y_max = a * x_max + b\n", "\n", " img = plt.scatter(X, Y, label='original data')\n", " img = plt.plot([x_min, x_max], [y_min, y_max], 'r', label='model')\n", " imgs.append(img)\n", " \n", "ani = animation.ArtistAnimation(fig, imgs)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 如何使用批次更新的方法?\n", "\n", "如果有一些数据包含比较大的错误(异常数据),因此每次更新仅仅使用一个数据会导致不精确,同时每次仅仅使用一个数据来计算更新也导致计算效率比较低。\n", "\n", "\n", "* [梯度下降方法的几种形式](https://blog.csdn.net/u010402786/article/details/51188876)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. 如何拟合多项式函数?\n", "\n", "需要设计一个弹道导弹防御系统,通过观测导弹的飞行路径,预测未来导弹的飞行轨迹,从而完成摧毁的任务。按照物理学,可以得知模型为:\n", "$$\n", "y = at^2 + bt + c\n", "$$\n", "我们需要求解三个模型参数$a, b, c$。\n", "\n", "损失函数的定义为:\n", "$$\n", "L = \\sum_{i=1}^N (y_i - at_i^2 - bt_i - c)^2\n", "$$\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVM0lEQVR4nO3dbYxcZ3mH8evGzssGqiwhVkrWTu2qVlDAgOkoCnVV0ZgqDiDsuhSCSjGUykINLQQEdeBD6SenCgJSQdNaARpalJeGNLEoJaQxqAWJlDVOgSSkWOHFXhKyvDggcEkc7n6Ys/HamRl7ds7MmTnn+kmRZ845O/PM7uY/Z+/nfs5EZiJJapanVT0ASdLoGf6S1ECGvyQ1kOEvSQ1k+EtSAy2vegAn4+yzz87Vq1dXPQxJmih79+79QWau6LRvIsJ/9erVzM7OVj0MSZooEfGdbvss+0hSAxn+ktRAhr8kNZDhL0kNZPhLUgNNRLeP+nPbvjmuvuMBvnfoMOdOT/HOS85ny/qZvrdLqq+YhKt6tlqtbGqrZ69g7rQP4Mpbv8bhx5948jGmTlnGH/zmDJ/cO3fS23duXQfQ13P7hiGNl4jYm5mtjvsM//HQT5AvBHOnfaef8jR+/PPHn/L4yyJ4osPPutv26alT+MWRX/b13Cd6w5A0Wob/mLtt31xfQT4zPQXA3KHDIxvjyTx3rzcM3wCk0esV/tb8R6zTGf7VdzxwTGACHH78iadsW/C9JYR+v2f+3fR67kOHn/pGdfjxJ7j6jgcA/yKQxoln/kPSTxmnW8h3s5Sz735r/mX/1XH867RMJA2fZ/4jdnwZZ+7QYa689WucfsrTOp7h91t37/VG8t5XPhfoHKitXzvrpLd3e/xe+3rNN3R63e/dfe8xr2/h+wT4BiANmeE/oDLKOE9kdjwz7hXkC7rt6xSeW9bP9LW93+eG/v6ysUwkVceyzwC6TdQupYyz8KYx6WHX7c2wrDLRJH5PpKpY9hlQt572bmf4Synj9Dr7niTdXkcZZaKr73jARWlSSQz/E+hWv4funS+DlHHqaOH1DVom+t6hwz1/HnX/PkplsuyzSD8li15dL3Uq4wyb33NpeFzkdRL6rd8H8IHXvLDrSldDZ+m6/Sx2bl3HFTfdQ7ffWOcJpGP1Cn+v6lnoVb/v5NzpKbasn2Hn1nXMTE8RtM8+DZvB9fq+nluc/R+v1zyBpKdqZM2/U6mh3/r9Qs26LhO146bb9/Wdl5y/pHkCy0HSsRp35r9QUpg7dJjk6ITh9BmndDx+4azTs/vx0O2vgpkufxGcOXVKx5/3bfvmRjpuadw07sy/W3nntOVP63qG79n9eOmnnTSCnm2jUlOVcuYfEVdExL0R8fWIuCEiTo+INRFxd0Tsj4ibIuLU4tjTivv7i/2ryxhDJ7ftm2PDVXtYs+Pf2HDVHm7bN9e1vPPo4cc9w59g3f4iONRhHQEcLQcd//shNcXA3T4RMQN8AbggMw9HxM3Ap4GXAbdm5o0R8ffA/2TmtRHxZ8DzM/PNEXEZ8PuZ+Zpez7GUbp+lXCb5izsu7us5NP42XLXHy0+rsUbR7bMcmIqI5cAZwEPAxcAtxf7rgS3F7c3FfYr9GyO6tNQMoFt5J7P9P/liiydwVS/vvOT8jj/vXuUgqQkGDv/MnAPeB3yXdug/CuwFDmXmkeKwg8DC6dQMcKD42iPF8c86/nEjYntEzEbE7Pz8fN/jsrwjWFo5CDqXDKU6GXjCNyKeSftsfg1wCPgXYNOgj5uZu4Bd0C779Pv1505Pdfxzf6E/37Bvjk4/726riM+dnvISEmqEMso+LwW+lZnzmfk4cCuwAZguykAAK4GFU6c5YBVAsf9M4IcljOMY3f7ct7wj6P370a1kaElIdVJG+H8XuCgizihq9xuB+4DPAa8qjtkG3F7c3l3cp9i/J4dwjQlX36qXXr8f3UqGdgipTkq5tk9E/DXwGuAIsA/4U9q1/RuBs4ptr8vMX0TE6cA/AeuBHwGXZeaDvR5/XK/nr3qyQ0h14YXdpD7YJqy68MJuUh+W2iEkTZLGXd5BOhlL6RDy4nGaJJ75SyepW4fQ7z5nhReP08Qx/KWT1K0c9LlvzNsaqolj2UfqQ6dy0BU33dPxWOcCNM4Mf2lAvVaTOxegcWXZRxqQcwGaRIa/NCDnAjSJLPtIJXAuQJPG8JeGxLkAjTPLPtKQOBegcWb4S0PiXIDGmWUfaYicC9C48sxfGrFzp6f62i4Ng+EvjdiJPmXOD4zRKFj2kUZsoQzUqdvHzw/WqBj+UgU6zQUAPT8/2PBXmSz7SGOk1+cHS2Uy/KUx4mSwRsXwl8ZIr8lgJ4JVJmv+0hjpNhkMOBGsUhn+0pjpNBm84ao9TgSrVJZ9pAngRLDKZvhLE8CJYJWtlPCPiOmIuCUivhER90fEiyPirIi4MyK+Wfz7zOLYiIi/jYj9EfHViHhRGWOQ6syJYJWtrDP/a4DPZOZzgBcA9wM7gLsycy1wV3Ef4FJgbfHfduDaksYg1Va3K4QCXh5aSxKZOdgDRJwJ3AP8ei56sIh4AHhJZj4UEc8GPp+Z50fEPxS3bzj+uG7P0Wq1cnZ2dqBxSnW04ao9HT8wZmZ6ii/uuLiCEWmcRMTezGx12lfGmf8aYB74WETsi4jrIuLpwDmLAv1h4Jzi9gxwYNHXHyy2HT/o7RExGxGz8/PzJQxTqh8ngrVUZYT/cuBFwLWZuR74GUdLPAAUfxH09SdGZu7KzFZmtlasWFHCMKX6cSJYS1VG+B8EDmbm3cX9W2i/GXy/KPdQ/PtIsX8OWLXo61cW2yT1yYlgLdXA4Z+ZDwMHIuL8YtNG4D5gN7Ct2LYNuL24vRt4fdH1cxHwaK96v6TunAjWUg084QsQES8ErgNOBR4E3kj7jeVm4DzgO8CrM/NHERHAh4BNwM+BN2Zmz9lcJ3yl/jgRLOg94VvK5R0y8x6g0xNs7HBsApeX8bySOnMiWCfiCl+phpwI1okY/lINnehzgiWv6inV0Ik+J7jTdjWL4S/VVKdLQ/sB8Vpg2UdqkF4fEK9mMfylBrELSAsMf6lB7ALSAsNfapATdQF5SYjmcMJXapATdQE5Gdwchr/UMJ26gKD3ZLDhXz+WfSQBTgY3jeEvCXAyuGkMf0mAl4RoGmv+koDek8GqH8Nf0pO6TQZ7PaD6Mfwl9WQLaD1Z85fUk9cDqifDX1JPtoDWk+EvqSdbQOvJ8JfUky2g9eSEr6Se/FSwejL8JZ2QnwpWP5Z9JC2JXUCTzfCXtCR2AU220sI/IpZFxL6I+FRxf01E3B0R+yPipog4tdh+WnF/f7F/dVljkDQ6dgFNtjLP/N8K3L/o/t8AH8jM3wB+DLyp2P4m4MfF9g8Ux0maMHYBTbZSwj8iVgIvB64r7gdwMXBLccj1wJbi9ubiPsX+jcXxkibIlvUz7Ny6jpnpKQKYmZ5i59Z1TvZOiLK6fT4IvAv4leL+s4BDmXmkuH8QWPiNmAEOAGTmkYh4tDj+B4sfMCK2A9sBzjvvvJKGKalMXghucg185h8RrwAeycy9JYznSZm5KzNbmdlasWJFmQ8taYgWWkDnDh0mOdoC6ofBj5cyyj4bgFdGxLeBG2mXe64BpiNi4S+LlcDCT34OWAVQ7D8T+GEJ45A0BmwBnQwDh39mXpmZKzNzNXAZsCcz/wj4HPCq4rBtwO3F7d3FfYr9ezIzBx2HpPFgC+hkGGaf/18Cb4+I/bRr+h8ptn8EeFax/e3AjiGOQdKI2QI6GUq9vENmfh74fHH7QeDCDsf8H/CHZT6vpPHxzkvOP+ayD2AL6Djy2j6SSuVnAU8Gw19S6WwBHX+Gv6SR8Cqg48ULu0kaCVtAx4vhL2kkbAEdL4a/pJGwBXS8GP6SRsKrgI4XJ3wljYQtoOPF8Jc0Mt1aQME20FEz/CVVzjbQ0bPmL6lytoGOnuEvqXK2gY6e4S+pcraBjp7hL6lytoGOnhO+kipnG+joGf6SxkKvNlCVz/CXNNbs/x8Ow1/S2LL/f3ic8JU0tuz/Hx7DX9LYsv9/eAx/SWPL/v/hMfwljS37/4fHCV9JY8v+/+EZOPwjYhXwceAcIIFdmXlNRJwF3ASsBr4NvDozfxwRAVwDvAz4OfCGzPzKoOOQVE/d+v9tAR1MGWWfI8A7MvMC4CLg8oi4ANgB3JWZa4G7ivsAlwJri/+2A9eWMAZJDbLQAjp36DDJ0RbQ2/bNVT20iTFw+GfmQwtn7pn5U+B+YAbYDFxfHHY9sKW4vRn4eLZ9CZiOiGcPOg5JzWEL6OBKnfCNiNXAeuBu4JzMfKjY9TDtshC03xgOLPqyg8U2SToptoAOrrTwj4hnAJ8E3paZP1m8LzOT9nxAP4+3PSJmI2J2fn6+rGFKqgFbQAdXSvhHxCm0g/8TmXlrsfn7C+Wc4t9Hiu1zwKpFX76y2HaMzNyVma3MbK1YsaKMYUqqCVtABzdw+BfdOx8B7s/M9y/atRvYVtzeBty+aPvro+0i4NFF5SFJOqEt62fYuXUdM9NTBDAzPcXOrevs9ulDtCsyAzxAxG8D/wV8DfhlsfndtOv+NwPnAd+h3er5o+LN4kPAJtqtnm/MzNlez9FqtXJ2tuchkqTjRMTezGx12jdwn39mfgGILrs3djg+gcsHfV5J6sT+/5PjCl9JteEloE+e1/aRVBv2/588w19Sbdj/f/IMf0m1Yf//yTP8JdWG/f8nzwlfSbXhJaBPnuEvqVa6XQJax7LsI0kN5Jm/pMZwAdhRhr+kRnAB2LEs+0hqBBeAHcvwl9QILgA7luEvqRFcAHYsw19SI7gA7FhO+EpqBBeAHcvwl9QYLgA7yvCX1HhN7P83/CU1WlP7/53wldRoTe3/N/wlNVpT+/8Nf0mN1tT+f8NfUqM1tf/fCV9JjdbU/n/DX1LjNbH/v7KyT0RsiogHImJ/ROyoahyS1ESVnPlHxDLgw8DvAQeBL0fE7sy8r4rxSFIndV78VdWZ/4XA/sx8MDMfA24ENlc0Fkl6ioXFX3OHDpMcXfx12765qodWiqrCfwY4sOj+wWKbJI2Fui/+GttWz4jYHhGzETE7Pz9f9XAkNUzdF39VFf5zwKpF91cW256Umbsys5WZrRUrVox0cJJU98VfVYX/l4G1EbEmIk4FLgN2VzQWSXqKui/+qqTbJzOPRMRbgDuAZcBHM/PeKsYiSZ3UffFXZGbVYzihVquVs7OzVQ9DkiZKROzNzFanfWM74StJGh4v7yBJfarD4i/DX5L6UJdP/rLsI0l9qMviL8NfkvpQl8Vfhr8k9aEui78Mf0nqQ10WfznhK0l9qMviL8NfkvpUh0/+suwjSQ1k+EtSA1n2kaSSTNLKX8NfkkowaSt/LftIUgkmbeWv4S9JJZi0lb+GvySVYNJW/hr+klSCSVv564SvJJVg0lb+Gv6SVJJJWvlr2UeSGsjwl6QGsuwjSSMwbqt/DX9JGrJxXP1r2UeShmwcV/8OFP4RcXVEfCMivhoR/xoR04v2XRkR+yPigYi4ZNH2TcW2/RGxY5Dnl6RJMI6rfwc9878TeF5mPh/4X+BKgIi4ALgMeC6wCfi7iFgWEcuADwOXAhcAry2OlaTaGsfVvwOFf2Z+NjOPFHe/BKwsbm8GbszMX2Tmt4D9wIXFf/sz88HMfAy4sThWkmprHFf/llnz/xPg34vbM8CBRfsOFtu6bX+KiNgeEbMRMTs/P1/iMCVptLasn2Hn1nXMTE8RwMz0FDu3rhvvbp+I+A/gVzvsek9m3l4c8x7gCPCJsgaWmbuAXQCtVivLelxJqsK4rf49Yfhn5kt77Y+INwCvADZm5kJIzwGrFh22sthGj+2SpBEZtNtnE/Au4JWZ+fNFu3YDl0XEaRGxBlgL/DfwZWBtRKyJiFNpTwrvHmQMkqT+DbrI60PAacCdEQHwpcx8c2beGxE3A/fRLgddnplPAETEW4A7gGXARzPz3gHHIEkTq6qVv3G0UjO+Wq1Wzs7OVj0MSSrV8St/od0FVNZkcETszcxWp32u8JWkilS58tfwl6SKVLny1/CXpIpUufLX8JekilS58tdLOktSRar83F/DX5IqVNXKX8s+ktRAhr8kNZBlH0kaQ8Ne+Wv4S9KYGcVn/lr2kaQxM4qVv4a/JI2ZUaz8NfwlacyMYuWv4S9JY2YUK3+d8JWkMTOKlb+GvySNoWGv/LXsI0kNZPhLUgMZ/pLUQIa/JDWQ4S9JDRSZWfUYTigi5oHvDPAQZwM/KGk4k6Jpr7lprxd8zU0xyGv+tcxc0WnHRIT/oCJiNjNbVY9jlJr2mpv2esHX3BTDes2WfSSpgQx/SWqgpoT/rqoHUIGmveamvV7wNTfFUF5zI2r+kqRjNeXMX5K0iOEvSQ1U6/CPiE0R8UBE7I+IHVWPZ9giYlVEfC4i7ouIeyPirVWPaVQiYllE7IuIT1U9llGIiOmIuCUivhER90fEi6se07BFxBXF7/XXI+KGiDi96jGVLSI+GhGPRMTXF207KyLujIhvFv8+s4znqm34R8Qy4MPApcAFwGsj4oJqRzV0R4B3ZOYFwEXA5Q14zQveCtxf9SBG6BrgM5n5HOAF1Py1R8QM8BdAKzOfBywDLqt2VEPxj8Cm47btAO7KzLXAXcX9gdU2/IELgf2Z+WBmPgbcCGyueExDlZkPZeZXits/pR0Iw7sg+JiIiJXAy4Hrqh7LKETEmcDvAB8ByMzHMvNQpYMajeXAVEQsB84AvlfxeEqXmf8J/Oi4zZuB64vb1wNbyniuOof/DHBg0f2DNCAIF0TEamA9cHfFQxmFDwLvAn5Z8ThGZQ0wD3ysKHVdFxFPr3pQw5SZc8D7gO8CDwGPZuZnqx3VyJyTmQ8Vtx8GzinjQesc/o0VEc8APgm8LTN/UvV4hikiXgE8kpl7qx7LCC0HXgRcm5nrgZ9RUilgXBV17s203/jOBZ4eEa+rdlSjl+3e/FL68+sc/nPAqkX3Vxbbai0iTqEd/J/IzFurHs8IbABeGRHfpl3auzgi/rnaIQ3dQeBgZi78VXcL7TeDOnsp8K3MnM/Mx4Fbgd+qeEyj8v2IeDZA8e8jZTxoncP/y8DaiFgTEafSnhzaXfGYhioignYd+P7MfH/V4xmFzLwyM1dm5mraP+M9mVnrM8LMfBg4EBHnF5s2AvdVOKRR+C5wUUScUfyeb6Tmk9yL7Aa2Fbe3AbeX8aC1/QD3zDwSEW8B7qDdGfDRzLy34mEN2wbgj4GvRcQ9xbZ3Z+anqxuShuTPgU8UJzYPAm+seDxDlZl3R8QtwFdod7Xto4aXeoiIG4CXAGdHxEHgr4CrgJsj4k20L23/6lKey8s7SFLz1LnsI0nqwvCXpAYy/CWpgQx/SWogw1+SGsjwl6QGMvwlqYH+Hw0jphXTh/5hAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "#t = np.array([2, 4, 6, 8])\n", "t = np.linspace(0, 10) # Add random noise\n", "\n", "pa = -20\n", "pb = 90\n", "pc = 800\n", "\n", "y = pa*t**2 + pb*t + pc\n", "\n", "\n", "plt.scatter(t, y)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5.1 如何得到更新项?\n", "\n", "$$\n", "L = \\sum_{i=1}^N (y_i - at_i^2 - bt_i - c)^2\n", "$$\n", "\n", "\\begin{eqnarray}\n", "\\frac{\\partial L}{\\partial a} & = & - 2\\sum_{i=1}^N (y_i - at_i^2 - bt_i -c) t_i^2 \\\\\n", "\\frac{\\partial L}{\\partial b} & = & - 2\\sum_{i=1}^N (y_i - at_i^2 - bt_i -c) t_i \\\\\n", "\\frac{\\partial L}{\\partial c} & = & - 2\\sum_{i=1}^N (y_i - at_i^2 - bt_i -c)\n", "\\end{eqnarray}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. 如何使用sklearn求解线性问题?\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X: (442, 1)\n", "Y: (442,)\n", "a = 949.435260, b = 152.133484\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn import linear_model\n", "from sklearn import datasets\n", "import numpy as np\n", "\n", "# load data\n", "d = datasets.load_diabetes()\n", "\n", "X = d.data[:, np.newaxis, 2]\n", "Y = d.target\n", "print(\"X: \", X.shape)\n", "print(\"Y: \", Y.shape)\n", "\n", "# create regression model\n", "regr = linear_model.LinearRegression()\n", "regr.fit(X, Y)\n", "\n", "a, b = regr.coef_, regr.intercept_\n", "print(\"a = %f, b = %f\" % (a, b))\n", "\n", "x_min = np.min(X)\n", "x_max = np.max(X)\n", "y_min = a * x_min + b\n", "y_max = a * x_max + b\n", "\n", "plt.scatter(X, Y)\n", "plt.plot([x_min, x_max], [y_min, y_max], 'r')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. 如何使用sklearn拟合多项式函数?" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([800., 90., -20.])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Fitting polynomial functions\n", "\n", "from sklearn.preprocessing import PolynomialFeatures\n", "from sklearn.linear_model import LinearRegression\n", "from sklearn.pipeline import Pipeline\n", "\n", "t = np.array([2, 4, 6, 8])\n", "\n", "pa = -20\n", "pb = 90\n", "pc = 800\n", "\n", "y = pa*t**2 + pb*t + pc\n", "\n", "model = Pipeline([('poly', PolynomialFeatures(degree=2)),\n", " ('linear', LinearRegression(fit_intercept=False))])\n", "model = model.fit(t[:, np.newaxis], y)\n", "model.named_steps['linear'].coef_\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }