|
|
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 线性模型和梯度下降\n",
- "这是神经网络的第一课,我们会学习一个非常简单的模型,线性回归,同时也会学习一个优化算法-梯度下降法,对这个模型进行优化。线性回归是监督学习里面一个非常简单的模型,同时梯度下降也是深度学习中应用最广的优化算法,我们将从这里开始我们的深度学习之旅"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 一元线性回归\n",
- "一元线性模型非常简单,假设我们有变量 $x_i$ 和目标 $y_i$,每个 i 对应于一个数据点,希望建立一个模型\n",
- "\n",
- "$$\n",
- "\\hat{y}_i = w x_i + b\n",
- "$$\n",
- "\n",
- "$\\hat{y}_i$ 是我们预测的结果,希望通过 $\\hat{y}_i$ 来拟合目标 $y_i$,通俗来讲就是找到这个函数拟合 $y_i$ 使得误差最小,即最小化\n",
- "\n",
- "$$\n",
- "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n",
- "$$"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "那么如何最小化这个误差呢?\n",
- "\n",
- "这里需要用到**梯度下降**,这是我们接触到的第一个优化算法,非常简单,但是却非常强大,在深度学习中被大量使用,所以让我们从简单的例子出发了解梯度下降法的原理"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 梯度下降法\n",
- "在梯度下降法中,我们首先要明确梯度的概念,随后我们再了解如何使用梯度进行下降。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 梯度\n",
- "梯度在数学上就是导数,如果是一个多元函数,那么梯度就是偏导数。比如一个函数f(x, y),那么 f 的梯度就是 \n",
- "\n",
- "$$\n",
- "(\\frac{\\partial f}{\\partial x},\\ \\frac{\\partial f}{\\partial y})\n",
- "$$\n",
- "\n",
- "可以称为 grad f(x, y) 或者 $\\nabla f(x, y)$。具体某一点 $(x_0,\\ y_0)$ 的梯度就是 $\\nabla f(x_0,\\ y_0)$。\n",
- "\n",
- "下面这个图片是 $f(x) = x^2$ 这个函数在 x=1 处的梯度\n",
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "梯度有什么意义呢?从几何意义来讲,一个点的梯度值是这个函数变化最快的地方,具体来说,对于函数 f(x, y),在点 $(x_0, y_0)$ 处,沿着梯度 $\\nabla f(x_0,\\ y_0)$ 的方向,函数增加最快,也就是说沿着梯度的方向,我们能够更快地找到函数的极大值点,或者反过来沿着梯度的反方向,我们能够更快地找到函数的最小值点。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 梯度下降法\n",
- "有了对梯度的理解,我们就能了解梯度下降发的原理了。上面我们需要最小化这个误差,也就是需要找到这个误差的最小值点,那么沿着梯度的反方向我们就能够找到这个最小值点。\n",
- "\n",
- "我们可以来看一个直观的解释。比如我们在一座大山上的某处位置,由于我们不知道怎么下山,于是决定走一步算一步,也就是在每走到一个位置的时候,求解当前位置的梯度,沿着梯度的负方向,也就是当前最陡峭的位置向下走一步,然后继续求解当前位置梯度,向这一步所在位置沿着最陡峭最易下山的位置走一步。这样一步步的走下去,一直走到觉得我们已经到了山脚。当然这样走下去,有可能我们不能走到山脚,而是到了某一个局部的山峰低处。\n",
- "\n",
- "类比我们的问题,就是沿着梯度的反方向,我们不断改变 w 和 b 的值,最终找到一组最好的 w 和 b 使得误差最小。\n",
- "\n",
- "在更新的时候,我们需要决定每次更新的幅度,比如在下山的例子中,我们需要每次往下走的那一步的长度,这个长度称为学习率,用 $\\eta$ 表示,这个学习率非常重要,不同的学习率都会导致不同的结果,学习率太小会导致下降非常缓慢,学习率太大又会导致跳动非常明显,可以看看下面的例子\n",
- "\n",
- "\n",
- "\n",
- "可以看到上面的学习率较为合适,而下面的学习率太大,就会导致不断跳动\n",
- "\n",
- "最后我们的更新公式就是\n",
- "\n",
- "$$\n",
- "w := w - \\eta \\frac{\\partial f(w,\\ b)}{\\partial w} \\\\\n",
- "b := b - \\eta \\frac{\\partial f(w,\\ b)}{\\partial b}\n",
- "$$\n",
- "\n",
- "通过不断地迭代更新,最终我们能够找到一组最优的 w 和 b,这就是梯度下降法的原理。\n",
- "\n",
- "最后可以通过这张图形象地说明一下这个方法\n",
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "上面是原理部分,下面通过一个例子来进一步学习线性模型"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<torch._C.Generator at 0x7f53f805ecd0>"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import torch\n",
- "import numpy as np\n",
- "from torch.autograd import Variable\n",
- "\n",
- "torch.manual_seed(2017)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 读入数据 x 和 y\n",
- "x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],\n",
- " [9.779], [6.182], [7.59], [2.167], [7.042],\n",
- " [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)\n",
- "\n",
- "y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],\n",
- " [3.366], [2.596], [2.53], [1.221], [2.827],\n",
- " [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[<matplotlib.lines.Line2D at 0x7f539410ee48>]"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAD8NJREFUeJzt3X+IHOd9x/HP5yRR++IQt9WRqJLuttCQkpjaShfXrqEYuwY3NXagLrhsXaekHIS0sYuh1DlwSeBKCsX9EUPMYqdR2sVNkE2qmritSAyJoVFYqbJsSYYYqjvLVaqzXct2N3Wr6Ns/ZoVOm7vs7N3uzuwz7xcss/Pco90vy95Hz81+Z9YRIQBAWqaKLgAAMHyEOwAkiHAHgAQR7gCQIMIdABJEuANAggh3AEgQ4Q4ACSLcASBBW4t64u3bt0etVivq6QFgIh06dOjViJjpN6+wcK/Vamq320U9PQBMJNtLeeZxWAYAEkS4A0CC+oa77ctsf9f2c7aP2f7MGnM+ZnvF9pHu7fdGUy4AII88x9zfkXRTRLxte5ukZ20/HRHf6Zn3lYj4/eGXCAAYVN9wj+yC7293d7d1b1wEHgBKLNcxd9tbbB+RdEbSgYg4uMa037B91PY+27vXeZx5223b7ZWVlU2UDQCTp9WSajVpairbtlqje65c4R4RP4yIayTtknSt7at6pvyjpFpE/IKkA5L2rvM4zYioR0R9ZqZvmyYAJKPVkubnpaUlKSLbzs+PLuAH6paJiDckPSPp1p7x1yLine7uo5J+cTjlAUAaFhakTufSsU4nGx+FPN0yM7av7N6/XNItkl7smbNj1e7tkk4Ms0gAmHTLy4ONb1aebpkdkvba3qLsP4OvRsRTtj8rqR0R+yV9yvbtks5Jel3Sx0ZTLgBMptnZ7FDMWuOjkKdb5qikPWuMP7jq/gOSHhhuaQCQjsXF7Bj76kMz09PZ+ChwhioAjEGjITWb0tycZGfbZjMbH4XCLhwGAFXTaIwuzHuxcgeABBHuAJI1zpOGyobDMgCSdOGkoQsfYF44aUga36GRIrFyB5CkcZ80VDaEO4AkjfukobIh3AEkab2Tg0Z10lDZEO4AkrS4mJ0ktNooTxoqG8IdQJLGfdJQ2dAtAyBZ4zxpqGxYuQNAggh3AEgQ4Q4ACSLcASBBhDsAJIhwB4AEEe4AkCDCHRhAlS8hi8nCSUxATlW/hCwmCyt3IKeqX0IWk4VwB3Kq+iVkMVkIdyCnql9CFpOFcAdyqvolZDFZCHcgp6pfQhaThW4ZYABVvoQsJgsrdwBIEOEOAAki3AEgQYQ7ACSIcAeABBHuAJAgwh0AEkS4A0CCCHcASFDfcLd9me3v2n7O9jHbn1ljzk/Y/ortl2wftF0bRbEAgHzyrNzfkXRTRFwt6RpJt9q+rmfOxyX9V0T8nKS/kPRnwy0TADCIvuEembe7u9u6t+iZdoekvd37+yTdbNtDqxIAMJBcx9xtb7F9RNIZSQci4mDPlJ2SXpakiDgn6aykn17jceZtt223V1ZWNlc5AGBducI9In4YEddI2iXpWttXbeTJIqIZEfWIqM/MzGzkIQAAOQzULRMRb0h6RtKtPT96RdJuSbK9VdJ7JL02jAIBAIPL0y0zY/vK7v3LJd0i6cWeafsl3dO9f6ekb0ZE73F5AMCY5Pmyjh2S9treouw/g69GxFO2PyupHRH7JT0m6W9tvyTpdUl3jaxiAEBffcM9Io5K2rPG+IOr7v+PpN8cbmkAgI3iDFUgca2WVKtJU1PZttUquiKMA9+hCiSs1ZLm56VOJ9tfWsr2Jb4LNnWs3IGELSxcDPYLOp1sHGkj3IGELS8PNo50EO5AwmZnBxtHOgh3IGGLi9L09KVj09PZONJGuAMjUoYulUZDajaluTnJzrbNJh+mVgHdMsAIlKlLpdEgzKuIlTswAnSpoGiEOzACdKmgaIQ7MAJ0qaBohDswAnSpoGiEe0WUoXOjSuhSQdHolqmAMnVuVAldKigSK/cKoHMDqB7CvQLo3ACqh3CvADo3gOoh3CuAzg2gegj3CqBzA6geumUqgs4NoFpYuQNAggh3AEgQ4Q4ACSLcASBBhDsAJIhwB4AEEe4AkCDCHcnjcseoIk5iQtK43DGqipU7ksbljlFVhDuSxuWOUVWEO5LG5Y5RVYQ7ksbljlFVhDuSltLljun6wSDolkHyUrjcMV0/GFTflbvt3bafsX3c9jHb964x50bbZ20f6d4eHE25QDXR9YNB5Vm5n5N0f0Qctv1uSYdsH4iI4z3zvh0Rtw2/RAB0/WBQfVfuEXE6Ig53778l6YSknaMuDMBFdP1gUAN9oGq7JmmPpINr/Ph628/Zftr2h9b59/O227bbKysrAxcLVBVdPxhU7nC3fYWkJyTdFxFv9vz4sKS5iLha0uclfW2tx4iIZkTUI6I+MzOz0ZqBykmp6wfj4YjoP8neJukpSf8cEQ/lmH9SUj0iXl1vTr1ej3a7PUCpAADbhyKi3m9enm4ZS3pM0on1gt32+7rzZPva7uO+NljJAIBhydMtc4OkuyU9b/tId+zTkmYlKSIekXSnpE/YPifpB5Luijx/EgAARqJvuEfEs5LcZ87Dkh4eVlEAgM3h8gMAkCDCHQASRLgDQIIIdwBIEOEOAAki3AEgQYQ7ACSIcAeABBHuAJAgwh0AEkS4A0CCCHcASBDhDgAJItwBIEGEOwAkiHAHgAQR7gCQIMIdABJEuANAggh3AEgQ4Q4ACSLcASBBhDsAJIhwB4AEEe4AkCDCHQASRLgDQIIIdxSu1ZJqNWlqKtu2WkVXBEy+rUUXgGprtaT5eanTyfaXlrJ9SWo0iqsLmHSs3FGohYWLwX5Bp5ONA9g4wh2FWl4ebBxAPoQ7CjU7O9g4gHwIdxRqcVGanr50bHo6GwewcYQ7CtVoSM2mNDcn2dm22eTDVGCz6JZB4RoNwhwYtr4rd9u7bT9j+7jtY7bvXWOObf+17ZdsH7X94dGUCwDII8/K/Zyk+yPisO13Szpk+0BEHF8159ckvb97+yVJX+huAQAF6Ltyj4jTEXG4e/8tSSck7eyZdoekL0fmO5KutL1j6NUCAHIZ6ANV2zVJeyQd7PnRTkkvr9o/pR/9D0C25223bbdXVlYGqxQAkFvucLd9haQnJN0XEW9u5MkiohkR9Yioz8zMbOQhAAA55Ap329uUBXsrIp5cY8orknav2t/VHQMAFCBPt4wlPSbpREQ8tM60/ZJ+p9s1c52ksxFxeoh1AgAGkKdb5gZJd0t63vaR7tinJc1KUkQ8Iunrkj4i6SVJHUm/O/xSAQB59Q33iHhWkvvMCUmfHFZRAIDN4fIDAJAgwh0AEkS4A0CCCHcASBDhDgAJItwBIEGEOwAkiHAHgAQR7gCQIMIdABJEuANAggh3AEgQ4Q4ACSLcASBBhDsAJIhwB4AEEe4AkCDCHQASRLgPUasl1WrS1FS2bbWKrgjjxnsAZZHnC7KRQ6slzc9LnU62v7SU7UtSo1FcXRgf3gMoE2ffbT1+9Xo92u12Ic89CrVa9svca25OOnly3NWgCLwHMA62D0VEvd88DssMyfLyYONID+8BlAnhPiSzs4ONV00VjkXzHkCZEO5DsrgoTU9fOjY9nY1X3YVj0UtLUsTFY9GpBTzvAZQJ4T4kjYbUbGbHV+1s22zyQZokLSxc/JDxgk4nG08J7wGUCR+oYuSmprIVey9bOn9+/PUAk4wPVFEaHIsGxo9wx8hxLBoYP8IdI8exaGD8CPdElL3VsNHITuQ5fz7bEuzAaHH5gQRw2juAXqzcE1CVVkMA+RHuCeC0dwC9CPcE0GoIoBfhngBaDQH06hvutr9o+4ztF9b5+Y22z9o+0r09OPwy8ePQagigV55umS9JeljSl3/MnG9HxG1DqQgb0mgQ5gAu6rtyj4hvSXp9DLUAAIZkWMfcr7f9nO2nbX9ovUm25223bbdXVlaG9NQAgF7DCPfDkuYi4mpJn5f0tfUmRkQzIuoRUZ+ZmRnCUwMA1rLpcI+INyPi7e79r0vaZnv7pisDAGzYpsPd9vtsu3v/2u5jvrbZxwUAbFzfbhnbj0u6UdJ226ck/YmkbZIUEY9IulPSJ2yfk/QDSXdFUd8AAgCQlCPcI+K3+vz8YWWtkgCAkuAMVQBIEOEOAAki3AEgQYQ7ACSIcAeABBHuAJAgwh0AEkS4A0CCCHcASBDhPqBWS6rVpKmpbNtqFV0RAPyoPN/EhK5WS5qflzqdbH9pKduX+BYkAOXCyn0ACwsXg/2CTicbB4AyIdwHsLw82DgAFIVwH8Ds7GDjAFAUwn0Ai4vS9PSlY9PT2TgAlAnhPoBGQ2o2pbk5yc62zSYfpgIon4kK9zK0ITYa0smT0vnz2ZZgB1BGE9MKSRsiAOQ3MSt32hABIL+JCXfaEAEgv4kJd9oQASC/iQl32hABIL+JCXfaEAEgv4nplpGyICfMAaC/iVm5AwDyI9wBIEGEOwAkiHAHgAQR7gCQIEdEMU9sr0hayjF1u6RXR1zOJOJ1WR+vzdp4XdY3Sa/NXETM9JtUWLjnZbsdEfWi6ygbXpf18dqsjddlfSm+NhyWAYAEEe4AkKBJCPdm0QWUFK/L+nht1sbrsr7kXpvSH3MHAAxuElbuAIABlTLcbe+2/Yzt47aP2b636JrKxPYW2/9m+6miaykT21fa3mf7RdsnbF9fdE1lYfsPu79LL9h+3PZlRddUFNtftH3G9gurxn7K9gHb3+tuf7LIGoehlOEu6Zyk+yPig5Kuk/RJ2x8suKYyuVfSiaKLKKG/kvRPEfHzkq4Wr5EkyfZOSZ+SVI+IqyRtkXRXsVUV6kuSbu0Z+2NJ34iI90v6Rnd/opUy3CPidEQc7t5/S9kv6c5iqyoH27sk/bqkR4uupUxsv0fSr0h6TJIi4n8j4o1iqyqVrZIut71V0rSk/yi4nsJExLckvd4zfIekvd37eyV9dKxFjUApw3012zVJeyQdLLaS0vhLSX8k6XzRhZTMz0pakfQ33UNWj9p+V9FFlUFEvCLpzyUtSzot6WxE/EuxVZXOeyPidPf+9yW9t8hihqHU4W77CklPSLovIt4sup6i2b5N0pmIOFR0LSW0VdKHJX0hIvZI+m8l8Kf1MHSPH9+h7D/An5H0Ltu/XWxV5RVZC+HEtxGWNtxtb1MW7K2IeLLoekriBkm32z4p6e8l3WT774otqTROSToVERf+wtunLOwh/aqkf4+IlYj4P0lPSvrlgmsqm/+0vUOSutszBdezaaUMd9tWduz0REQ8VHQ9ZRERD0TEroioKftA7JsRwQpMUkR8X9LLtj/QHbpZ0vECSyqTZUnX2Z7u/m7dLD5s7rVf0j3d+/dI+ocCaxmKUoa7shXq3cpWpke6t48UXRRK7w8ktWwflXSNpD8tuJ5S6P41s0/SYUnPK/u9T+6MzLxsPy7pXyV9wPYp2x+X9DlJt9j+nrK/dD5XZI3DwBmqAJCgsq7cAQCbQLgDQIIIdwBIEOEOAAki3AEgQYQ7ACSIcAeABBHuAJCg/weHsaZQFbgrMwAAAABJRU5ErkJggg==\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出图像\n",
- "import matplotlib.pyplot as plt\n",
- "%matplotlib inline\n",
- "\n",
- "plt.plot(x_train, y_train, 'bo')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 转换成 Tensor\n",
- "x_train = torch.from_numpy(x_train)\n",
- "y_train = torch.from_numpy(y_train)\n",
- "\n",
- "# 定义参数 w 和 b\n",
- "w = Variable(torch.randn(1), requires_grad=True) # 随机初始化\n",
- "b = Variable(torch.zeros(1), requires_grad=True) # 使用 0 进行初始化"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 构建线性回归模型\n",
- "x_train = Variable(x_train)\n",
- "y_train = Variable(y_train)\n",
- "\n",
- "def linear_model(x):\n",
- " return x * w + b"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [],
- "source": [
- "y_ = linear_model(x_train)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "经过上面的步骤我们就定义好了模型,在进行参数更新之前,我们可以先看看模型的输出结果长什么样"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7f53940f2320>"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
- "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**思考:红色的点表示预测值,似乎排列成一条直线,请思考一下这些点是否在一条直线上?**"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "这个时候需要计算我们的误差函数,也就是\n",
- "\n",
- "$$\n",
- "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n",
- "$$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 计算误差\n",
- "def get_loss(y_, y):\n",
- " return torch.mean((y_ - y) ** 2)\n",
- "\n",
- "loss = get_loss(y_, y_train)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(153.3520, grad_fn=<MeanBackward1>)\n"
- ]
- }
- ],
- "source": [
- "# 打印一下看看 loss 的大小\n",
- "print(loss)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "定义好了误差函数,接下来我们需要计算 w 和 b 的梯度了,这时得益于 PyTorch 的自动求导,我们不需要手动去算梯度,有兴趣的同学可以手动计算一下,w 和 b 的梯度分别是\n",
- "\n",
- "$$\n",
- "\\frac{\\partial}{\\partial w} = \\frac{2}{n} \\sum_{i=1}^n x_i(w x_i + b - y_i) \\\\\n",
- "\\frac{\\partial}{\\partial b} = \\frac{2}{n} \\sum_{i=1}^n (w x_i + b - y_i)\n",
- "$$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 自动求导\n",
- "loss.backward()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor([161.0043])\n",
- "tensor([22.8730])\n"
- ]
- }
- ],
- "source": [
- "# 查看 w 和 b 的梯度\n",
- "print(w.grad)\n",
- "print(b.grad)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 更新一次参数\n",
- "w.data = w.data - 1e-2 * w.grad.data\n",
- "b.data = b.data - 1e-2 * b.grad.data"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "更新完成参数之后,我们再一次看看模型输出的结果"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7f53940bc438>"
- ]
- },
- "execution_count": 19,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAFfJJREFUeJzt3X9w3HWdx/HXO2mgtOXwbDMIliR15uxZUlrClinHiT1osVJUGM4ZO1Gpdxi1lqung4PTP47DIo5z08rgDJIrikBAj6Ie0+E8BKr4gwO2NXDYSuuVpAbQxlYrbVraJu/7YzehDZvud5P97vf72X0+ZjLJbje77+5OX/1839/P5/M1dxcAIBx1SRcAACgNwQ0AgSG4ASAwBDcABIbgBoDAENwAEBiCGwACQ3ADQGAIbgAIzKQ4nnTGjBne0tISx1MDQFXasmXLH9y9McpjYwnulpYWZbPZOJ4aAKqSmfVGfSytEgAITNHgNrPZZtZ93NefzeyzlSgOAPBmRVsl7v6ipPmSZGb1kl6W9P2Y6wIAjKHUHvdlkv7P3SP3YoYdPXpUfX19Onz4cKm/ijFMnjxZM2fOVENDQ9KlAKigUoP7w5IeKPQHZtYhqUOSmpqa3vTnfX19Ov3009XS0iIzK7VOjOLu2rt3r/r6+jRr1qykywFQQZFPTprZKZI+IOnBQn/u7p3unnH3TGPjm2e0HD58WNOnTye0y8TMNH36dI5ggDTo6pJaWqS6utz3rq5YX66UEff7JG1199+P98UI7fLi/QRSoKtL6uiQBgZyt3t7c7clqb09lpcsZTrgco3RJgGAmrVmzRuhPWxgIHd/TCIFt5lNlbRE0vdiqyTlVqxYoY0bNyZdBoC02b27tPvLIFJwu/tBd5/u7vtjq2SUOFtG7q6hoaHyPSGA2lVgMsZJ7y+DVK6cHG4Z9fZK7m+0jCYS3j09PZo9e7Y+9rGPqbW1Vffee68uuugitbW16UMf+pAOHDggSbr55pu1YMECtba2qqOjQ+5epr8VgKp0yy3SlCkn3jdlSu7+mKQyuONqGe3cuVMrV67UT37yE91111167LHHtHXrVmUyGa1bt06StGrVKj377LN64YUXdOjQIW3atGliLwqgurW3S52dUnOzZJb73tkZ24lJKaZNpiYqrpZRc3OzFi5cqE2bNmnbtm26+OKLJUlHjhzRRRddJEnavHmzvvrVr2pgYED79u3Tueeeq/e///0Te2EA1a29PdagHi2Vwd3UlGuPFLp/IqZOnSop1+NesmSJHnjgxEkyhw8f1sqVK5XNZnXOOefopptuYp40gNRJZask7pbRwoUL9fOf/1y/+c1vJEkHDx7Ujh07RkJ6xowZOnDgALNIAKRSKkfcw0cca9bk2iNNTbnQLteRSGNjo+6++24tX75cr7/+uiRp7dq1euc736lPfOITam1t1dve9jYtWLCgPC8IAGVkccyayGQyPvpCCtu3b9e73vWusr9WreN9BaqDmW1x90yUx6ayVQIAGBvBDQCBIbgBIDAENwAEhuAGgMAQ3AAQGIK7gLvvvluvvPLKyO3rrrtO27Ztm/Dz9vT06P777y/599hSFsDx0hvcFb4U0PFGB/eGDRs0Z86cCT/veIMbAI6XzuCOY19XSffdd58uvPBCzZ8/X5/85Cc1ODioFStWqLW1VXPnztX69eu1ceNGZbNZtbe3a/78+Tp06JAWLVqk4QVF06ZN0w033KBzzz1Xixcv1jPPPKNFixbpHe94hx5++GFJuYB+97vfrba2NrW1tekXv/iFJOnGG2/UT3/6U82fP1/r16/X4OCgbrjhBi1YsEDnnXee7rzzTkm5vVRWrVql2bNna/HixdqzZ8+E/t4Aqoy7l/3rggsu8NG2bdv2pvvG1NzsnovsE7+am6M/R4HXv/LKK/3IkSPu7v7pT3/ab7rpJl+8ePHIY/74xz+6u/t73vMef/bZZ0fuP/62JH/kkUfc3f2qq67yJUuW+JEjR7y7u9vnzZvn7u4HDx70Q4cOubv7jh07fPj92Lx5sy9btmzkee+8807/0pe+5O7uhw8f9gsuuMB37drlDz30kC9evNiPHTvmL7/8sp9xxhn+4IMPjvn3AhA+SVmPmLGp3Kskjn1dH3/8cW3ZsmVk/5FDhw5p6dKl2rVrl66//notW7ZMl19+edHnOeWUU7R06VJJ0ty5c3XqqaeqoaFBc+fOVU9PjyTp6NGjWrVqlbq7u1VfX68dO3YUfK5HH31Uzz///Ej/ev/+/dq5c6eefPJJLV++XPX19Tr77LN16aWXjvvvDaD6pLNVEsOlgNxd1157rbq7u9Xd3a0XX3xRt912m5577jktWrRI3/jGN3TdddcVfZ6GhoaRq6vX1dXp1FNPHfn52LFjkqT169frzDPP1HPPPadsNqsjR46MWdPtt98+UtNLL70U6T8PoOYleA4sDdIZ3DHs63rZZZdp48aNI/3iffv2qbe3V0NDQ7rmmmu0du1abd26VZJ0+umn67XXXhv3a+3fv19nnXWW6urqdO+992pwcLDg8773ve/VHXfcoaNHj0qSduzYoYMHD+qSSy7Rd7/7XQ0ODurVV1/V5s2bx10LUHViOgcWknS2SmLY13XOnDlau3atLr/8cg0NDamhoUHr1q3T1VdfPXLh4FtvvVVSbvrdpz71KZ122ml66qmnSn6tlStX6pprrtE999yjpUuXjlzA4bzzzlN9fb3mzZunFStWaPXq1erp6VFbW5vcXY2NjfrBD36gq6++Wk888YTmzJmjpqamkavzANDJr21YwavQJIltXQPH+4qaU1eXG2mPZiblB2EhYltXANUrhnNgoSG4AYQl7msbBqCiwR1HW6aW8X6iJrW3S52dUnNzrj3S3Jy7XSP9bSniyUkze4ukDZJaJbmkf3D3ks7aTZ48WXv37tX06dNHptNh/Nxde/fu1eTJk5MuBai89vaaCurRos4quU3SD939783sFElTiv3CaDNnzlRfX5/6+/tL/VWMYfLkyZo5c2bSZQCosKLBbWZnSLpE0gpJcvcjkgqvKDmJhoYGzZo1q9RfAwCMEqXHPUtSv6RvmdkvzWyDmU2NuS4AwBiiBPckSW2S7nD38yUdlHTj6AeZWYeZZc0sSzsEAOITJbj7JPW5+9P52xuVC/ITuHunu2fcPdPY2FjOGgEAxyka3O7+O0m/NbPZ+bsukzTxy8EAAMYl6qyS6yV15WeU7JL08fhKAgCcTKTgdvduSZHW0AMA4sWSdwAIDMENAIEhuAEgMAQ3AASG4AaAwBDcABAYghsAAkNwA0BgCG4ACAzBDQCBIbgBIDAENwAEhuAGgMAQ3AAQGIIbAAJDcANAYAhuAAgMwQ0AgSG4ASAwBDcABIbgBoDAENwAEBiCGwACQ3ADQGAIbgAIDMENhKyrS2ppkerqct+7upKuCBUwKcqDzKxH0muSBiUdc/dMnEUBiKCrS+rokAYGcrd7e3O3Jam9Pbm6ELtSRtx/5+7zCW0gJdaseSO0hw0M5O5HVaNVAoRq9+7S7kfViBrcLulRM9tiZh1xFgQgoqam0u5H1Yga3H/r7m2S3ifpM2Z2yegHmFmHmWXNLNvf31/WIgEUcMst0pQpJ943ZUruflS1SMHt7i/nv++R9H1JFxZ4TKe7Z9w909jYWN4qAbxZe7vU2Sk1N0tmue+dnZyYrAFFZ5WY2VRJde7+Wv7nyyXdHHtlAIprbyeoa1CU6YBnSvq+mQ0//n53/2GsVQEAxlQ0uN19l6R5FagFABAB0wGBpLH6ESWKtHISQExY/YhxYMQNJInVjxgHghtIEqsfMQ4EN5AkVj9iHAhuIEmsfsQ4ENxAklj9iHFgVgmQNFY/okSMuAEgMAQ3AASG4AaAwBDcABAYghsYD/YXQYKYVQKUiv1FkDBG3ECp2F8ECSO4gSiOb4309hZ+DPuLoEJolQDFjG6NjIX9RVAhjLiBYgq1RkZjfxFUEMENFHOyFgj7iyABtEqAYpqaCve1m5ulnp6KlwMw4gakk8/LZutVpAzBDQyffOztldzfmJc9HN5svYqUMXcv+5NmMhnPZrNlf14gFi0ttEKQODPb4u6ZKI9lxA1w3UcEhuAGuO4jAkNwA5x8RGAiB7eZ1ZvZL81sU5wFARXHyUcEppR53KslbZf0FzHVAiSH6z4iIJFG3GY2U9IySRviLQcAUEzUVsnXJH1B0lCMtQAAIiga3GZ2paQ97r6lyOM6zCxrZtn+/v6yFQgAOFGUEffFkj5gZj2SviPpUjO7b/SD3L3T3TPunmlsbCxzmQCAYUWD292/6O4z3b1F0oclPeHuH4m9MqAYrvuIGsXugAgT131EDWOvEoSJ/UVQZdirBNWP/UVQwwhuhIn9RVDDCG6Eif1FUMMIboSJ/UVQw5hVgnCxvwhqFCNuAAgMwQ0AgSG4ASAwBDcABIbgBoDAENwAMEGV3u+M6YAAMAFJ7HfGiBsAJmDNmjdCe9jAQO7+uBDciA/7ZaMGJLHfGcGNeAwfP/b2Su5vHD8S3iiDNI0JktjvjOBGPJI4fkRNSNuYIIn9zghuxIP9shGTtI0JktjvjODGxBU6bmW/bMQkjWOC9vbchZeGhnLf4977jODGxIx13HrFFeyXjVgwJiC4MVFjHbc+8gj7ZSMWXEOD4MZEney4tdLHjxOQplkKODmuocHKSUxUU1Phq60HdNyaxMo3TEytX0ODETcmpgqOW9M2SwEohuDGxFTBcWsaZykAJ0OrBBMX+HFrFXR7UGMYcaPmVUG3BzWmaHCb2WQze8bMnjOzX5nZv1aiMKBSqqDbgxoTpVXyuqRL3f2AmTVI+pmZ/Ze7/0/MtQEVE3i3BzWmaHC7u0s6kL/ZkP/yOIsCAIwtUo/bzOrNrFvSHkk/cvenCzymw8yyZpbt7+8vd50AgLxIwe3ug+4+X9JMSReaWWuBx3S6e8bdM42NjeWuE0ABrPisTSXNKnH3P0naLGlpPOUAiCpt+1KjcqLMKmk0s7fkfz5N0hJJv467MCDN0jDSZcVn7Yoyq+QsSd82s3rlgv4/3H1TvGUB6ZWWvU1Y8Vm7LDdppLwymYxns9myPy+QBi0thVdaNjfnNkGstTpQHma2xd0zUR7LyskqkIbD9lqSlpEuKz5rF8EdOE5QVV5arsDCis/aRaskcBwuV97oHreUG+kSmpgIWiU1ZPduabm69JJaNKg6vaQWLVcXJ6hixEgXSWNb18CtemuXbt3boanKDf9a1Kt/V4dmvFWSSJK4sLcJksSIO3Bf1pqR0B42VQP6sqp/Mi8nZVGrGHEHbtq+wj2Rse6vFmmZSw0kgRF36NIyxaHCWDWIWkZwp12xfkCNTuZNy1zqcqDlg1IR3GkWZZJ2jU5xqJYDDebhYzyYx51mTNIeU7XMpeYjxjDmcVeLauoHlFm1HGjwEWM8mFWSZk1NhYdjofUDYlINc6n5iDEejLjTrEZPPNYSPmKMB8GdZtXSD8CY+IgxHpycBIAU4OQkAFQxghsAAkNwA0BgCG4ACAzBDQCBIbgBIDAENwAEhuBGbNiuFIgHe5UgFlyhBohP0RG3mZ1jZpvNbJuZ/crMVleiMISNK9QA8Yky4j4m6fPuvtXMTpe0xcx+5O7bYq4NAWO7UiA+RUfc7v6qu2/N//yapO2S3h53YQhbtVyhBkijkk5OmlmLpPMlPR1HManGmbaSsF0pEJ/IwW1m0yQ9JOmz7v7nAn/eYWZZM8v29/eXs8bkcWHAkrFdKRCfSNu6mlmDpE2S/tvd1xV7fNVt68qFAQHErKzbupqZSbpL0vYooV2VONMGIEWitEoulvRRSZeaWXf+64qY60oXzrQBSJEos0p+5u7m7ue5+/z81yOVKK7ixjoByZk2ACnCyslhUZb6rVmTa480NeVCmzNtABLANSeHcQISQIK45uR4cAISQCAI7mFFTkCy/qa28fkjTQjuYSc5Acn6m5Or9lDj80fa0OM+XldXwROQtL/HNvqcrpT7/66aVkny+aMSSulxE9wR1NXlRlqjmUlDQ5WvJ01qIdT4/FEJnJwssyTX36S9DVEL53RZf4W0IbgjSGr9TQi91VoINdZfIW0I7giS2ukuhKvI1EKosdMh0oYed4qF0lsd45wugBKU0uNmyXuKNTUVPvGXtjZEeztBDVQSrZIUq4U2BIDSEdwpRm8VQCG0SlKONgSA0RhxA0BgCO7jpH2xCwBItEpGRLmOAgCkASPuvBAWuwCAlKLgTrpNUQt7bgCoDqkI7q4u6bGPd+nHvS065nX6cW+LHvt4V0XDuxb23ABQHVIR3E+v7tLXj3aoRb2qk6tFvfr60Q49vbpyyc1iFwChSEVwf27vGk3ViQ3mqRrQ5/ZWrsHMYhcAoUjFJlNDVqc6vbmOIZnqPEW7KQFATIK7kMLA9MKN5LHuB4BalorgnnbbLTp2yokN5mOnTNG022gwA8BoRYPbzL5pZnvM7IXYqmhv16RvnthgnvRNGswAUEjRHreZXSLpgKR73L01ypNyIQUAKE1Ze9zu/qSkfROuCgBQFqnocQMAoitbcJtZh5llzSzb399frqcFAIxStuB29053z7h7prGxsVxPCwAYhVYJAAQmyqySByQtkjRD0u8l/Yu731Xkd/olFbg++QlmSPpD5EprB+/L2HhvxsZ7U1hI70uzu0dqV8Sy5D3SC5tlo059qSW8L2PjvRkb701h1fq+0CoBgMAQ3AAQmCSDuzPB104z3pex8d6MjfemsKp8XxLrcQMAxodWCQAEpqLBbWbnmNlmM9tmZr8ys9WVfP0QmFm9mf3SzDYlXUuamNlbzGyjmf3azLab2UVJ15QGZvbP+X9LL5jZA2Y2OemaklJoJ1Mze6uZ/cjMdua//2WSNZZLpUfcxyR93t3nSFoo6TNmNqfCNaTdaknbky4ihW6T9EN3/2tJ88R7JDN7u6R/kpTJ79xZL+nDyVaVqLslLR11342SHnf3v5L0eP528Coa3O7+qrtvzf/8mnL/+N5eyRrSzMxmSlomaUPStaSJmZ0h6RJJd0mSux9x9z8lW1VqTJJ0mplNkjRF0isJ15OYMXYy/aCkb+d//rakqypaVEwS63GbWYuk8yU9nVQNKfQ1SV+QxIU2TzRLUr+kb+XbSBvMbGrSRSXN3V+W9G+Sdkt6VdJ+d3802apS50x3fzX/8+8knZlkMeWSSHCb2TRJD0n6rLv/OYka0sbMrpS0x923JF1LCk2S1CbpDnc/X9JBVckh70Tk+7UfVO4/trMlTTWzjyRbVXp5bgpdVUyjq3hwm1mDcqHd5e7fq/Trp9jFkj5gZj2SviPpUjO7L9mSUqNPUp+7Dx+dbVQuyGvdYkkvuXu/ux+V9D1Jf5NwTWnzezM7S5Ly3/ckXE9ZVHpWiSnXp9zu7usq+dpp5+5fdPeZ7t6i3AmmJ9yd0ZMkd/+dpN+a2ez8XZdJ2pZgSWmxW9JCM5uS/7d1mThpO9rDkq7N/3ytpP9MsJayqfSI+2JJH1VuNNmd/7qiwjUgTNdL6jKz5yXNl/TlhOtJXP4IZKOkrZL+V7l/z1W5UjCK/E6mT0mabWZ9ZvaPkr4iaYmZ7VTuCOUrSdZYLqycBIDAsHISAAJDcANAYAhuAAgMwQ0AgSG4ASAwBDcABIbgBoDAENwAEJj/B3whq/8kFWOXAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "y_ = linear_model(x_train)\n",
- "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
- "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "从上面的例子可以看到,更新之后红色的线跑到了蓝色的线下面,没有特别好的拟合蓝色的真实值,所以我们需要在进行几次更新"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "epoch: 0, loss: 3.135772228240967\n",
- "epoch: 1, loss: 0.355089008808136\n",
- "epoch: 2, loss: 0.30295446515083313\n",
- "epoch: 3, loss: 0.30131959915161133\n",
- "epoch: 4, loss: 0.3006228804588318\n",
- "epoch: 5, loss: 0.2999469041824341\n",
- "epoch: 6, loss: 0.299274742603302\n",
- "epoch: 7, loss: 0.2986060082912445\n",
- "epoch: 8, loss: 0.2979407012462616\n",
- "epoch: 9, loss: 0.29727888107299805\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:11: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n",
- " # This is added back by InteractiveShellApp.init_path()\n"
- ]
- }
- ],
- "source": [
- "for e in range(10): # 进行 10 次更新\n",
- " y_ = linear_model(x_train)\n",
- " loss = get_loss(y_, y_train)\n",
- " \n",
- " w.grad.zero_() # 记得归零梯度\n",
- " b.grad.zero_() # 记得归零梯度\n",
- " loss.backward()\n",
- " \n",
- " w.data = w.data - 1e-2 * w.grad.data # 更新 w\n",
- " b.data = b.data - 1e-2 * b.grad.data # 更新 b \n",
- " print('epoch: {}, loss: {}'.format(e, loss.data[0]))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7f53942d5550>"
- ]
- },
- "execution_count": 21,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "y_ = linear_model(x_train)\n",
- "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
- "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "经过 10 次更新,我们发现红色的预测结果已经比较好的拟合了蓝色的真实值。\n",
- "\n",
- "现在你已经学会了你的第一个机器学习模型了,再接再厉,完成下面的小练习。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**小练习:**\n",
- "\n",
- "重启 notebook 运行上面的线性回归模型,但是改变训练次数以及不同的学习率进行尝试得到不同的结果"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 多项式回归模型"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "下面我们更进一步,讲一讲多项式回归。什么是多项式回归呢?非常简单,根据上面的线性回归模型\n",
- "\n",
- "$$\n",
- "\\hat{y} = w x + b\n",
- "$$\n",
- "\n",
- "这里是关于 x 的一个一次多项式,这个模型比较简单,没有办法拟合比较复杂的模型,所以我们可以使用更高次的模型,比如\n",
- "\n",
- "$$\n",
- "\\hat{y} = w_0 + w_1 x + w_2 x^2 + w_3 x^3 + \\cdots\n",
- "$$\n",
- "\n",
- "这样就能够拟合更加复杂的模型,这就是多项式模型,这里使用了 x 的更高次,同理还有多元回归模型,形式也是一样的,只是出了使用 x,还是更多的变量,比如 y、z 等等,同时他们的 loss 函数和简单的线性回归模型是一致的。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "首先我们可以先定义一个需要拟合的目标函数,这个函数是个三次的多项式"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3\n"
- ]
- }
- ],
- "source": [
- "# 定义一个多变量函数\n",
- "\n",
- "w_target = np.array([0.5, 3, 2.4]) # 定义参数\n",
- "b_target = np.array([0.9]) # 定义参数\n",
- "\n",
- "f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(\n",
- " b_target[0], w_target[0], w_target[1], w_target[2]) # 打印出函数的式子\n",
- "\n",
- "print(f_des)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "我们可以先画出这个多项式的图像"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7f53910144e0>"
- ]
- },
- "execution_count": 23,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出这个函数的曲线\n",
- "x_sample = np.arange(-3, 3.1, 0.1)\n",
- "y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 + w_target[2] * x_sample ** 3\n",
- "\n",
- "plt.plot(x_sample, y_sample, label='real curve')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "接着我们可以构建数据集,需要 x 和 y,同时是一个三次多项式,所以我们取了 $x,\\ x^2, x^3$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 构建数据 x 和 y\n",
- "# x 是一个如下矩阵 [x, x^2, x^3]\n",
- "# y 是函数的结果 [y]\n",
- "\n",
- "x_train = np.stack([x_sample ** i for i in range(1, 4)], axis=1)\n",
- "x_train = torch.from_numpy(x_train).float() # 转换成 float tensor\n",
- "\n",
- "y_train = torch.from_numpy(y_sample).float().unsqueeze(1) # 转化成 float tensor "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "torch.Size([61, 3])\n"
- ]
- }
- ],
- "source": [
- "print(x_train.size())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "接着我们可以定义需要优化的参数,就是前面这个函数里面的 $w_i$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 定义参数和模型\n",
- "w = Variable(torch.randn(3, 1), requires_grad=True)\n",
- "b = Variable(torch.zeros(1), requires_grad=True)\n",
- "\n",
- "# 将 x 和 y 转换成 Variable\n",
- "x_train = Variable(x_train)\n",
- "y_train = Variable(y_train)\n",
- "\n",
- "def multi_linear(x):\n",
- " return torch.mm(x, w) + b"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "我们可以画出没有更新之前的模型和真实的模型之间的对比"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7f5390feddd8>"
- ]
- },
- "execution_count": 28,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出更新之前的模型\n",
- "y_pred = multi_linear(x_train)\n",
- "\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以发现,这两条曲线之间存在差异,我们计算一下他们之间的误差"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 29,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(413.9844, grad_fn=<MeanBackward1>)\n"
- ]
- }
- ],
- "source": [
- "# 计算误差,这里的误差和一元的线性模型的误差是相同的,前面已经定义过了 get_loss\n",
- "loss = get_loss(y_pred, y_train)\n",
- "print(loss)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 自动求导\n",
- "loss.backward()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor([[ -34.1391],\n",
- " [-146.6133],\n",
- " [-215.9149]])\n",
- "tensor([-27.0838])\n"
- ]
- }
- ],
- "source": [
- "# 查看一下 w 和 b 的梯度\n",
- "print(w.grad)\n",
- "print(b.grad)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 32,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 更新一下参数\n",
- "w.data = w.data - 0.001 * w.grad.data\n",
- "b.data = b.data - 0.001 * b.grad.data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7f5390fb8a20>"
- ]
- },
- "execution_count": 33,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出更新一次之后的模型\n",
- "y_pred = multi_linear(x_train)\n",
- "\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "因为只更新了一次,所以两条曲线之间的差异仍然存在,我们进行 100 次迭代"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "epoch 20, Loss: 73.67843\n",
- "epoch 40, Loss: 17.97095\n",
- "epoch 60, Loss: 4.94101\n",
- "epoch 80, Loss: 1.87171\n",
- "epoch 100, Loss: 1.12812\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:14: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n",
- " \n"
- ]
- }
- ],
- "source": [
- "# 进行 100 次参数更新\n",
- "for e in range(100):\n",
- " y_pred = multi_linear(x_train)\n",
- " loss = get_loss(y_pred, y_train)\n",
- " \n",
- " w.grad.data.zero_()\n",
- " b.grad.data.zero_()\n",
- " loss.backward()\n",
- " \n",
- " # 更新参数\n",
- " w.data = w.data - 0.001 * w.grad.data\n",
- " b.data = b.data - 0.001 * b.grad.data\n",
- " if (e + 1) % 20 == 0:\n",
- " print('epoch {}, Loss: {:.5f}'.format(e+1, loss.data[0]))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看到更新完成之后 loss 已经非常小了,我们画出更新之后的曲线对比"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 35,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7f5390f019e8>"
- ]
- },
- "execution_count": 35,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出更新之后的结果\n",
- "y_pred = multi_linear(x_train)\n",
- "\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看到,经过 100 次更新之后,可以看到拟合的线和真实的线已经完全重合了"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": true
- },
- "source": [
- "**小练习:上面的例子是一个三次的多项式,尝试使用二次的多项式去拟合它,看看最后能做到多好**\n",
- "\n",
- "**提示:参数 `w = torch.randn(2, 1)`,同时重新构建 x 数据集**"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.5.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|