|
|
@@ -31,7 +31,7 @@ |
|
|
|
"source": [ |
|
|
|
"### 1.1 示例\n", |
|
|
|
"\n", |
|
|
|
"假设我们有下面的一些观测数据,我们希望找到他们内在的规律。" |
|
|
|
"假设我们有下面的一些观测数据,希望找到它们内在的规律。" |
|
|
|
] |
|
|
|
}, |
|
|
|
{ |
|
|
@@ -82,7 +82,7 @@ |
|
|
|
"$$\n", |
|
|
|
"其中$\\mathbf{X}$为自变量,$\\mathbf{Y}$为因变量。\n", |
|
|
|
"\n", |
|
|
|
"我们希望找到一个模型能够解释这些数据,假设使用最简单的线性模型来拟合数据:\n", |
|
|
|
"希望找到一个模型能够解释这些数据,假设使用最简单的线性模型来拟合数据:\n", |
|
|
|
"$$\n", |
|
|
|
"y = ax + b\n", |
|
|
|
"$$\n", |
|
|
@@ -190,11 +190,10 @@ |
|
|
|
"梯度下降法有很多优点,其中最主要的优点是,**在梯度下降法的求解过程中只需求解损失函数的一阶导数,计算的代价比较小,这使得梯度下降法能在很多大规模数据集上得到应用。**\n", |
|
|
|
"\n", |
|
|
|
"梯度下降法的含义是通过当前点的梯度方向寻找到新的迭代点。梯度下降法的基本思想可以类比为一个下山的过程。假设这样一个场景:\n", |
|
|
|
"* 一个人被困在山上,需要从山上下来(i.e. 找到山的最低点,也就是山谷)。\n", |
|
|
|
"* 一个人被困在山上,需要从山上下来,找到山的最低点,也就是山谷;\n", |
|
|
|
"* 但此时山上的浓雾很大,导致可视度很低。因此,下山的路径就无法全部确定,他必须利用自己周围的信息去找到下山的路径。\n", |
|
|
|
"* 这个时候,他就可以利用梯度下降算法来帮助自己下山。\n", |
|
|
|
" - 具体来说就是,以他当前的所处的位置为基准,寻找这个位置最陡峭的地方,然后朝着山的高度下降的地方走\n", |
|
|
|
" - 然后每走一段距离,都反复采用同一个方法,最后就能成功的抵达山谷。\n", |
|
|
|
"* 以他当前的所处的位置为基准,寻找这个位置最陡峭的地方,然后朝着山的高度下降的地方走\n", |
|
|
|
"* 每走一段距离,都反复采用同一个方法,最后就能成功的抵达山谷。\n", |
|
|
|
"\n", |
|
|
|
"\n", |
|
|
|
"一般情况下,这座山最陡峭的地方是无法通过肉眼立马观察出来的,而是需要一个工具来测量;同时,这个人此时正好拥有测量出最陡峭方向的能力。所以,此人每走一段距离,都需要一段时间来测量所在位置最陡峭的方向,这是比较耗时的。那么为了在太阳下山之前到达山底,就要尽可能的减少测量方向的次数。这是一个两难的选择,如果测量的频繁,可以保证下山的方向是绝对正确的,但又非常耗时;如果测量的过少,又有偏离轨道的风险。所以需要找到一个合适的测量方向的频率,来确保下山的方向不错误,同时又不至于耗时太多!\n", |
|
|
@@ -209,7 +208,7 @@ |
|
|
|
"L = \\sum_{i=1}^{N} (y_i - a x_i - b)^2\n", |
|
|
|
"$$\n", |
|
|
|
"\n", |
|
|
|
"我们更新的策略是:\n", |
|
|
|
"更新的策略是:\n", |
|
|
|
"$$\n", |
|
|
|
"\\theta^1 = \\theta^0 - \\eta \\triangledown L(\\theta)\n", |
|
|
|
"$$\n", |
|
|
@@ -217,7 +216,7 @@ |
|
|
|
"\n", |
|
|
|
"此公式的意义是:$L$是关于$\\theta$的一个函数,我们当前所处的位置为$\\theta_0$点,要从这个点走到L的最小值点,也就是山底。首先我们先确定前进的方向,也就是梯度的反向,然后走一段距离的步长,也就是$\\eta$,走完这个段步长,就到达了$\\theta_1$这个点!\n", |
|
|
|
"\n", |
|
|
|
"更新的策略是:\n", |
|
|
|
"最终的更新方程是:\n", |
|
|
|
"\n", |
|
|
|
"$$\n", |
|
|
|
"a^1 = a^0 + 2 \\eta [ y - (ax+b)]*x \\\\\n", |
|
|
@@ -1410,7 +1409,7 @@ |
|
|
|
"plt.plot(t, y)\n", |
|
|
|
"plt.xlabel(\"time\")\n", |
|
|
|
"plt.ylabel(\"height\")\n", |
|
|
|
"plt.savefig(\"missle_taj.pdf\")\n", |
|
|
|
"plt.savefig(\"fig-res-missle_taj.pdf\")\n", |
|
|
|
"plt.show()" |
|
|
|
] |
|
|
|
}, |
|
|
@@ -1493,7 +1492,7 @@ |
|
|
|
"plt.plot(t, y, 'r-', label='Real data')\n", |
|
|
|
"plt.plot(t, y_est, 'g-x', label='Estimated data')\n", |
|
|
|
"plt.legend()\n", |
|
|
|
"plt.savefig(\"missle_est.pdf\")\n", |
|
|
|
"plt.savefig(\"fig-res-missle_est.pdf\")\n", |
|
|
|
"plt.show()\n" |
|
|
|
] |
|
|
|
}, |
|
|
@@ -1562,7 +1561,7 @@ |
|
|
|
"plt.plot([x_min, x_max], [y_min, y_max], 'r')\n", |
|
|
|
"plt.xlabel(\"X\")\n", |
|
|
|
"plt.ylabel(\"Y\")\n", |
|
|
|
"plt.savefig(\"sklearn_linear_fitting.pdf\")\n", |
|
|
|
"plt.savefig(\"fig-res-sklearn_linear_fitting.pdf\")\n", |
|
|
|
"plt.show()" |
|
|
|
] |
|
|
|
}, |
|
|
@@ -1636,7 +1635,7 @@ |
|
|
|
"name": "python", |
|
|
|
"nbconvert_exporter": "python", |
|
|
|
"pygments_lexer": "ipython3", |
|
|
|
"version": "3.5.4" |
|
|
|
"version": "3.7.9" |
|
|
|
} |
|
|
|
}, |
|
|
|
"nbformat": 4, |
|
|
|