{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 最小二乘\n", "\n", "## 1. 最小二乘的基本原理\n", "\n", "最小二乘法(Least Squares)是一种数学优化技术,它通过最小化误差的平方和找到一组数据的最佳函数匹配, 最小二乘法通常用于曲线拟合、求解模型。很多其他的优化问题也可通过最小化能量或最大化熵用最小二乘形式表达。\n", "\n", "![ls_theory](images/least_squares.png)\n", "\n", "最小二乘原理的一般形式为:\n", "$$\n", "L = \\sum (V_{obv} - V_{target}(\\theta))^2\n", "$$\n", "其中\n", "* $V_{obv}$是观测的多组样本值\n", "* $V_{target}$是假设拟合函数的输出值\n", "* $\\theta$为构造模型的参数\n", "* $L$是目标函数\n", "\n", "如果通过调整模型参数$\\theta$,使得$L$下降到最小则表明,拟合函数与观测最为接近,也就是找到了最优的模型。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.1 示例\n", "\n", "假设我们有下面的一些观测数据,希望找到它们内在的规律。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEHCAYAAACp9y31AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAU8ElEQVR4nO3df6yeZX3H8ffXw1EPuu2ANIS21DYZaUdwUj0hOLJNClicRBrnnEZNsxD5RzdGTLXsR8wSF+u6oSZzLI2i3SSCg64QZVYCdWZmOk8tDvnRwBBKD2jrpM5pM0r57o/zHCinzzk9z3Oe+9dzv19J0/Pc58d9nbvp57nu6/pe1x2ZiSSpPV5SdQMkSeUy+CWpZQx+SWoZg1+SWsbgl6SWMfglqWVOKfoEETECTAJTmXlFRKwCbgZeBewB3puZz8z3M84444xcuXJl0U2VpKGyZ8+eH2fmktnHCw9+4BrgQeCXO68/DnwiM2+OiL8HrgJumO8HrFy5ksnJyWJbKUlDJiIe73a80KGeiFgOvAX4TOd1AOuAWztfsh3YUGQbJEkvVvQY/yeBDwHPdV6/Cjicmc92Xh8AlhXcBknScQoL/oi4AjiYmXv6/P6rI2IyIiYPHTo04NZJUnsV2eO/CHhrRDzG9GTuOuBTwHhEzMwtLAemun1zZm7LzInMnFiy5IS5CUlSnwoL/sy8LjOXZ+ZK4J3APZn5bmA38PbOl20Ebi+qDZKkE5VR1TPbh4GbI+KjwF7gsxW0QZJqa+feKbbu2seTh4+wdHyMTetXs2Ht4KZDSwn+zPw68PXOx48CF5RxXklqmp17p7hux30cOXoMgKnDR7hux30AAwt/V+5KUo1s3bXv+dCfceToMbbu2jewcxj8klQjTx4+0tPxfhj8klQjS8fHejreD4Nfkmpk0/rVjI2OvOjY2OgIm9avHtg5qqjqkSTNYWYCt/FVPZKkhduwdtlAg342h3okqWUMfklqGYNfklrG4JeklnFyV9JQK3rfm7qcsxcGv6ShVca+N3U4Z68c6pE0tMrY96YO5+yVwS9paJWx700dztkrh3okDa2l42NMdQncbvveDGpcvpdzVsUev6ShtdB9b2bG5acOHyF5YVx+596uT4YdyDmrZPBLGlob1i7jY297DcvGxwhg2fgYH3vba07oyQ9yXH6h56ySQz2ShtpC9r0Z9Lh80XvtLJY9fkmtV8Ye+HVi8EtqvYvXLCFmHavbuPwgGfySWm3n3ilu2zNFHncsgN99fb2HaxbD4JfUat0mdhPY/dChahpUAoNfUqs1YcHVoFnVI6nV5ltwVffN1vplj19Sq8214OriNUsGtqirbgx+Sa0214Kr3Q8dqv1ma/1yqEdS63VbcHXtLfd2/dphGPu3xy9JXQzzoi6DX5K6aMJma/1yqEeSupgZ+hnGqh6DX9LQGVQZZt03W+uXwS9pqDThmbfdlLlmwDF+SUOlCc+8nW2QD4JZCINf0lBp4hYMZb9ZGfyShkoTyzDLfrMy+CUNlSaWYZb9ZmXwSxoqTXjm7Wxlv1lZ1SNp6DStDLPsNQMGvyTVQJlvVoUN9UTEyyPiPyLiexFxf0T8Ref4qoj4dkQ8EhG3RMRLi2qDJOlERY7x/x+wLjNfC5wPXB4RFwIfBz6Rmb8KPA1cVWAbJEmzFBb8Oe1/Oy9HO38SWAfc2jm+HdhQVBskSScqtKonIkYi4l7gIHAX8F/A4cx8tvMlB4Cug1oRcXVETEbE5KFDw/vQY0kqW6GTu5l5DDg/IsaBfwbW9PC924BtABMTE1lIAyU1xrA+/7YKpVT1ZObhiNgNvAEYj4hTOr3+5UDzH2ApqVBN3XitrgoL/ohYAhzthP4YcBnTE7u7gbcDNwMbgduLaoPUJPZoT7wGF69Zwu6HDjHVZeuCmb1s2naNBqHIHv9ZwPaIGGF6LuFLmfnliHgAuDkiPgrsBT5bYBukRrBH2/0afOFb++f9njpvvFZnhQV/Zv4nsLbL8UeBC4o6r9RE8+3O2Jbg73YNTma+vWy8g5qbK3elGmjiVsKD1uvvOt9eNt5Bzc9N2qQaaOJWwoPWy+96so3XmvgwljIZ/FINNHEr4UHrdg1mG31JcNqpozx5+Ahbd+2b8wlV3kHNz+CXaqCJWwkPWrdr8J4LVzz/enxsFAKe/sXRkz6e0Duo+UVm/ddGTUxM5OTkZNXNkFSBmUnabiWdMP0G8c3N6074nuPH+GH6Dqptb6YRsSczJ2Yfd3JXUm11C/DZug3flL2/fdMY/JJqayElnkvHx+Ys3TTouzP4JdXWySZjx0ZHuHjNEks3e+TkrqTamm8ydmYCfPdDhyzd7JHBL6m25ipz/eTvn883N69jw9pllm72weCXVFsLKXO1dLN3jvFLqrWTTdJuWr+6a+lmmxa/9crgl9Rolm72zuCXdFJ13+nS0s3eGPyS5uVOl8PH4Jc0rzo+K6DudyB1Z/BLBRqGgKpbuWRd7kCa/G9rOadUkD/beR/X3nIvU4ePnHQ3yTqrW7lkHfban3nzaeq/rcEvFWDn3ilu+tZ+Zu9928QVpXV7VkAd7kDq8OazGA71SB2DvHXfumvfCaE/o2krSutWLrl0fKzrFs1l3oHU4c1nMQx+icGPG88XAE1cUVqncsk6LNiqw5vPYjjUIzH4W/e5AiDAFaWLVIenldVt+KtX9vglBn/r3q1XGsC7L1xRm55zk1V9B1K34a9eGfwSg791b3ow6OSqfvNZDINfYrqHvunW73H02AtTsqMjsahb97mCoddJ5CbXi6ueDH5pxuwynLnKchah10nkuixW0nBxcldiekjm6HMvTvqjz+XA67J7nURuer246snglyivLrvX8zS9Xlz15FCPRHl12b2ep2714s43DAd7/BLl1WX3ep461Ys3fX8avcDglyhvUVCv56nDYqUZzjcMj8gsoHRhwCYmJnJycrLqZkittmrzV7oWOgXwgy1vKbs5WoCI2JOZE7OP2+OXtCB1255Z/TP4JS1IneYbtDhW9UgnYSXLtEFvQ+F1rY7BL82jypWzdQzGQe1P44rkajnUI82jqkqWYS+dtEKoWga/NI+qVs4OezC6IrlahQV/RJwdEbsj4oGIuD8irukcPz0i7oqIhzt/n1ZUG6TFqqqSZdiD0QqhahXZ438W+GBmngtcCLw/Is4FNgN3Z+Y5wN2d11ItVVXJMuzBaIVQtQoL/sx8KjO/2/n4Z8CDwDLgSmB758u2AxuKaoO0WFWtnB32YKzTiuQ2KmXlbkSsBL4BnAfsz8zxzvEAnp55Pet7rgauBlixYsXrH3/88cLbKc2oQ0XNYtpQh/arenOt3C08+CPilcC/An+ZmTsi4vDxQR8RT2fmvOP8btmgMs0uNYTp3nZTeqRNb78Gp5ItGyJiFLgNuCkzd3QO/ygizup8/izgYJFtkHrV9IqaprdfxSuyqieAzwIPZub1x33qDmBj5+ONwO1FtUHqR9MraprefhWvyB7/RcB7gXURcW/nz+8AW4DLIuJh4NLOa6k2ml5R0/T2q3hFVvX8W2ZGZv56Zp7f+XNnZv53Zl6Smedk5qWZ+ZOi2iD1o+kVNU1vv4rnXj3SLIPejKxsTW+/iueDWCRpSPkgFkkSME/wR8SdnYVXkqQhMt8Y/+eAr0XEduCvMvNoSW3SEHIlaf+8dhq0OYM/M/8pIv4F+HNgMiL+EXjuuM9fP9f3SsfzoRv989qpCCcb438G+DnwMuCXZv2RFsSVpP3z2qkIc/b4I+Jy4HqmV9q+LjN/UVqrNFRcSdo/r52KMF+P/0+B38vMzYa+FsOVpP3z2qkIcwZ/Zv5mZt5fZmM0nFxJ2j+vnYrgyl0VzpWk/Zvv2lnto365cldqIPfc10K4clcaIlb7aDEMfqmBrPbRYhj8UgNZ7aPFMPilBrLaR4thVY/UQFZKaTEMfqmhNqxdZtCrLwZ/zVmrLWnQDP4ac2dGSUVwcrfGrNWWVAR7/DVmrXb7OLSnMtjjrzFrtdtlZmhv6vARkheG9nbunaq6aRoyBn+NzVWrffGaJVy05R5Wbf4KF225x2AYEg7tqSwO9dRYt1rti9cs4bY9U074DiGH9lQWg7/mZtdqX7Tlnjl7hQZ/sy0dH2OqS8g7tKdBc6inBDv3Tg1saMZe4fByGwaVxR5/wQZdi9+0XqFVKgvnNgwqi8FfsPkm7Pr5D71p/equD+CoY6/QBWi9cxsGlcGhnoINemhmw9plfOxtr2HZ+BgBLBsfq+1Tl6xSkerJHn/BihiaaUqv0PkIqZ7s8ReszRN2C12ANsjJb0knZ4+/YG2bsDt+Mnf81FFGXxIcfS6f//zsNz3nAaTyGfwlaMrQzGLNDvGnf3GU0ZFgfGyUnx452vVNb9CT35JObuiD33LC8nQL8aPHkle87BTu/cibun6P8wBS+YZ6jN9Nr8rVT4i7EZ1UvqEN/p17p/jgl75nOWGJ+gnxNk9+S1UZyuCf6ekfy+z6eYcRitFPiDdpXYI0LAob44+IG4ErgIOZeV7n2OnALcBK4DHgHZn59KDP3W2s+XgOIxSj3wqmtkx+S3VR5OTu54G/Bf7huGObgbszc0tEbO68/vCgTzxfj95hhGIZ4lL9FRb8mfmNiFg56/CVwBs7H28Hvk4BwT/XatmRiFYPI1jhJAnKH+M/MzOf6nz8Q+DMIk4y11jz37zjta0NOiucJM2orI4/MzMius++AhFxNXA1wIoVK3r62W1bLbsQRSyU8g5Caqayg/9HEXFWZj4VEWcBB+f6wszcBmwDmJiYmPMNYi6ONb/YoBdKzbXVwuTjP2H3Q4d8M5BqrOyhnjuAjZ2PNwK3l3z+1hr0Qqm57iBu+tZ+h5Okmiss+CPii8C/A6sj4kBEXAVsAS6LiIeBSzuvVYJBL5Sa605h9q2ZC+ak+imyquddc3zqkqLOqbkNet5jrsqpblwwJ9XL0G/SphcMct6j2yMggxN7/OCCOaluhnLLBhWv21YL775whfvuSA1gj79Ew1b+2O0OYuLVpw/V7ygNI4N/ABYS6G150pRltFL9OdSzSAtdETvfAipJKpPBv0gLDXSfNCWpLgz+RVpooPukKUl1YfAv0kID3SdNSaoLg3+RFhroPmlKUl1Y1bNIvayIteJFUh0Y/ANgoEtqEod6JKllDH5JahmDX5JaxuCXpJYx+CWpZQx+SWoZg1+SWsbgl6SWcQFXCYbtASySms3gL1hbHsAiqTkc6imYD2CRVDcGf8F8AIukujH4C+YDWCTVjcFfMB/AIqlunNwtWC/79UtSGQz+Erhfv6Q6cahHklrG4JekljH4JallDH5JahmDX5JaxuCXpJYx+CWpZQx+SWoZg1+SWsbgl6SWMfglqWUMfklqmUqCPyIuj4h9EfFIRGyuog2S1FalB39EjACfBt4MnAu8KyLOLbsdktRWVfT4LwAeycxHM/MZ4GbgygraIUmtVEXwLwOeOO71gc6xF4mIqyNiMiImDx06VFrjJGnY1XZyNzO3ZeZEZk4sWbKk6uZI0tCoIvingLOPe728c0ySVIIqgv87wDkRsSoiXgq8E7ijgnZIUiuV/szdzHw2Ij4A7AJGgBsz8/6y2yFJbVXJw9Yz807gzirOLUltV9vJXUlSMQx+SWoZg1+SWsbgl6SWMfglqWUMfklqmUrKOeto594ptu7ax5OHj7B0fIxN61ezYe0JWwhJUuMZ/EyH/nU77uPI0WMATB0+wnU77gMw/CUNHYd6gK279j0f+jOOHD3G1l37KmqRJBXH4AeePHykp+OS1GQGP7B0fKyn45LUZAY/sGn9asZGR150bGx0hE3rV1fUIkkqjpO7vDCBa1WPpDYw+Ds2rF1m0EtqBYd6JKllDH5JahmDX5JaxuCXpJYx+CWpZSIzq27DSUXEIeDnwI+rbkvFzqDd16Dtvz94Ddr++0Nv1+DVmblk9sFGBD9ARExm5kTV7ahS269B239/8Bq0/feHwVwDh3okqWUMfklqmSYF/7aqG1ADbb8Gbf/9wWvQ9t8fBnANGjPGL0kajCb1+CVJA9CI4I+IyyNiX0Q8EhGbq25PmSLi7IjYHREPRMT9EXFN1W2qSkSMRMTeiPhy1W0pW0SMR8StEfFQRDwYEW+ouk1li4hrO/8Hvh8RX4yIl1fdpiJFxI0RcTAivn/csdMj4q6IeLjz92n9/OzaB39EjACfBt4MnAu8KyLOrbZVpXoW+GBmngtcCLy/Zb//8a4BHqy6ERX5FPDVzFwDvJaWXYeIWAb8ETCRmecBI8A7q21V4T4PXD7r2Gbg7sw8B7i787pntQ9+4ALgkcx8NDOfAW4Grqy4TaXJzKcy87udj3/G9H/41u0fHRHLgbcAn6m6LWWLiF8Bfgv4LEBmPpOZhyttVDVOAcYi4hTgVODJittTqMz8BvCTWYevBLZ3Pt4ObOjnZzch+JcBTxz3+gAtDD6AiFgJrAW+XXFTqvBJ4EPAcxW3owqrgEPA5zpDXZ+JiFdU3agyZeYU8NfAfuAp4KeZ+bVqW1WJMzPzqc7HPwTO7OeHNCH4BUTEK4HbgD/OzP+puj1liogrgIOZuafqtlTkFOB1wA2ZuZbp7UvaNtd1GtO93VXAUuAVEfGealtVrZwuyeyrLLMJwT8FnH3c6+WdY60REaNMh/5Nmbmj6vZU4CLgrRHxGNNDfesi4gvVNqlUB4ADmTlzp3cr028EbXIp8IPMPJSZR4EdwG9U3KYq/CgizgLo/H2wnx/ShOD/DnBORKyKiJcyPaFzR8VtKk1EBNNjuw9m5vVVt6cKmXldZi7PzJVM//vfk5mt6e1l5g+BJyJidefQJcADFTapCvuBCyPi1M7/iUto2QR3xx3Axs7HG4Hb+/khtX/mbmY+GxEfAHYxPZN/Y2beX3GzynQR8F7gvoi4t3PsTzLzzuqapAr8IXBTp/PzKPAHFbenVJn57Yi4Ffgu05VuexnyVbwR8UXgjcAZEXEA+AiwBfhSRFwFPA68o6+f7cpdSWqXJgz1SJIGyOCXpJYx+CWpZQx+SWoZg1+SWsbgl3rU2TH1BxFxeuf1aZ3XKytumrQgBr/Uo8x8AriB6ZpqOn9vy8zHKmuU1APr+KU+dLbR2APcCLwPOL+zlYBUe7VfuSvVUWYejYhNwFeBNxn6ahKHeqT+vZnpLYLPq7ohUi8MfqkPEXE+cBnTT0W7dmbHRKkJDH6pR53dIW9g+tkI+4GtTD8kRGoEg1/q3fuA/Zl5V+f13wG/FhG/XWGbpAWzqkeSWsYevyS1jMEvSS1j8EtSyxj8ktQyBr8ktYzBL0ktY/BLUssY/JLUMv8PXGSx1qZ3pgkAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "# 生成数据\n", "data_num = 50\n", "X = np.random.rand(data_num, 1)*10\n", "Y = X * 3 + 4 + 4*np.random.randn(data_num,1)\n", "\n", "# 画出数据的分布\n", "plt.scatter(X, Y)\n", "plt.xlabel(\"X\")\n", "plt.ylabel(\"Y\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.2 数学原理\n", "有$N$个观测数据为:\n", "$$\n", "\\mathbf{X} = \\{x_1, x_2, ..., x_N \\} \\\\\n", "\\mathbf{Y} = \\{y_1, y_2, ..., y_N \\}\n", "$$\n", "其中$\\mathbf{X}$为自变量,$\\mathbf{Y}$为因变量。\n", "\n", "希望找到一个模型能够解释这些数据,假设使用最简单的线性模型来拟合数据:\n", "$$\n", "y = ax + b\n", "$$\n", "那么问题就变成求解参数$a$, $b$能够使得模型输出尽可能和观测数据有比较小的误差。\n", "\n", "如何构建函数来评估模型输出与观测数据之间的误差是一个关键问题,这里我们使用观测数据与模型输出的平方和来作为评估函数(也被称为损失函数Loss function):\n", "$$\n", "L = \\sum_{i=1}^{N} \\{y_i - (a x_i + b)\\}^2 \\\\\n", "L = \\sum_{i=1}^{N} (y_i - a x_i - b)^2 \n", "$$\n", "\n", "使误差函数最小,那么我们就可以求出模型的参数:\n", "$$\n", "\\frac{\\partial L}{\\partial a} = -2 \\sum_{i=1}^{N} (y_i - a x_i - b) x_i \\\\\n", "\\frac{\\partial L}{\\partial b} = -2 \\sum_{i=1}^{N} (y_i - a x_i - b)\n", "$$\n", "\n", "即当偏微分为0时,误差函数为最小,因此我们可以得到:\n", "$$\n", "-2 \\sum_{i=1}^{N} (y_i - a x_i - b) x_i = 0 \\\\\n", "-2 \\sum_{i=1}^{N} (y_i - a x_i - b) = 0 \\\\\n", "$$\n", "\n", "将上式调整一下顺序可以得到:\n", "$$\n", "a \\sum x_i^2 + b \\sum x_i = \\sum y_i x_i \\\\\n", "a \\sum x_i + b N = \\sum y_i\n", "$$\n", "\n", "上式中$\\sum x_i^2$, $\\sum x_i$, $\\sum y_i$, $\\sum y_i x_i$都是已知的数据,而参数$a$, $b$是我们想要求得未知参数。通过求解二元一次方程组,我们即可求出模型的最优参数。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.3 求解程序" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a = 3.074602, b = 3.761519\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "N = X.shape[0]\n", "\n", "S_X2 = np.sum(X*X)\n", "S_X = np.sum(X)\n", "S_XY = np.sum(X*Y)\n", "S_Y = np.sum(Y)\n", "\n", "A1 = np.array([[S_X2, S_X], \n", " [S_X, N]])\n", "B1 = np.array([S_XY, S_Y])\n", "\n", "# numpy.linalg模块包含线性代数的函数。\n", "# 使用这个模块,可以计算逆矩阵、求特征值、解线性方程组以及求解行列式等。\n", "coeff = np.linalg.inv(A1).dot(B1)\n", "\n", "print('a = %f, b = %f' % (coeff[0], coeff[1]))\n", "\n", "x_min = np.min(X)\n", "x_max = np.max(X)\n", "y_min = coeff[0] * x_min + coeff[1]\n", "y_max = coeff[0] * x_max + coeff[1]\n", "\n", "plt.scatter(X, Y, label='original data')\n", "plt.plot([x_min, x_max], [y_min, y_max], 'r', label='model')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. 如何使用迭代的方法求出模型参数\n", "\n", "当数据比较多的时候,或者模型比较复杂,无法直接使用解析的方式求出模型参数。因此更为常用的方式是,通过迭代的方式逐步逼近模型的参数。\n", "\n", "### 2.1 梯度下降法\n", "在机器学习算法中,对于很多监督学习模型,需要对原始的模型构建损失函数,接下来便是通过优化算法对损失函数进行优化,以便寻找到最优的参数。在求解机器学习参数的优化算法中,使用较多的是基于梯度下降的优化算法(Gradient Descent, GD)。\n", "\n", "梯度下降法有很多优点,其中最主要的优点是,**在梯度下降法的求解过程中只需求解损失函数的一阶导数,计算的代价比较小,这使得梯度下降法能在很多大规模数据集上得到应用。**\n", "\n", "梯度下降法的含义是通过当前点的梯度方向寻找到新的迭代点。梯度下降法的基本思想可以类比为一个下山的过程。假设这样一个场景:\n", "* 一个人被困在山上,需要从山上下来,找到山的最低点,也就是山谷;\n", "* 但此时山上的浓雾很大,导致可视度很低。因此,下山的路径就无法全部确定,他必须利用自己周围的信息去找到下山的路径。\n", "* 以他当前的所处的位置为基准,寻找这个位置最陡峭的地方,然后朝着山的高度下降的地方走\n", "* 每走一段距离,都反复采用同一个方法,最后就能成功的抵达山谷。\n", "\n", "\n", "一般情况下,这座山最陡峭的地方是无法通过肉眼立马观察出来的,而是需要一个工具来测量;同时,这个人此时正好拥有测量出最陡峭方向的能力。所以,此人每走一段距离,都需要一段时间来测量所在位置最陡峭的方向,这是比较耗时的。那么为了在太阳下山之前到达山底,就要尽可能的减少测量方向的次数。这是一个两难的选择,如果测量的频繁,可以保证下山的方向是绝对正确的,但又非常耗时;如果测量的过少,又有偏离轨道的风险。所以需要找到一个合适的测量方向的频率,来确保下山的方向不错误,同时又不至于耗时太多!\n", "\n", "\n", "![gradient_descent](images/gradient_descent.png)\n", "\n", "如上图所示,得到了最优解。$x$,$y$表示的是$\\theta_0$和$\\theta_1$,$z$方向表示的是花费函数,很明显出发点不同,最后到达的收敛点可能不一样。当然如果是碗状的,那么收敛点就应该是一样的。\n", "\n", "对于最小二乘的损失函数\n", "$$\n", "L = \\sum_{i=1}^{N} (y_i - a x_i - b)^2\n", "$$\n", "\n", "更新的策略是:\n", "$$\n", "\\theta^1 = \\theta^0 - \\eta \\triangledown L(\\theta)\n", "$$\n", "其中$\\theta$代表了模型中的参数,例如$a$, $b$\n", "\n", "此公式的意义是:$L$是关于$\\theta$的一个函数,我们当前所处的位置为$\\theta_0$点,要从这个点走到L的最小值点,也就是山底。首先我们先确定前进的方向,也就是梯度的反向,然后走一段距离的步长,也就是$\\eta$,走完这个段步长,就到达了$\\theta_1$这个点!\n", "\n", "最终的更新方程是:\n", "\n", "$$\n", "a^1 = a^0 + 2 \\eta [ y - (ax+b)]*x \\\\\n", "b^1 = b^0 + 2 \\eta [ y - (ax+b)] \n", "$$\n", "\n", "下面就这个公式的几个常见的疑问:\n", "\n", "* **$\\eta$是什么含义?**\n", "$\\eta$在梯度下降算法中被称作为学习率或者步长,意味着我们可以通过$\\eta$来控制每一步走的距离,以保证不要步子跨的太大,错过了最低点。同时也要保证不要走的太慢,导致太阳下山了,还没有走到山下。所以$\\eta$的选择在梯度下降法中往往是很重要的。\n", "![gd_stepsize](images/gd_stepsize.png)\n", "\n", "* **为什么要梯度要乘以一个负号?**\n", "梯度前加一个负号,就意味着朝着梯度相反的方向前进!梯度的方向实际就是函数在此点上升最快的方向,而我们需要朝着下降最快的方向走,自然就是负的梯度的方向,所以此处需要加上负号。\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 示例代码" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0: loss = 964.462432, a = 3.605114, b = 1.371685\n", "epoch 100: loss = 856.680487, a = 3.044999, b = 3.374801\n", "epoch 200: loss = 900.182987, a = 3.261635, b = 3.670704\n", "epoch 300: loss = 841.107710, a = 3.082096, b = 3.703232\n", "epoch 400: loss = 848.034818, a = 3.023324, b = 3.703601\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import random\n", "\n", "n_epoch = 500 # epoch size\n", "a, b = 1, 1 # initial parameters\n", "epsilon = 0.001 # learning rate\n", "\n", "for i in range(n_epoch):\n", " data_idx = list(range(N))\n", " random.shuffle(data_idx)\n", " \n", " for j in data_idx:\n", " a = a + epsilon*2*(Y[j] - a*X[j] - b)*X[j]\n", " b = b + epsilon*2*(Y[j] - a*X[j] - b)\n", "\n", " L = 0\n", " for j in range(N):\n", " L = L + (Y[j]-a*X[j]-b)**2\n", " \n", " if i % 100 == 0:\n", " print(\"epoch %4d: loss = %f, a = %f, b = %f\" % (i, L, a, b))\n", " \n", "x_min = np.min(X)\n", "x_max = np.max(X)\n", "y_min = a * x_min + b\n", "y_max = a * x_max + b\n", "\n", "plt.scatter(X, Y, label='original data')\n", "plt.plot([x_min, x_max], [y_min, y_max], 'r', label='model')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. 如何可视化迭代过程" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "/* global mpl */\n", "window.mpl = {};\n", "\n", "mpl.get_websocket_type = function () {\n", " if (typeof WebSocket !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof MozWebSocket !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert(\n", " 'Your browser does not have WebSocket support. ' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.'\n", " );\n", " }\n", "};\n", "\n", "mpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = this.ws.binaryType !== undefined;\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById('mpl-warnings');\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent =\n", " 'This browser does not support binary websocket messages. ' +\n", " 'Performance may be slow.';\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = document.createElement('div');\n", " this.root.setAttribute('style', 'display: inline-block');\n", " this._root_extra_style(this.root);\n", "\n", " parent_element.appendChild(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message('supports_binary', { value: fig.supports_binary });\n", " fig.send_message('send_image_mode', {});\n", " if (fig.ratio !== 1) {\n", " fig.send_message('set_dpi_ratio', { dpi_ratio: fig.ratio });\n", " }\n", " fig.send_message('refresh', {});\n", " };\n", "\n", " this.imageObj.onload = function () {\n", " if (fig.image_mode === 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function () {\n", " fig.ws.close();\n", " };\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "};\n", "\n", "mpl.figure.prototype._init_header = function () {\n", " var titlebar = document.createElement('div');\n", " titlebar.classList =\n", " 'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n", " var titletext = document.createElement('div');\n", " titletext.classList = 'ui-dialog-title';\n", " titletext.setAttribute(\n", " 'style',\n", " 'width: 100%; text-align: center; padding: 3px;'\n", " );\n", " titlebar.appendChild(titletext);\n", " this.root.appendChild(titlebar);\n", " this.header = titletext;\n", "};\n", "\n", "mpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n", "\n", "mpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n", "\n", "mpl.figure.prototype._init_canvas = function () {\n", " var fig = this;\n", "\n", " var canvas_div = (this.canvas_div = document.createElement('div'));\n", " canvas_div.setAttribute(\n", " 'style',\n", " 'border: 1px solid #ddd;' +\n", " 'box-sizing: content-box;' +\n", " 'clear: both;' +\n", " 'min-height: 1px;' +\n", " 'min-width: 1px;' +\n", " 'outline: 0;' +\n", " 'overflow: hidden;' +\n", " 'position: relative;' +\n", " 'resize: both;'\n", " );\n", "\n", " function on_keyboard_event_closure(name) {\n", " return function (event) {\n", " return fig.key_event(event, name);\n", " };\n", " }\n", "\n", " canvas_div.addEventListener(\n", " 'keydown',\n", " on_keyboard_event_closure('key_press')\n", " );\n", " canvas_div.addEventListener(\n", " 'keyup',\n", " on_keyboard_event_closure('key_release')\n", " );\n", "\n", " this._canvas_extra_style(canvas_div);\n", " this.root.appendChild(canvas_div);\n", "\n", " var canvas = (this.canvas = document.createElement('canvas'));\n", " canvas.classList.add('mpl-canvas');\n", " canvas.setAttribute('style', 'box-sizing: content-box;');\n", "\n", " this.context = canvas.getContext('2d');\n", "\n", " var backingStore =\n", " this.context.backingStorePixelRatio ||\n", " this.context.webkitBackingStorePixelRatio ||\n", " this.context.mozBackingStorePixelRatio ||\n", " this.context.msBackingStorePixelRatio ||\n", " this.context.oBackingStorePixelRatio ||\n", " this.context.backingStorePixelRatio ||\n", " 1;\n", "\n", " this.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n", " 'canvas'\n", " ));\n", " rubberband_canvas.setAttribute(\n", " 'style',\n", " 'box-sizing: content-box; position: absolute; left: 0; top: 0; z-index: 1;'\n", " );\n", "\n", " // Apply a ponyfill if ResizeObserver is not implemented by browser.\n", " if (this.ResizeObserver === undefined) {\n", " if (window.ResizeObserver !== undefined) {\n", " this.ResizeObserver = window.ResizeObserver;\n", " } else {\n", " var obs = _JSXTOOLS_RESIZE_OBSERVER({});\n", " this.ResizeObserver = obs.ResizeObserver;\n", " }\n", " }\n", "\n", " this.resizeObserverInstance = new this.ResizeObserver(function (entries) {\n", " var nentries = entries.length;\n", " for (var i = 0; i < nentries; i++) {\n", " var entry = entries[i];\n", " var width, height;\n", " if (entry.contentBoxSize) {\n", " if (entry.contentBoxSize instanceof Array) {\n", " // Chrome 84 implements new version of spec.\n", " width = entry.contentBoxSize[0].inlineSize;\n", " height = entry.contentBoxSize[0].blockSize;\n", " } else {\n", " // Firefox implements old version of spec.\n", " width = entry.contentBoxSize.inlineSize;\n", " height = entry.contentBoxSize.blockSize;\n", " }\n", " } else {\n", " // Chrome <84 implements even older version of spec.\n", " width = entry.contentRect.width;\n", " height = entry.contentRect.height;\n", " }\n", "\n", " // Keep the size of the canvas and rubber band canvas in sync with\n", " // the canvas container.\n", " if (entry.devicePixelContentBoxSize) {\n", " // Chrome 84 implements new version of spec.\n", " canvas.setAttribute(\n", " 'width',\n", " entry.devicePixelContentBoxSize[0].inlineSize\n", " );\n", " canvas.setAttribute(\n", " 'height',\n", " entry.devicePixelContentBoxSize[0].blockSize\n", " );\n", " } else {\n", " canvas.setAttribute('width', width * fig.ratio);\n", " canvas.setAttribute('height', height * fig.ratio);\n", " }\n", " canvas.setAttribute(\n", " 'style',\n", " 'width: ' + width + 'px; height: ' + height + 'px;'\n", " );\n", "\n", " rubberband_canvas.setAttribute('width', width);\n", " rubberband_canvas.setAttribute('height', height);\n", "\n", " // And update the size in Python. We ignore the initial 0/0 size\n", " // that occurs as the element is placed into the DOM, which should\n", " // otherwise not happen due to the minimum size styling.\n", " if (fig.ws.readyState == 1 && width != 0 && height != 0) {\n", " fig.request_resize(width, height);\n", " }\n", " }\n", " });\n", " this.resizeObserverInstance.observe(canvas_div);\n", "\n", " function on_mouse_event_closure(name) {\n", " return function (event) {\n", " return fig.mouse_event(event, name);\n", " };\n", " }\n", "\n", " rubberband_canvas.addEventListener(\n", " 'mousedown',\n", " on_mouse_event_closure('button_press')\n", " );\n", " rubberband_canvas.addEventListener(\n", " 'mouseup',\n", " on_mouse_event_closure('button_release')\n", " );\n", " rubberband_canvas.addEventListener(\n", " 'dblclick',\n", " on_mouse_event_closure('dblclick')\n", " );\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband_canvas.addEventListener(\n", " 'mousemove',\n", " on_mouse_event_closure('motion_notify')\n", " );\n", "\n", " rubberband_canvas.addEventListener(\n", " 'mouseenter',\n", " on_mouse_event_closure('figure_enter')\n", " );\n", " rubberband_canvas.addEventListener(\n", " 'mouseleave',\n", " on_mouse_event_closure('figure_leave')\n", " );\n", "\n", " canvas_div.addEventListener('wheel', function (event) {\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " on_mouse_event_closure('scroll')(event);\n", " });\n", "\n", " canvas_div.appendChild(canvas);\n", " canvas_div.appendChild(rubberband_canvas);\n", "\n", " this.rubberband_context = rubberband_canvas.getContext('2d');\n", " this.rubberband_context.strokeStyle = '#000000';\n", "\n", " this._resize_canvas = function (width, height, forward) {\n", " if (forward) {\n", " canvas_div.style.width = width + 'px';\n", " canvas_div.style.height = height + 'px';\n", " }\n", " };\n", "\n", " // Disable right mouse context menu.\n", " this.rubberband_canvas.addEventListener('contextmenu', function (_e) {\n", " event.preventDefault();\n", " return false;\n", " });\n", "\n", " function set_focus() {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "};\n", "\n", "mpl.figure.prototype._init_toolbar = function () {\n", " var fig = this;\n", "\n", " var toolbar = document.createElement('div');\n", " toolbar.classList = 'mpl-toolbar';\n", " this.root.appendChild(toolbar);\n", "\n", " function on_click_closure(name) {\n", " return function (_event) {\n", " return fig.toolbar_button_onclick(name);\n", " };\n", " }\n", "\n", " function on_mouseover_closure(tooltip) {\n", " return function (event) {\n", " if (!event.currentTarget.disabled) {\n", " return fig.toolbar_button_onmouseover(tooltip);\n", " }\n", " };\n", " }\n", "\n", " fig.buttons = {};\n", " var buttonGroup = document.createElement('div');\n", " buttonGroup.classList = 'mpl-button-group';\n", " for (var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " /* Instead of a spacer, we start a new button group. */\n", " if (buttonGroup.hasChildNodes()) {\n", " toolbar.appendChild(buttonGroup);\n", " }\n", " buttonGroup = document.createElement('div');\n", " buttonGroup.classList = 'mpl-button-group';\n", " continue;\n", " }\n", "\n", " var button = (fig.buttons[name] = document.createElement('button'));\n", " button.classList = 'mpl-widget';\n", " button.setAttribute('role', 'button');\n", " button.setAttribute('aria-disabled', 'false');\n", " button.addEventListener('click', on_click_closure(method_name));\n", " button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n", "\n", " var icon_img = document.createElement('img');\n", " icon_img.src = '_images/' + image + '.png';\n", " icon_img.srcset = '_images/' + image + '_large.png 2x';\n", " icon_img.alt = tooltip;\n", " button.appendChild(icon_img);\n", "\n", " buttonGroup.appendChild(button);\n", " }\n", "\n", " if (buttonGroup.hasChildNodes()) {\n", " toolbar.appendChild(buttonGroup);\n", " }\n", "\n", " var fmt_picker = document.createElement('select');\n", " fmt_picker.classList = 'mpl-widget';\n", " toolbar.appendChild(fmt_picker);\n", " this.format_dropdown = fmt_picker;\n", "\n", " for (var ind in mpl.extensions) {\n", " var fmt = mpl.extensions[ind];\n", " var option = document.createElement('option');\n", " option.selected = fmt === mpl.default_extension;\n", " option.innerHTML = fmt;\n", " fmt_picker.appendChild(option);\n", " }\n", "\n", " var status_bar = document.createElement('span');\n", " status_bar.classList = 'mpl-message';\n", " toolbar.appendChild(status_bar);\n", " this.message = status_bar;\n", "};\n", "\n", "mpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n", " // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n", " // which will in turn request a refresh of the image.\n", " this.send_message('resize', { width: x_pixels, height: y_pixels });\n", "};\n", "\n", "mpl.figure.prototype.send_message = function (type, properties) {\n", " properties['type'] = type;\n", " properties['figure_id'] = this.id;\n", " this.ws.send(JSON.stringify(properties));\n", "};\n", "\n", "mpl.figure.prototype.send_draw_message = function () {\n", " if (!this.waiting) {\n", " this.waiting = true;\n", " this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n", " }\n", "};\n", "\n", "mpl.figure.prototype.handle_save = function (fig, _msg) {\n", " var format_dropdown = fig.format_dropdown;\n", " var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n", " fig.ondownload(fig, format);\n", "};\n", "\n", "mpl.figure.prototype.handle_resize = function (fig, msg) {\n", " var size = msg['size'];\n", " if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n", " fig._resize_canvas(size[0], size[1], msg['forward']);\n", " fig.send_message('refresh', {});\n", " }\n", "};\n", "\n", "mpl.figure.prototype.handle_rubberband = function (fig, msg) {\n", " var x0 = msg['x0'] / fig.ratio;\n", " var y0 = (fig.canvas.height - msg['y0']) / fig.ratio;\n", " var x1 = msg['x1'] / fig.ratio;\n", " var y1 = (fig.canvas.height - msg['y1']) / fig.ratio;\n", " x0 = Math.floor(x0) + 0.5;\n", " y0 = Math.floor(y0) + 0.5;\n", " x1 = Math.floor(x1) + 0.5;\n", " y1 = Math.floor(y1) + 0.5;\n", " var min_x = Math.min(x0, x1);\n", " var min_y = Math.min(y0, y1);\n", " var width = Math.abs(x1 - x0);\n", " var height = Math.abs(y1 - y0);\n", "\n", " fig.rubberband_context.clearRect(\n", " 0,\n", " 0,\n", " fig.canvas.width / fig.ratio,\n", " fig.canvas.height / fig.ratio\n", " );\n", "\n", " fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n", "};\n", "\n", "mpl.figure.prototype.handle_figure_label = function (fig, msg) {\n", " // Updates the figure title.\n", " fig.header.textContent = msg['label'];\n", "};\n", "\n", "mpl.figure.prototype.handle_cursor = function (fig, msg) {\n", " var cursor = msg['cursor'];\n", " switch (cursor) {\n", " case 0:\n", " cursor = 'pointer';\n", " break;\n", " case 1:\n", " cursor = 'default';\n", " break;\n", " case 2:\n", " cursor = 'crosshair';\n", " break;\n", " case 3:\n", " cursor = 'move';\n", " break;\n", " }\n", " fig.rubberband_canvas.style.cursor = cursor;\n", "};\n", "\n", "mpl.figure.prototype.handle_message = function (fig, msg) {\n", " fig.message.textContent = msg['message'];\n", "};\n", "\n", "mpl.figure.prototype.handle_draw = function (fig, _msg) {\n", " // Request the server to send over a new figure.\n", " fig.send_draw_message();\n", "};\n", "\n", "mpl.figure.prototype.handle_image_mode = function (fig, msg) {\n", " fig.image_mode = msg['mode'];\n", "};\n", "\n", "mpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n", " for (var key in msg) {\n", " if (!(key in fig.buttons)) {\n", " continue;\n", " }\n", " fig.buttons[key].disabled = !msg[key];\n", " fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n", " }\n", "};\n", "\n", "mpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n", " if (msg['mode'] === 'PAN') {\n", " fig.buttons['Pan'].classList.add('active');\n", " fig.buttons['Zoom'].classList.remove('active');\n", " } else if (msg['mode'] === 'ZOOM') {\n", " fig.buttons['Pan'].classList.remove('active');\n", " fig.buttons['Zoom'].classList.add('active');\n", " } else {\n", " fig.buttons['Pan'].classList.remove('active');\n", " fig.buttons['Zoom'].classList.remove('active');\n", " }\n", "};\n", "\n", "mpl.figure.prototype.updated_canvas_event = function () {\n", " // Called whenever the canvas gets updated.\n", " this.send_message('ack', {});\n", "};\n", "\n", "// A function to construct a web socket function for onmessage handling.\n", "// Called in the figure constructor.\n", "mpl.figure.prototype._make_on_message_function = function (fig) {\n", " return function socket_on_message(evt) {\n", " if (evt.data instanceof Blob) {\n", " var img = evt.data;\n", " if (img.type !== 'image/png') {\n", " /* FIXME: We get \"Resource interpreted as Image but\n", " * transferred with MIME type text/plain:\" errors on\n", " * Chrome. But how to set the MIME type? It doesn't seem\n", " * to be part of the websocket stream */\n", " img.type = 'image/png';\n", " }\n", "\n", " /* Free the memory for the previous frames */\n", " if (fig.imageObj.src) {\n", " (window.URL || window.webkitURL).revokeObjectURL(\n", " fig.imageObj.src\n", " );\n", " }\n", "\n", " fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n", " img\n", " );\n", " fig.updated_canvas_event();\n", " fig.waiting = false;\n", " return;\n", " } else if (\n", " typeof evt.data === 'string' &&\n", " evt.data.slice(0, 21) === 'data:image/png;base64'\n", " ) {\n", " fig.imageObj.src = evt.data;\n", " fig.updated_canvas_event();\n", " fig.waiting = false;\n", " return;\n", " }\n", "\n", " var msg = JSON.parse(evt.data);\n", " var msg_type = msg['type'];\n", "\n", " // Call the \"handle_{type}\" callback, which takes\n", " // the figure and JSON message as its only arguments.\n", " try {\n", " var callback = fig['handle_' + msg_type];\n", " } catch (e) {\n", " console.log(\n", " \"No handler for the '\" + msg_type + \"' message type: \",\n", " msg\n", " );\n", " return;\n", " }\n", "\n", " if (callback) {\n", " try {\n", " // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n", " callback(fig, msg);\n", " } catch (e) {\n", " console.log(\n", " \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n", " e,\n", " e.stack,\n", " msg\n", " );\n", " }\n", " }\n", " };\n", "};\n", "\n", "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n", "mpl.findpos = function (e) {\n", " //this section is from http://www.quirksmode.org/js/events_properties.html\n", " var targ;\n", " if (!e) {\n", " e = window.event;\n", " }\n", " if (e.target) {\n", " targ = e.target;\n", " } else if (e.srcElement) {\n", " targ = e.srcElement;\n", " }\n", " if (targ.nodeType === 3) {\n", " // defeat Safari bug\n", " targ = targ.parentNode;\n", " }\n", "\n", " // pageX,Y are the mouse positions relative to the document\n", " var boundingRect = targ.getBoundingClientRect();\n", " var x = e.pageX - (boundingRect.left + document.body.scrollLeft);\n", " var y = e.pageY - (boundingRect.top + document.body.scrollTop);\n", "\n", " return { x: x, y: y };\n", "};\n", "\n", "/*\n", " * return a copy of an object with only non-object keys\n", " * we need this to avoid circular references\n", " * http://stackoverflow.com/a/24161582/3208463\n", " */\n", "function simpleKeys(original) {\n", " return Object.keys(original).reduce(function (obj, key) {\n", " if (typeof original[key] !== 'object') {\n", " obj[key] = original[key];\n", " }\n", " return obj;\n", " }, {});\n", "}\n", "\n", "mpl.figure.prototype.mouse_event = function (event, name) {\n", " var canvas_pos = mpl.findpos(event);\n", "\n", " if (name === 'button_press') {\n", " this.canvas.focus();\n", " this.canvas_div.focus();\n", " }\n", "\n", " var x = canvas_pos.x * this.ratio;\n", " var y = canvas_pos.y * this.ratio;\n", "\n", " this.send_message(name, {\n", " x: x,\n", " y: y,\n", " button: event.button,\n", " step: event.step,\n", " guiEvent: simpleKeys(event),\n", " });\n", "\n", " /* This prevents the web browser from automatically changing to\n", " * the text insertion cursor when the button is pressed. We want\n", " * to control all of the cursor setting manually through the\n", " * 'cursor' event from matplotlib */\n", " event.preventDefault();\n", " return false;\n", "};\n", "\n", "mpl.figure.prototype._key_event_extra = function (_event, _name) {\n", " // Handle any extra behaviour associated with a key event\n", "};\n", "\n", "mpl.figure.prototype.key_event = function (event, name) {\n", " // Prevent repeat events\n", " if (name === 'key_press') {\n", " if (event.key === this._key) {\n", " return;\n", " } else {\n", " this._key = event.key;\n", " }\n", " }\n", " if (name === 'key_release') {\n", " this._key = null;\n", " }\n", "\n", " var value = '';\n", " if (event.ctrlKey && event.key !== 'Control') {\n", " value += 'ctrl+';\n", " }\n", " else if (event.altKey && event.key !== 'Alt') {\n", " value += 'alt+';\n", " }\n", " else if (event.shiftKey && event.key !== 'Shift') {\n", " value += 'shift+';\n", " }\n", "\n", " value += 'k' + event.key;\n", "\n", " this._key_event_extra(event, name);\n", "\n", " this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n", " return false;\n", "};\n", "\n", "mpl.figure.prototype.toolbar_button_onclick = function (name) {\n", " if (name === 'download') {\n", " this.handle_save(this, null);\n", " } else {\n", " this.send_message('toolbar_button', { name: name });\n", " }\n", "};\n", "\n", "mpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n", " this.message.textContent = tooltip;\n", "};\n", "\n", "///////////////// REMAINING CONTENT GENERATED BY embed_js.py /////////////////\n", "// prettier-ignore\n", "var _JSXTOOLS_RESIZE_OBSERVER=function(A){var t,i=new WeakMap,n=new WeakMap,a=new WeakMap,r=new WeakMap,o=new Set;function s(e){if(!(this instanceof s))throw new TypeError(\"Constructor requires 'new' operator\");i.set(this,e)}function h(){throw new TypeError(\"Function is not a constructor\")}function c(e,t,i,n){e=0 in arguments?Number(arguments[0]):0,t=1 in arguments?Number(arguments[1]):0,i=2 in arguments?Number(arguments[2]):0,n=3 in arguments?Number(arguments[3]):0,this.right=(this.x=this.left=e)+(this.width=i),this.bottom=(this.y=this.top=t)+(this.height=n),Object.freeze(this)}function d(){t=requestAnimationFrame(d);var s=new WeakMap,p=new Set;o.forEach((function(t){r.get(t).forEach((function(i){var r=t instanceof window.SVGElement,o=a.get(t),d=r?0:parseFloat(o.paddingTop),f=r?0:parseFloat(o.paddingRight),l=r?0:parseFloat(o.paddingBottom),u=r?0:parseFloat(o.paddingLeft),g=r?0:parseFloat(o.borderTopWidth),m=r?0:parseFloat(o.borderRightWidth),w=r?0:parseFloat(o.borderBottomWidth),b=u+f,F=d+l,v=(r?0:parseFloat(o.borderLeftWidth))+m,W=g+w,y=r?0:t.offsetHeight-W-t.clientHeight,E=r?0:t.offsetWidth-v-t.clientWidth,R=b+v,z=F+W,M=r?t.width:parseFloat(o.width)-R-E,O=r?t.height:parseFloat(o.height)-z-y;if(n.has(t)){var k=n.get(t);if(k[0]===M&&k[1]===O)return}n.set(t,[M,O]);var S=Object.create(h.prototype);S.target=t,S.contentRect=new c(u,d,M,O),s.has(i)||(s.set(i,[]),p.add(i)),s.get(i).push(S)}))})),p.forEach((function(e){i.get(e).call(e,s.get(e),e)}))}return s.prototype.observe=function(i){if(i instanceof window.Element){r.has(i)||(r.set(i,new Set),o.add(i),a.set(i,window.getComputedStyle(i)));var n=r.get(i);n.has(this)||n.add(this),cancelAnimationFrame(t),t=requestAnimationFrame(d)}},s.prototype.unobserve=function(i){if(i instanceof window.Element&&r.has(i)){var n=r.get(i);n.has(this)&&(n.delete(this),n.size||(r.delete(i),o.delete(i))),n.size||r.delete(i),o.size||cancelAnimationFrame(t)}},A.DOMRectReadOnly=c,A.ResizeObserver=s,A.ResizeObserverEntry=h,A}; // eslint-disable-line\n", "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n", "\n", "mpl.extensions = [\"eps\", \"jpeg\", \"pgf\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n", "\n", "mpl.default_extension = \"png\";/* global mpl */\n", "\n", "var comm_websocket_adapter = function (comm) {\n", " // Create a \"websocket\"-like object which calls the given IPython comm\n", " // object with the appropriate methods. Currently this is a non binary\n", " // socket, so there is still some room for performance tuning.\n", " var ws = {};\n", "\n", " ws.binaryType = comm.kernel.ws.binaryType;\n", " ws.readyState = comm.kernel.ws.readyState;\n", " function updateReadyState(_event) {\n", " if (comm.kernel.ws) {\n", " ws.readyState = comm.kernel.ws.readyState;\n", " } else {\n", " ws.readyState = 3; // Closed state.\n", " }\n", " }\n", " comm.kernel.ws.addEventListener('open', updateReadyState);\n", " comm.kernel.ws.addEventListener('close', updateReadyState);\n", " comm.kernel.ws.addEventListener('error', updateReadyState);\n", "\n", " ws.close = function () {\n", " comm.close();\n", " };\n", " ws.send = function (m) {\n", " //console.log('sending', m);\n", " comm.send(m);\n", " };\n", " // Register the callback with on_msg.\n", " comm.on_msg(function (msg) {\n", " //console.log('receiving', msg['content']['data'], msg);\n", " var data = msg['content']['data'];\n", " if (data['blob'] !== undefined) {\n", " data = {\n", " data: new Blob(msg['buffers'], { type: data['blob'] }),\n", " };\n", " }\n", " // Pass the mpl event to the overridden (by mpl) onmessage function.\n", " ws.onmessage(data);\n", " });\n", " return ws;\n", "};\n", "\n", "mpl.mpl_figure_comm = function (comm, msg) {\n", " // This is the function which gets called when the mpl process\n", " // starts-up an IPython Comm through the \"matplotlib\" channel.\n", "\n", " var id = msg.content.data.id;\n", " // Get hold of the div created by the display call when the Comm\n", " // socket was opened in Python.\n", " var element = document.getElementById(id);\n", " var ws_proxy = comm_websocket_adapter(comm);\n", "\n", " function ondownload(figure, _format) {\n", " window.open(figure.canvas.toDataURL());\n", " }\n", "\n", " var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n", "\n", " // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n", " // web socket which is closed, not our websocket->open comm proxy.\n", " ws_proxy.onopen();\n", "\n", " fig.parent_element = element;\n", " fig.cell_info = mpl.find_output_cell(\"
\");\n", " if (!fig.cell_info) {\n", " console.error('Failed to find cell for figure', id, fig);\n", " return;\n", " }\n", " fig.cell_info[0].output_area.element.on(\n", " 'cleared',\n", " { fig: fig },\n", " fig._remove_fig_handler\n", " );\n", "};\n", "\n", "mpl.figure.prototype.handle_close = function (fig, msg) {\n", " var width = fig.canvas.width / fig.ratio;\n", " fig.cell_info[0].output_area.element.off(\n", " 'cleared',\n", " fig._remove_fig_handler\n", " );\n", " fig.resizeObserverInstance.unobserve(fig.canvas_div);\n", "\n", " // Update the output cell to use the data from the current canvas.\n", " fig.push_to_output();\n", " var dataURL = fig.canvas.toDataURL();\n", " // Re-enable the keyboard manager in IPython - without this line, in FF,\n", " // the notebook keyboard shortcuts fail.\n", " IPython.keyboard_manager.enable();\n", " fig.parent_element.innerHTML =\n", " '';\n", " fig.close_ws(fig, msg);\n", "};\n", "\n", "mpl.figure.prototype.close_ws = function (fig, msg) {\n", " fig.send_message('closing', msg);\n", " // fig.ws.close()\n", "};\n", "\n", "mpl.figure.prototype.push_to_output = function (_remove_interactive) {\n", " // Turn the data on the canvas into data in the output cell.\n", " var width = this.canvas.width / this.ratio;\n", " var dataURL = this.canvas.toDataURL();\n", " this.cell_info[1]['text/html'] =\n", " '';\n", "};\n", "\n", "mpl.figure.prototype.updated_canvas_event = function () {\n", " // Tell IPython that the notebook contents must change.\n", " IPython.notebook.set_dirty(true);\n", " this.send_message('ack', {});\n", " var fig = this;\n", " // Wait a second, then push the new image to the DOM so\n", " // that it is saved nicely (might be nice to debounce this).\n", " setTimeout(function () {\n", " fig.push_to_output();\n", " }, 1000);\n", "};\n", "\n", "mpl.figure.prototype._init_toolbar = function () {\n", " var fig = this;\n", "\n", " var toolbar = document.createElement('div');\n", " toolbar.classList = 'btn-toolbar';\n", " this.root.appendChild(toolbar);\n", "\n", " function on_click_closure(name) {\n", " return function (_event) {\n", " return fig.toolbar_button_onclick(name);\n", " };\n", " }\n", "\n", " function on_mouseover_closure(tooltip) {\n", " return function (event) {\n", " if (!event.currentTarget.disabled) {\n", " return fig.toolbar_button_onmouseover(tooltip);\n", " }\n", " };\n", " }\n", "\n", " fig.buttons = {};\n", " var buttonGroup = document.createElement('div');\n", " buttonGroup.classList = 'btn-group';\n", " var button;\n", " for (var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " /* Instead of a spacer, we start a new button group. */\n", " if (buttonGroup.hasChildNodes()) {\n", " toolbar.appendChild(buttonGroup);\n", " }\n", " buttonGroup = document.createElement('div');\n", " buttonGroup.classList = 'btn-group';\n", " continue;\n", " }\n", "\n", " button = fig.buttons[name] = document.createElement('button');\n", " button.classList = 'btn btn-default';\n", " button.href = '#';\n", " button.title = name;\n", " button.innerHTML = '';\n", " button.addEventListener('click', on_click_closure(method_name));\n", " button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n", " buttonGroup.appendChild(button);\n", " }\n", "\n", " if (buttonGroup.hasChildNodes()) {\n", " toolbar.appendChild(buttonGroup);\n", " }\n", "\n", " // Add the status bar.\n", " var status_bar = document.createElement('span');\n", " status_bar.classList = 'mpl-message pull-right';\n", " toolbar.appendChild(status_bar);\n", " this.message = status_bar;\n", "\n", " // Add the close button to the window.\n", " var buttongrp = document.createElement('div');\n", " buttongrp.classList = 'btn-group inline pull-right';\n", " button = document.createElement('button');\n", " button.classList = 'btn btn-mini btn-primary';\n", " button.href = '#';\n", " button.title = 'Stop Interaction';\n", " button.innerHTML = '';\n", " button.addEventListener('click', function (_evt) {\n", " fig.handle_close(fig, {});\n", " });\n", " button.addEventListener(\n", " 'mouseover',\n", " on_mouseover_closure('Stop Interaction')\n", " );\n", " buttongrp.appendChild(button);\n", " var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n", " titlebar.insertBefore(buttongrp, titlebar.firstChild);\n", "};\n", "\n", "mpl.figure.prototype._remove_fig_handler = function (event) {\n", " var fig = event.data.fig;\n", " if (event.target !== this) {\n", " // Ignore bubbled events from children.\n", " return;\n", " }\n", " fig.close_ws(fig, {});\n", "};\n", "\n", "mpl.figure.prototype._root_extra_style = function (el) {\n", " el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n", "};\n", "\n", "mpl.figure.prototype._canvas_extra_style = function (el) {\n", " // this is important to make the div 'focusable\n", " el.setAttribute('tabindex', 0);\n", " // reach out to IPython and tell the keyboard manager to turn it's self\n", " // off when our div gets focus\n", "\n", " // location in version 3\n", " if (IPython.notebook.keyboard_manager) {\n", " IPython.notebook.keyboard_manager.register_events(el);\n", " } else {\n", " // location in version 2\n", " IPython.keyboard_manager.register_events(el);\n", " }\n", "};\n", "\n", "mpl.figure.prototype._key_event_extra = function (event, _name) {\n", " var manager = IPython.notebook.keyboard_manager;\n", " if (!manager) {\n", " manager = IPython.keyboard_manager;\n", " }\n", "\n", " // Check for shift+enter\n", " if (event.shiftKey && event.which === 13) {\n", " this.canvas_div.blur();\n", " // select the cell after this one\n", " var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n", " IPython.notebook.select(index + 1);\n", " }\n", "};\n", "\n", "mpl.figure.prototype.handle_save = function (fig, _msg) {\n", " fig.ondownload(fig, null);\n", "};\n", "\n", "mpl.find_output_cell = function (html_output) {\n", " // Return the cell and output element which can be found *uniquely* in the notebook.\n", " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n", " // IPython event is triggered only after the cells have been serialised, which for\n", " // our purposes (turning an active figure into a static one), is too late.\n", " var cells = IPython.notebook.get_cells();\n", " var ncells = cells.length;\n", " for (var i = 0; i < ncells; i++) {\n", " var cell = cells[i];\n", " if (cell.cell_type === 'code') {\n", " for (var j = 0; j < cell.output_area.outputs.length; j++) {\n", " var data = cell.output_area.outputs[j];\n", " if (data.data) {\n", " // IPython >= 3 moved mimebundle to data attribute of output\n", " data = data.data;\n", " }\n", " if (data['text/html'] === html_output) {\n", " return [cell, data, j];\n", " }\n", " }\n", " }\n", " }\n", "};\n", "\n", "// Register the function which deals with the matplotlib target/channel.\n", "// The kernel may be null if the page has been refreshed.\n", "if (IPython.notebook.kernel !== null) {\n", " IPython.notebook.kernel.comm_manager.register_target(\n", " 'matplotlib',\n", " mpl.mpl_figure_comm\n", " );\n", "}\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib nbagg\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.animation as animation\n", "\n", "n_epoch = 300 # epoch size\n", "a, b = 1, 1 # initial parameters\n", "epsilon = 0.0001 # learning rate\n", "\n", "fig = plt.figure()\n", "imgs = []\n", "\n", "for i in range(n_epoch):\n", " data_idx = list(range(N))\n", " random.shuffle(data_idx)\n", " \n", " for j in data_idx[:10]:\n", " a = a + epsilon*2*(Y[j] - a*X[j] - b)*X[j]\n", " b = b + epsilon*2*(Y[j] - a*X[j] - b)\n", "\n", "\n", " if i<80 and i % 5 == 0:\n", " x_min = np.min(X)\n", " x_max = np.max(X)\n", " y_min = a * x_min + b\n", " y_max = a * x_max + b\n", "\n", " img = plt.scatter(X, Y, label='original data')\n", " img = plt.plot([x_min, x_max], [y_min, y_max], 'r', label='model')\n", " imgs.append(img)\n", " \n", "ani = animation.ArtistAnimation(fig, imgs)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 如何使用批次更新的方法?\n", "\n", "如果有一些数据包含比较大的错误(异常数据),因此每次更新仅仅使用一个数据会导致不精确,同时每次仅仅使用一个数据来计算更新也导致计算效率比较低。\n", "\n", "\n", "* [梯度下降方法的几种形式](https://blog.csdn.net/u010402786/article/details/51188876)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. 如何拟合多项式函数?\n", "\n", "需要设计一个弹道导弹防御系统,通过观测导弹的飞行路径,预测未来导弹的飞行轨迹,从而完成摧毁的任务。按照物理学,可以得知模型为:\n", "$$\n", "y = at^2 + bt + c\n", "$$\n", "我们需要求解三个模型参数$a, b, c$。\n", "\n", "损失函数的定义为:\n", "$$\n", "L = \\sum_{i=1}^N (y_i - at_i^2 - bt_i - c)^2\n", "$$\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "pa = -20\n", "pb = 90\n", "pc = 800\n", "\n", "t = np.linspace(0, 10) \n", "y = pa*t**2 + pb*t + pc + np.random.randn(np.size(t))*15\n", "\n", "\n", "plt.plot(t, y)\n", "plt.xlabel(\"time\")\n", "plt.ylabel(\"height\")\n", "plt.savefig(\"fig-res-missle_taj.pdf\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5.1 如何得到更新项?\n", "\n", "$$\n", "L = \\sum_{i=1}^N (y_i - at_i^2 - bt_i - c)^2\n", "$$\n", "\n", "\\begin{eqnarray}\n", "\\frac{\\partial L}{\\partial a} & = & - 2\\sum_{i=1}^N (y_i - at_i^2 - bt_i -c) t_i^2 \\\\\n", "\\frac{\\partial L}{\\partial b} & = & - 2\\sum_{i=1}^N (y_i - at_i^2 - bt_i -c) t_i \\\\\n", "\\frac{\\partial L}{\\partial c} & = & - 2\\sum_{i=1}^N (y_i - at_i^2 - bt_i -c)\n", "\\end{eqnarray}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5.2 程序" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0: loss = 2.47402e+07, a = -3.30002, b = 6.91887, c = 4.30695\n", "epoch 500: loss = 1.09272e+06, a = -29.2101, b = 226.526, c = 431.351\n", "epoch 1000: loss = 366835, a = -24.1553, b = 158.837, c = 599.169\n", "epoch 1500: loss = 135985, a = -21.2759, b = 120.282, c = 694.733\n", "epoch 2000: loss = 63728, a = -19.6362, b = 98.3262, c = 749.154\n", "epoch 2500: loss = 41779.5, a = -18.7024, b = 85.8231, c = 780.145\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "n_epoch = 3000 # epoch size\n", "a, b, c = 1.0, 1.0, 1.0 # initial parameters\n", "epsilon = 0.0001 # learning rate\n", "\n", "N = np.size(t)\n", "\n", "for i in range(n_epoch):\n", " for j in range(N):\n", " a = a + epsilon*2*(y[j] - a*t[j]**2 - b*t[j] - c)*t[j]**2\n", " b = b + epsilon*2*(y[j] - a*t[j]**2 - b*t[j] - c)*t[j]\n", " c = c + epsilon*2*(y[j] - a*t[j]**2 - b*t[j] - c)\n", "\n", " L = 0\n", " for j in range(N):\n", " L = L + (y[j] - a*t[j]**2 - b*t[j] - c)**2\n", " \n", " if i % 500 == 0:\n", " print(\"epoch %4d: loss = %10g, a = %10g, b = %10g, c = %10g\" % (i, L, a, b, c))\n", " \n", " \n", "y_est = a*t**2 + b*t + c \n", "\n", "\n", "plt.plot(t, y, 'r-', label='Real data')\n", "plt.plot(t, y_est, 'g-x', label='Estimated data')\n", "plt.legend()\n", "plt.savefig(\"fig-res-missle_est.pdf\")\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. 如何使用sklearn求解线性问题?\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X: (100, 1)\n", "Y: (100, 1)\n", "a = 2.773500, b = 5.311307\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "\n", "import matplotlib.pyplot as plt\n", "from sklearn import linear_model\n", "import numpy as np\n", "\n", "# load data\n", "# generate data\n", "data_num = 100\n", "X = np.random.rand(data_num, 1)*10\n", "Y = X * 3 + 4 + 8*np.random.randn(data_num,1)\n", "\n", "print(\"X: \", X.shape)\n", "print(\"Y: \", Y.shape)\n", "\n", "# create regression model\n", "regr = linear_model.LinearRegression()\n", "regr.fit(X, Y)\n", "\n", "a, b = np.squeeze(regr.coef_), np.squeeze(regr.intercept_)\n", "\n", "print(\"a = %f, b = %f\" % (a, b))\n", "\n", "x_min = np.min(X)\n", "x_max = np.max(X)\n", "y_min = a * x_min + b\n", "y_max = a * x_max + b\n", "\n", "plt.scatter(X, Y)\n", "plt.plot([x_min, x_max], [y_min, y_max], 'r')\n", "plt.xlabel(\"X\")\n", "plt.ylabel(\"Y\")\n", "plt.savefig(\"fig-res-sklearn_linear_fitting.pdf\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. 如何使用sklearn拟合多项式函数?" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([800., 90., -20.])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Fitting polynomial functions\n", "\n", "from sklearn.preprocessing import PolynomialFeatures\n", "from sklearn.linear_model import LinearRegression\n", "from sklearn.pipeline import Pipeline\n", "\n", "t = np.array([2, 4, 6, 8])\n", "\n", "pa = -20\n", "pb = 90\n", "pc = 800\n", "\n", "y = pa*t**2 + pb*t + pc\n", "\n", "model = Pipeline([('poly', PolynomialFeatures(degree=2)),\n", " ('linear', LinearRegression(fit_intercept=False))])\n", "model = model.fit(t[:, np.newaxis], y)\n", "model.named_steps['linear'].coef_\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 参考资料\n", "* [梯度下降法](https://blog.csdn.net/u010402786/article/details/51188876)\n", "* [如何理解最小二乘法?](https://blog.csdn.net/ccnt_2012/article/details/81127117)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 2 }