|
|
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 线性模型的PyTorch实现\n",
- "\n",
- "本节简单回顾一下线性回归模型,并演示一下如何使用PyTorch来对线性回归模型进行建模和模型参数计算。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 1. 一元线性回归\n",
- "一元线性回归模型比较简单,假设有变量 $x_i$ 和目标 $y_i$,每个 i 对应于一个数据点,希望建立一个模型\n",
- "\n",
- "$$\n",
- "\\hat{y}_i = w x_i + b\n",
- "$$\n",
- "\n",
- "$\\hat{y}_i$ 是预测的结果,希望通过 $\\hat{y}_i$ 来拟合目标 $y_i$,通俗来讲就是找到这个函数拟合 $y_i$ 使得误差最小,即最小化\n",
- "\n",
- "$$\n",
- "\\frac{1}{n} \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n",
- "$$"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "那么如何最小化这个误差呢?"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 2. 梯度下降法\n",
- "\n",
- "在梯度下降法中,首先要明确梯度的概念,梯度在数学上就是导数,如果是一个多元函数,那么梯度就是偏导数。比如一个函数$f(x, y)$,那么 $f$ 的梯度就是 \n",
- "\n",
- "$$\n",
- "(\\frac{\\partial f}{\\partial x},\\ \\frac{\\partial f}{\\partial y})\n",
- "$$\n",
- "\n",
- "可以称为 grad f(x, y) 或者 $\\nabla f(x, y)$。具体某一点 $(x_0,\\ y_0)$ 的梯度就是 $\\nabla f(x_0,\\ y_0)$。\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "梯度有什么意义呢?从几何意义来讲,一个点的梯度值是这个函数变化最快的地方。具体来说,对于函数 f(x, y),在点 $(x_0, y_0)$ 处,沿着梯度 $\\nabla f(x_0,\\ y_0)$ 的方向,函数增加最快,也就是说沿着梯度的方向,能够更快地找到函数的极大值点,或者反过来沿着梯度的反方向,能够更快地找到函数的最小值点。\n",
- "\n",
- "针对一元线性回归问题,就是沿着梯度的反方向,不断改变 $w$ 和 $b$ 的值,最终找到一组最好的 $w$ 和 $b$ 使得误差最小。\n",
- "\n",
- "在更新的时候,需要决定每次更新的幅度就是每次往下走的那一步的长度,这个长度称为学习率,用 $\\eta$ 表示。不同的学习率都会导致不同的结果,学习率太小会导致下降非常缓慢;学习率太大又会导致跳动非常明显。\n",
- "\n",
- "最后我们的更新公式就是\n",
- "\n",
- "$$\n",
- "w = w - \\eta \\frac{\\partial f(w,\\ b)}{\\partial w} \\\\\n",
- "b = b - \\eta \\frac{\\partial f(w,\\ b)}{\\partial b}\n",
- "$$\n",
- "\n",
- "通过不断地迭代更新,最终我们能够找到一组最优的 $w$ 和 $b$。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 3. PyTorch实现\n",
- "\n",
- "上面是原理部分,下面通过一个例子来进一步学习线性模型"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<torch._C.Generator at 0x7fb8b00453f0>"
- ]
- },
- "execution_count": 1,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import torch\n",
- "import numpy as np\n",
- "\n",
- "torch.manual_seed(2021)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[<matplotlib.lines.Line2D at 0x7fb7aef5dc70>]"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAORUlEQVR4nO3db4hs913H8c9ncwl2Y2tC7lo08e42Qmu1GIxjrakU2/gvqTQIeRCcIAZhEbQWn9jqQhHkgoIPoogtQ7BFsjZiTAVF04qlKtSkzLZp/jQqaXp3m0TN3ihVsg/iTb4+OLPZzWTuzpmbOb/znTPvFyyzc+bc4bu/2fu5v3vO748jQgCAvFbaLgAAcDKCGgCSI6gBIDmCGgCSI6gBILlTTbzp6dOnY2Njo4m3BoBO2tnZOR8Ra5NeaySoNzY2NBwOm3hrAOgk27sXe41LHwCQHEENAMkR1ACQHEENAMkR1ACQHEENJLG9LW1sSCsr1eP2dtsVIYtGhucBmM32trS5KR0cVM93d6vnktTvt1cXcqBHDSSwtXUU0ocODqrjAEENJLC3N9txLBeCGkjgzJnZjmO5ENRAAmfPSqurrz62ulodBwhqIIF+XxoMpPV1ya4eBwNuJKLCqA8giX6fYMZk9KgBIDmCGgCSI6gBIDmCGgCSI6gBIDmCGgCSI6gBIDmCGugglkwtq+n2ZsIL0DEsmVpWifZ2RMznnY7p9XoxHA7n/r4AptvYqMJi3Pq6dO5c6Wq6b17tbXsnInqTXuPSB9AxLJlaVon2JqiBjmHJ1LJKtDdBDXQMS6aWVaK9CWqgY1gytawS7c3NRABIgJuJQAsYy4x5qRXUtn/N9uO2H7P9Kdvf0nRhwCI7HFu7uytFVI933CGdPk1gY3ZTg9r2NZJ+VVIvIt4h6TJJtzddGLDItraOJkAc9/zzVYAT1phF3UsfpyS9wfYpSauSnm2uJGDxnTSG9uCgCvJxXCrBxUwN6oh4RtLvSdqT9O+SvhkRnx0/z/am7aHt4f7+/vwrBRbItDG040E+6VIJPW8cqnPp4ypJt0p6i6TvlHSF7TvGz4uIQUT0IqK3trY2/0qBBTJpbO1x40E+6VLJxXreWD51Ln38uKSvR8R+RPyfpPsl3dhsWcBiOxxbe/XVr31t0mQIpn3jJHWCek/Su2yv2rakmyQ90WxZWCZdvTbb70vnz0v33DN9MgTTvnGSOteoH5J0n6QvSXp09GcGDdeFJbEM12b7/WoVtZdfrh4nzVhj2jdOwsxEtIolOY9sb1fXpPf2qp702bNM+14mJ81MJKjRqpWVqic9zq56oMCyYAo50uLaLDAdQY1WcW0WmI6gRqtYkhOYjs1t0bp+n2AGTkKPGgCSI6gBIDmCGgCSI6gBIDmCGgCSI6gBIDmCGgCSI6gBIDmCGgCSI6ixkLq62UAX8NnMH0GNhbMImw0sa1gtwmeziFiPGgsn+2YDh2F1fLPa1dXlWGwq+2eTGRsHoFOybzawzGGV/bPJjI0D0CnZNxtY5h3Fs382i4qgxsLJvtnAModV9s9mURHUWDjZNxtY5rDK/tksKq5RAw1gR3HM6qRr1OzwAjSAXWswT1z6AIDkCGoASI6gBoDkCGoASI6gBoDkCGoASI6gBoDkCGoASI6gBoDkCGoASI6gBoDkCGoASC51UC/rvnOZ8ZkA5U1dPc/22yT92bFD10n6aETc1VRR0mv3nTvcJFNiVbK28JkA7ZhpPWrbl0l6RtIPR8SEXeEq81iPepn3ncuKzwRozjz3TLxJ0tdOCul5WeZ957LiMwHaMWtQ3y7pU5NesL1pe2h7uL+//7oLW+Z957LiMwHaUTuobV8u6QOS/nzS6xExiIheRPTW1tZed2HLvO9cVnwmQDtm6VHfLOlLEfGfTRVzHJtk5sNnArSj9s1E2/dK+kxEfGLauWxuCwCzed03E22vSvoJSffPszAAwHS1gjoiDiLi6oj4ZtMFAcCiKDUBbOqEFwDAa5WcAJZ6CjkAZLW1dRTShw4OquPzRlADwCUoOQGMoAaAS1ByAhhBDQCXoOQEMIIaAC5ByQlgjPoAgEvU75eZmUuPGkBndHVjC3rUADqhyxtb0KMG0AklxzWXRlAD6IQub2xBUAPohC5vbEFQA+iELm9sQVAD6IQub2zBqA8AnVFqXHNp9KgBIDmCGgCSI6gBIDmCGgCSI6gBIDmCGgCSI6iBGXR1dTbkxjhqoKYur86G3OhRAzV1eXU25EZQAzV1eXU25EZQAzV1eXU25EZQt4gbU4uly6uzITeCuiWHN6Z2d6WIoxtThHVeXV6dDbk5Iub+pr1eL4bD4dzft0s2NqpwHre+Lp07V7oaAG2zvRMRvUmv0aNuCTemANRFULeEG1MA6iKoW8KNKQB1EdQt4cYUgLqYQt6irm4bBGC+6FEDQHIENQAkVyuobV9p+z7b/2L7Cds/0nRhAIBK3WvUvy/pgYi4zfblklan/QEAwHxMDWrbb5L0Hkm/IEkR8aKkF5stCwBwqM6lj+sk7Uv6hO0v277b9hXjJ9netD20Pdzf3597oQCwrOoE9SlJN0j6WET8gKQXJH1k/KSIGERELyJ6a2trcy4TAJZXnaB+WtLTEfHQ6Pl9qoIbAFDA1KCOiP+Q9A3bbxsduknSVxutCgDwirqjPj4oaXs04uMpSXc2VxIA4LhaQR0RD0uauE4qAKBZzEwEgOQIagBIjqAGgOQIagBIjqAGxmxvV5sPr6xUj+wMj7axcQBwzPa2tLkpHRxUz3d3q+cSmzygPfSogWO2to5C+tDBQXUcaAtBDRyztzfbcaAEgho45syZ2Y4DJRDUwDFnz0qrY9tirK5Wx4G2ENQdwmiF16/flwYDaX1dsqvHwYAbiWgXoz46gtEK89Pv02bIhR51RzBaAegugrojGK0AdBdB3RGMVgC6i6DuCEYrAN1FUHcEoxWA7mLUR4cwWgHoJnrUAJAcQQ0AyRHUAJAcQQ0AyRHUAJAcQQ0AyRHUAJAcQQ0AyRHUAJAcQY1LwiYFQDlMIcfM2KQAKIseNWbGJgVAWQQ1ZsYmBUBZBDVmxiYFQFkENWbGJgVAWQQ1ZsYmBUBZjPrAJWGTAqAcetQAkFytHrXtc5L+V9JLki5ERK/JogAAR2a59PHeiDjfWCUAgIm49AEAydUN6pD0Wds7tjcnnWB70/bQ9nB/f39+FQLAkqsb1O+OiBsk3Szpl22/Z/yEiBhERC8iemtra3MtEgCWWa2gjohnR4/PSfq0pHc2WRQA4MjUoLZ9he03Hn4v6SclPdZ0YQCASp1RH2+W9Gnbh+f/aUQ80GhVAIBXTA3qiHhK0vUFagEATMDwPABIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOQIagBIjqAGgOTSBPX2trSxIa2sVI/b221XBAA5nGq7AKkK5c1N6eCger67Wz2XpH6/vboAIIMUPeqtraOQPnRwUB0HgGWXIqj39mY7DgDLJEVQnzkz23EAWCYpgvrsWWl19dXHVler4wCw7GoHte3LbH/Z9l/Pu4h+XxoMpPV1ya4eBwNuJAKANNuojw9JekLSm5oopN8nmAFgklo9atvXSnq/pLubLQcAMK7upY+7JP26pJcvdoLtTdtD28P9/f151AYAUI2gtv0zkp6LiJ2TzouIQUT0IqK3trY2twIBYNnV6VG/W9IHbJ+TdK+k99m+p9GqAACvmBrUEfEbEXFtRGxIul3S5yLijsYrAwBIamitj52dnfO2d2f4I6clnW+ilgVDO1RoB9rg0DK1w/rFXnBElCxkchH2MCJ6bdfRNtqhQjvQBodoh0qKmYkAgIsjqAEguSxBPWi7gCRohwrtQBscoh2U5Bo1AODisvSoAQAXQVADQHLFgtr2T9v+V9tP2v7IhNdt+w9Grz9i+4ZStZVUox36o5//EdtfsH19G3U2bVo7HDvvh2y/ZPu2kvWVUqcdbP+Y7YdtP277H0rXWEKNvxffZvuvbH9l1A53tlFnayKi8S9Jl0n6mqTrJF0u6SuSvnfsnFsk/a0kS3qXpIdK1Fbyq2Y73CjpqtH3Ny9rOxw773OS/kbSbW3X3dLvw5WSvirpzOj5t7ddd0vt8JuSfnf0/Zqk/5J0edu1l/oq1aN+p6QnI+KpiHhR1Zoht46dc6ukP4nKg5KutP0dheorZWo7RMQXIuK/R08flHRt4RpLqPP7IEkflPQXkp4rWVxBddrh5yTdHxF7khQRXWyLOu0Qkt5o25K+VVVQXyhbZntKBfU1kr5x7PnTo2OznrPoZv0Zf1HV/zK6Zmo72L5G0s9K+njBukqr8/vwVklX2f687R3bP1+sunLqtMMfSnq7pGclPSrpQxFx0WWXu6aRtT4m8IRj4+MC65yz6Gr/jLbfqyqof7TRitpRpx3ukvThiHip6kR1Up12OCXpByXdJOkNkv7Z9oMR8W9NF1dQnXb4KUkPS3qfpO+W9He2/yki/qfh2lIoFdRPS/quY8+vVfUv46znLLpaP6Pt71e1m87NEfF8odpKqtMOPUn3jkL6tKRbbF+IiL8sUmEZdf9enI+IFyS9YPsfJV0vqUtBXacd7pT0O1FdpH7S9tclfY+kL5YpsWWFbhackvSUpLfo6GbB942d8369+mbiF9u+gN9SO5yR9KSkG9uut812GDv/k+rmzcQ6vw9vl/T3o3NXJT0m6R1t195CO3xM0m+Nvn+zpGcknW679lJfRXrUEXHB9q9I+oyqO7x/HBGP2/6l0esfV3Vn/xZVIXWg6l/QTqnZDh+VdLWkPxr1Ji9Ex1YPq9kOnVenHSLiCdsPSHpE1VZ4d0fEY+1VPX81fx9+W9InbT+qqjP34YhYluVPmUIOANkxMxEAkiOoASA5ghoAkiOoASA5ghoAkiOoASA5ghoAkvt/nElIdlbTfhoAAAAASUVORK5CYII=\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 生层测试数据\n",
- "x_train = np.random.rand(20, 1)\n",
- "y_train = x_train * 3 + 4 + 3*np.random.rand(20,1)\n",
- "\n",
- "# 画出图像\n",
- "import matplotlib.pyplot as plt\n",
- "%matplotlib inline\n",
- "\n",
- "plt.plot(x_train, y_train, 'bo')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 转换成 Tensor\n",
- "x_train = torch.from_numpy(x_train)\n",
- "y_train = torch.from_numpy(y_train)\n",
- "\n",
- "# 定义参数 w 和 b\n",
- "w = torch.randn(1, requires_grad=True) # 随机初始化\n",
- "b = torch.zeros(1, requires_grad=True) # 使用 0 进行初始化"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 构建线性回归模型\n",
- "def linear_model(x):\n",
- " return x * w + b\n",
- "\n",
- "def logistc_regression(x):\n",
- " return torch.sigmoid(x*w+b) "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "y_ = linear_model(x_train)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "经过上面的步骤我们就定义好了模型,在进行参数更新之前,我们可以先看看模型的输出结果长什么样"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7aee58d00>"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWs0lEQVR4nO3df2xd5X3H8c/XiSE4BYoSDwGp7XQaawKJl8RZQ4vISiAJBK0g+KNgBmmpIggwNA3RsqgrUpp1SFPa0hWoxTI04g41oaXalBVWfpSt5ZdDHVoSSFiwgwMTjkEZxESJ4+/+OL754V7b59r3nPPcc98vybq5517uffzc5MNzn5/m7gIAhKsm6wIAAEZHUANA4AhqAAgcQQ0AgSOoASBwk5N40enTp3tTU1MSLw0AubR169Z97l5f7LFEgrqpqUkdHR1JvDQA5JKZdY/0GF0fABA4ghoAAkdQA0DgEumjLubw4cPq6enRwYMH03rL3JsyZYpmzJih2trarIsCIEGpBXVPT49OPfVUNTU1yczSetvccnf19fWpp6dHM2fOzLo4ABKUWtfHwYMHNW3aNEK6TMxM06ZN4xtKjrS3S01NUk1NdNvennWJEIrUWtSSCOkyoz7zo71dWrVK6u+P7nd3R/clqbU1u3IhDAwmAgFYs+ZYSBf090fXAYI6pqamJu3bty/rYiCn9uwp7TqqS7BBnWR/nbtrcHCwfC8ITFBDQ2nXUV2CDOpCf113t+R+rL9uImHd1dWlWbNmafXq1Zo/f77Wrl2rhQsXau7cufrmN7959HlXXnmlFixYoPPOO09tbW1l+G2Asa1bJ9XVnXitri66DgQZ1En1173xxhu64YYbdO+992rv3r166aWX1NnZqa1bt+q5556TJG3YsEFbt25VR0eH7rvvPvX19U3sTYEYWlultjapsVEyi27b2hhIRCTVWR9xJdVf19jYqEWLFunOO+/Uk08+qXnz5kmSPvroI+3atUsXXXSR7rvvPv30pz+VJL399tvatWuXpk2bNrE3BmJobSWYUVyQLeqk+uumTp0qKeqjvvvuu9XZ2anOzk69+eabuummm/Tss8/qF7/4hZ5//nlt27ZN8+bNY54yKhJzstOVdH0HGdRJ99ctW7ZMGzZs0EcffSRJ2rt3r9577z3t379fZ5xxhurq6vT666/rhRdeKM8bAilKYowHI0ujvoMM6qT765YuXarrrrtOF1xwgebMmaNrrrlGH374oZYvX66BgQHNnTtX3/jGN7Ro0aLyvCGQIuZkpyuN+jZ3L9+rDWlpafHhBwfs2LFDs2bNKvt7VTvqFcPV1EQtu+HMJGalll+56tvMtrp7S9H3iPkCf2Vmr5nZ78zsX81sSvy3B6pTVv3EzMlOVxr1PWZQm9k5kv5SUou7ny9pkqQvla8IQP4U67e8/npp+vTkA5s52elKo77j9lFPlnSKmU2WVCfpnfIVAcifYv2WktTXN/JAU7la4MzJTlca9T1mULv7Xkn/IGmPpHcl7Xf3J8tXBFS7PE4lG23Of7GBpnLPHGhtlbq6oj7Sri5COmlJ13ecro8zJH1R0kxJZ0uaambXF3neKjPrMLOO3t7e8pYSuZXXqWRj9U8OD3JmamA0cbo+LpH0lrv3uvthST+R9LnhT3L3NndvcfeW+vr6cpcTOZXXgCrWb3m84UHO7nkYTZyg3iNpkZnVWbRT/RJJO5ItVrYefvhhvfPOsW74r371q9q+ffuEX7erq0s/+tGPSv7vVq5cqc2bN0/4/UOU14Aq9FsW232g2EATMzUwmjh91C9K2izpFUm/Hfpvkt9WLsOOy+FB/dBDD2n27NkTft3xBnWe5TmgWlulffukjRvHHmhipgZG5e5l/1mwYIEPt3379t+7NqKNG93r6tyjbsvop64uuj4BjzzyiC9cuNCbm5t91apVPjAw4DfeeKOfd955fv755/v69et906ZNPnXqVD/33HO9ubnZ+/v7ffHixf7yyy+7u/vUqVP9rrvu8vnz5/uSJUv8xRdf9MWLF/vMmTP9Zz/7mbu7v/XWW37hhRf6vHnzfN68ef6rX/3K3d0/+9nP+mmnnebNzc2+fv16HxgY8DvvvNNbWlp8zpw5/uCDD7q7++DgoN96660+a9Ysv/zyy/2yyy7zTZs2Ff2dSqrXACX0UVekjRvdGxvdzaLbaqyDaiapw0fI1DCDurHxxH+5hZ/GxtJ+82Hvf8UVV/ihQ4fc3f2WW27xe+65xy+55JKjz/nggw/c3U8I5uH3JfmWLVvc3f3KK6/0Sy+91A8dOuSdnZ3e3Nzs7u4HDhzwjz/+2N3dd+7c6YX6eOaZZ3zFihVHX/eHP/yhr1271t3dDx486AsWLPDdu3f7Y4895pdccokPDAz43r17/fTTT89tULsTUID76EEd5DanSXRcPvXUU9q6dasWLlwoSfr444+1fPly7d69W7fffrtWrFihpUuXjvk6J510kpYvXy5JmjNnjk4++WTV1tZqzpw56urqkiQdPnxYt912mzo7OzVp0iTt3Lmz6Gs9+eSTevXVV4/2P+/fv1+7du3Sc889p2uvvVaTJk3S2WefrYsvvnjcv3clYHtPYHRhBnVDQzRPq9j1cXJ33Xjjjfr2t799wvV169bpiSee0A9+8AP9+Mc/1oYNG0Z9ndra2qOnf9fU1Ojkk08++ueBgQFJ0ne+8x2deeaZ2rZtmwYHBzVlSvEV9+6u73//+1q2bNkJ17ds2cIJ4wCOCnL3vCRGVpYsWaLNmzfrvffekyS9//776u7u1uDgoK6++mqtXbtWr7zyiiTp1FNP1Ycffjju99q/f7/OOuss1dTU6JFHHtGRI0eKvu6yZcv0wAMP6PDhw5KknTt36sCBA7rooov06KOP6siRI3r33Xf1zDPPjLsseZXHRTJ5wWdTfmG2qAvfg9esibo7GhqikJ7A9+PZs2frW9/6lpYuXarBwUHV1tZq/fr1uuqqq44edFtoba9cuVI333yzTjnlFD3//PMlv9fq1at19dVXa9OmTfrCF75w9MCCuXPnavLkyWpubtbKlSt1xx13qKurS/Pnz5e7q76+Xo8//riuuuoqPf3005ozZ47OPfdcLV68eNy/dx4VFskU5l8XFslI4XShtLeX9a9vxaiEz6YSsc1phavGem1qKt4z1tgYLd/N2vCwkqIvhNWw30bon03IJrzNKRCS0BfJ5HW1ZRyhfzaViqBGxQl9kUw1h1Xon02lSjWok+hmqWbVWp+hr+Kr5rAK/bOpVKkF9ZQpU9TX11e14VJu7q6+vr4Rp/7lWej7LVdzWIX+2VSq1AYTDx8+rJ6eHh08eLDs71etpkyZohkzZqi2tjbromCYap31gfEbbTAxtaAGAIyMWR8AUMEIagAIXNBBzVLU8PCZAOkLcwm5WIoaIj4TIBvBDiayFDU8fCZAcipyMLGaV3eFis8EyEawQV3Nq7tCxWcCZCPYoK7m1V2h4jMBTpTW4HqwQc1S1PDwmQDHFAbXu7ujQ10Lg+tJhHWwg4kAELJyD65X5GAiAIQszcF1ghpAbqS5ICvNwXWCGkAupNlnLKU7uE5QA8iFtI9AS3NwncFEALlQUxO1pIczkwYH0y9PqRhMBJB7eV6QRVADyIU8L8giqAHkQp4XZBHUQAnYjztsra3RYpPBweg2DyEtBbwfNRAa9uNGVmhRAzGlPf0LKCCoM8TX6MrCftzICkGdkbRXUWHi8jz9C2EjqDPC1+jKk+fpXwgbQZ0RvkZXnjxP/0LYmPWRkYaG4nvZ8jU6bK2tBDPSF6tFbWafNLPNZva6me0wswuSLlje8TUaQFxxuz6+J+nn7v4ZSc2SdiRXpOrA12gAcY25e56ZnSZpm6RPe8yt9tg9DwBKM9Hd8z4tqVfSP5vZb8zsITObWuRNVplZh5l19Pb2TrDIAICCOEE9WdJ8SQ+4+zxJByR9ffiT3L3N3VvcvaW+vr7MxQSA6hUnqHsk9bj7i0P3NysKbgBACsYManf/X0lvm9kfD11aIml7oqUCABwVd9bH7ZLazexVSX8i6e8SKxGQMfZgQWhiLXhx905JRUcjgTxhK1OEiCXkOUJLcOLYgwUhYgl5TtASLA/2YEGIaFHnBC3B8mArU4SIoM4JWoLlwR4sCBFBnRO0BMuDPVgQIoI6J2gJlk9eT7JG5SKoc4KWIJBfzPrIETa1B/KJFjUABI6gxriwuAZID10fKBmLa4B00aJGyVhcA6SLoEbJWFwDpIugRslYXAOki6BGyVhcA6SLoEbJWFwDpItZHxgXFtcA6aFFDQCBI6gBIHAENQAEjqAGgMAR1AAQOIIaAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDAEdQAEDiCGgACR1ADQOAIagAIHEENAIELJqjb26WmJqmmJrptb8+6RAAQhiCO4mpvl1atkvr7o/vd3dF9ieOeACCIFvWaNcdCuqC/P7oOANUudlCb2SQz+42Z/Xu5C7FnT2nXAaCalNKivkPSjiQK0dBQ2nUAqCaxgtrMZkhaIemhJAqxbp1UV3fitbq66DoAVLu4LervSrpL0uBITzCzVWbWYWYdvb29JRWitVVqa5MaGyWz6LatjYFEAJBiBLWZXSHpPXffOtrz3L3N3VvcvaW+vr7kgrS2Sl1d0uBgdEtIA0AkTov685L+3My6JD0q6WIz25hoqQAAR40Z1O5+t7vPcPcmSV+S9LS7X594yQAAkgKZRw0AGFlJKxPd/VlJzyZSEgBAUbSoASBwBDUABI6gBoDAEdQAEDiCGgACR1ADQOAIagAIHEENAIEjqAEgcAQ1AASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHAENQAEjqAGgMAR1AAQOIIaAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDAEdQAEDiCGgACR1ADQOAIagAIHEENAIEjqAEgcAQ1AASOoAaAwBHUADBe7e1SU5NkJk2eHN02NUXXy2hyWV8NAKpFe7u0apXU3x/dP3Ikuu3ujq5LUmtrWd6KFjUAjMeaNcdCerj+/ujxMhkzqM3sU2b2jJntMLPXzOyOsr07AFSqPXsm9ngJ4rSoByT9tbvPkrRI0q1mNrtsJQCAStTQMLHHSzBmULv7u+7+ytCfP5S0Q9I5ZSsBAJRLYXCvpiaRQb0TrFsn1dUVf6yuLnq8TErqozazJknzJL1Y5LFVZtZhZh29vb1lKh4AxFQY3OvultyPDeolFdatrVJbm9TYGN2fNCm6bWyMrpdpIFGSzN3jPdHsE5J+KWmdu/9ktOe2tLR4R0dHGYoHADE1NUXhPFxjo9TVlXZpSmZmW929pdhjsVrUZlYr6TFJ7WOFNABkYqTBuzIO6mUlzqwPk/RPkna4+/rkiwQA4zDS4F0ZB/WyEqdF/XlJfyHpYjPrHPq5POFyAahUaQ7oHa/Y4F6ZB/WyMubKRHf/b0mWQlkAVLrhq/USWKU3osLrr1kTdXc0NEQhnfT7piD2YGIpGEwEqlSFD+hlacKDiQAQS44H9LJEUAMonxwP6GWJoAbyiAG9XCGogbxJe4Xe8Y5frWeWyCq9asRgIpA3DOhVJAYTgUqwevWxU0ImT47ujwcDerlDUANJi9NfvHq19MADx04JOXIkuj+esGZAL3cIaiAJ7e3S9OlR6/j668fuL25rK/46I10fDQN6uUNQA+XW3i595StSX1/xx4sd01RoSQ830vXRMKCXOwQ1UIo43Rhr1kiHDo3+OsP7iwt7GQ830vWxtLZGA4eDg9EtIV3RCGpguJHCOO60tziDdsP7iwv7YQw30nVUFYIaON5oYVzs1Oli3RhjDdoV6y++/37plluOtaAnTYru33//xH4f5ALzqIHjjTYHec+eKLyHM4u6GAoKfdTFuj+mTZO+9z26IvB7mEeN/ElqifRoc5DjTntrbZU2bIhCuWDaNGnjRmnfPkIaJSOoUXmSXCI9WhiXMu2ttTUKZffoh4DGBBDUyF6preO4fcXjMVoYM+0NGaGPGtkafiKIFAXjaAFYUxOvr3giZcrhKSEI22h91AQ1sjWeDYTYdAg5xGAiwjWeDYRYIo0qQ1CjNIX+5MIOb2YTm3Uxng2E6CtGlSGoEd/xsy2kY/tQTGTWxXhbxyyRRhUhqBFfsdkWBeOddUHrGBgTg4mIb6TZFgXlmnUBVCEGE6tF0geajrWHBRvTA4kgqPMijQNNi/UnFzDrAkgMQZ0XSa7WKzi+P1k6ttMb/cpAogjqLJWzqyKtA00Lsy3cpYGB6JZZF0CiCOqslLurggNNgdwiqLNS7q4KVusBuUVQZ6XcXRXMRwZya3LWBahaDQ3FNxaaSFdFayvBDOQQLeqs0FUBICaCOit0VQCIia6PLNFVASAGWtQAELiwgzrpvSsAoALECmozW25mb5jZm2b29URKMjyUV69Ofu8KAKgAY25zamaTJO2UdKmkHkkvS7rW3beP9N+UvM1psQNOzYpvqcm5eAByaKLbnP6ppDfdfbe7H5L0qKQvlrOARVfpjfQ/kHLvXQEAgYsT1OdIevu4+z1D105gZqvMrMPMOnp7e0srRSnhy94VAKpMnKC2Itd+r7nr7m3u3uLuLfX19aWVYqTwtWFvzYIQAFUoTlD3SPrUcfdnSHqnrKUYaZXezTezIARA1Yuz4OVlSX9kZjMl7ZX0JUnXlbUUhfBdsybqBmloiMKbUAaAsYPa3QfM7DZJT0iaJGmDu79W9pKwSg8Aioq1hNzdt0jaknBZAABFhL0yEQBAUANA6AhqAAgcQQ0AgRtzr49xvahZr6Qi50yNaLqkfWUvSOWhHiLUA3VQUE310OjuRVcLJhLUpTKzjpE2I6km1EOEeqAOCqiHCF0fABA4ghoAAhdKULdlXYBAUA8R6oE6KKAeFEgfNQBgZKG0qAEAIyCoASBwqQX1WAfkWuS+ocdfNbP5aZUtTTHqoXXo93/VzH5tZs1ZlDNpcQ9MNrOFZnbEzK5Js3xpiVMPZvZnZtZpZq+Z2S/TLmMaYvy7ON3M/s3Mtg3Vw5ezKGdm3D3xH0Xbo/6PpE9LOknSNkmzhz3nckn/oehEmUWSXkyjbGn+xKyHz0k6Y+jPl1VrPRz3vKcV7dx4TdblzujvwyclbZfUMHT/D7Iud0b18DeS7h36c72k9yWdlHXZ0/pJq0Ud54DcL0r6F4+8IOmTZnZWSuVLy5j14O6/dvcPhu6+oOhEnbyJe2Dy7ZIek/RemoVLUZx6uE7ST9x9jyS5ex7rIk49uKRTzcwkfUJRUA+kW8zspBXUcQ7IjXWIboUr9Xe8SdG3jLwZsx7M7BxJV0l6MMVypS3O34dzJZ1hZs+a2VYzuyG10qUnTj38o6RZio4B/K2kO9x9MJ3iZS/WwQFlEOeA3FiH6Fa42L+jmX1BUVBfmGiJshGnHr4r6WvufsSGH3KcH3HqYbKkBZKWSDpF0vNm9oK770y6cCmKUw/LJHVKuljSH0r6TzP7L3f/v4TLFoS0gjrOAbnJH6KbvVi/o5nNlfSQpMvcvS+lsqUpTj20SHp0KKSnS7rczAbc/fFUSpiOuP8u9rn7AUkHzOw5Sc2S8hTUcerhy5L+3qNO6jfN7C1Jn5H0UjpFzFhKgwWTJe2WNFPHBgvOG/acFTpxMPGlrDvwM6qHBklvSvpc1uXNsh6GPf9h5XMwMc7fh1mSnhp6bp2k30k6P+uyZ1APD0i6Z+jPZyo6aHt61mVP6yeVFrWPcECumd089PiDikb2L1cUUv2K/g+aKzHr4W8lTZN0/1BrcsBztntYzHrIvTj14O47zOznkl6VNCjpIXf/XXalLr+Yfx/WSnrYzH6rqDH3NXevlu1PWUIOAKFjZSIABI6gBoDAEdQAEDiCGgACR1ADQOAIagAIHEENAIH7f906gbeq7S1IAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
- "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "这个时候需要计算我们的误差函数,也就是\n",
- "\n",
- "$$\n",
- "E = \\sum_{i=1}^n(\\hat{y}_i - y_i)^2\n",
- "$$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(704.5194, dtype=torch.float64, grad_fn=<SumBackward0>)\n"
- ]
- }
- ],
- "source": [
- "# 计算误差\n",
- "def get_loss(y_, y):\n",
- " return torch.sum((y_ - y) ** 2)\n",
- "\n",
- "loss = get_loss(y_, y_train)\n",
- "print(loss)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "定义好了误差函数,接下来我们需要计算 $w$ 和 $b$ 的梯度了,这时得益于 PyTorch 的自动求导,不需要手动去算梯度就可以得到计算好的梯度值。手动计算的$w$ 和 $b$ 的梯度分别是\n",
- "\n",
- "$$\n",
- "\\frac{\\partial}{\\partial w} = \\frac{2}{n} \\sum_{i=1}^n x_i(w x_i + b - y_i) \\\\\n",
- "\\frac{\\partial}{\\partial b} = \\frac{2}{n} \\sum_{i=1}^n (w x_i + b - y_i)\n",
- "$$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 自动求导\n",
- "loss.backward()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor([-117.3280])\n",
- "tensor([-234.3059])\n"
- ]
- }
- ],
- "source": [
- "# 查看 w 和 b 的梯度\n",
- "print(w.grad)\n",
- "print(b.grad)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 更新一次参数\n",
- "w.data = w.data - 1e-2 * w.grad.data\n",
- "b.data = b.data - 1e-2 * b.grad.data"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "更新完成参数之后,我们再一次看看模型输出的结果"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7ac5d4a30>"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "y_ = linear_model(x_train)\n",
- "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
- "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "从上面的例子可以看到,更新之后红色的线跑到了蓝色的线下面,没有特别好的拟合蓝色的真实值,所以我们需要在进行几次更新"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "epoch: 19, loss: 21.218688263809952\n",
- "epoch: 39, loss: 19.55484974487415\n",
- "epoch: 59, loss: 18.824963796393106\n",
- "epoch: 79, loss: 18.50477882805245\n",
- "epoch: 99, loss: 18.364321569910107\n"
- ]
- }
- ],
- "source": [
- "for e in range(100): # 进行 100 次更新\n",
- " y_ = linear_model(x_train)\n",
- " loss = get_loss(y_, y_train)\n",
- " \n",
- " w.grad.zero_() # 注意:归零梯度\n",
- " b.grad.zero_() # 注意:归零梯度\n",
- " loss.backward()\n",
- " \n",
- " w.data = w.data - 1e-2 * w.grad.data # 更新 w\n",
- " b.data = b.data - 1e-2 * b.grad.data # 更新 b \n",
- " if (e + 1) % 20 == 0:\n",
- " print('epoch: {}, loss: {}'.format(e, loss.item()))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7ac548be0>"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWwklEQVR4nO3df2zc9X3H8dfbTiA4SSlNPARkiWESjEB+EBuWtmsIkIbQoAKi0spMId2QgQzGJrUDFmllSilDmpqSrvxwEUWrTdEIBXUTa1n5MZAITW1qKCUlocEOCWxxDEohTpbEee+Pr+045i7+Xvy97/dz33s+JOt8d9/cfe5zl5c/9/l+fpi7CwAQrpqsCwAAODKCGgACR1ADQOAIagAIHEENAIGbUI4HnT59ujc0NJTjoQEglzo7O3e6e32h+8oS1A0NDero6CjHQwNALplZT7H76PoAgMAR1AAQOIIaAAJXlj7qQvbv369t27Zp7969aT1l7k2aNEkzZszQxIkTsy4KgDJKLai3bdumqVOnqqGhQWaW1tPmlrurr69P27Zt06mnnpp1cQCUUWpdH3v37tW0adMI6YSYmaZNm8Y3lBxpb5caGqSamuiyvT3rEiEUqbWoJRHSCaM+86O9XWppkfr7o+s9PdF1SWpuzq5cCAMnE4EArFp1KKSH9PdHtwMEdUwNDQ3auXNn1sVATm3dWtrtqC7BBnU5++vcXQcPHkzuAYFxmjmztNtRXYIM6qH+up4eyf1Qf914wrq7u1tnnnmmVq5cqQULFmj16tU699xzNXfuXH3jG98YPu7yyy9XY2OjzjrrLLW2tibwaoCx3XmnVFd3+G11ddHtQJBBXa7+ujfffFPXXHON7r77bm3fvl0bNmxQV1eXOjs79cILL0iSHnroIXV2dqqjo0Nr165VX1/f+J4UiKG5WWptlWbNksyiy9ZWTiQikuqoj7jK1V83a9YsLVy4UF/72tf09NNP65xzzpEkffTRR9q8ebMWLVqktWvX6oknnpAkvfPOO9q8ebOmTZs2vicGYmhuJphRWJBBPXNm1N1R6PbxmDx5sqSoj/r222/X9ddff9j9zz//vH7+859r/fr1qqur0+LFixmnDCBzQXZ9lLu/7uKLL9ZDDz2kjz76SJK0fft27dixQ7t27dIJJ5yguro6/fa3v9XLL7+czBMCwDgE2aIe+vq3alXU3TFzZhTSSX0tXLp0qTZu3KhPf/rTkqQpU6aora1Ny5Yt0/3336+5c+fqjDPO0MKFC5N5QgAYB3P3xB+0qanJR28csHHjRp155pmJP1e1o16BfDCzTndvKnRfkF0fAIBDCGoACBxBDQCBI6iBHGLJ1HSVu76DHPUB4OixZGq60qhvWtRAzrBkaora23X+tQ36sN+0XxM0INPbatBl/e2J1jdBXcDDDz+sd999d/j6ddddpzfeeGPcj9vd3a1HHnmk5H+3YsUKrVu3btzPj+rAkqkpGWxKzxjoUY2kCRpQjaQG9ej7atFne5Lr/wg3qDPsZBsd1A8++KBmz5497sc92qAGSsGSqSkp9NVl0GT16+7a5JrUYQZ1OdY5ldTW1qbzzjtP8+fP1/XXX6+BgQGtWLFCZ599tubMmaM1a9Zo3bp16ujoUHNzs+bPn689e/Zo8eLFGprAM2XKFN16661qbGzUkiVLtGHDBi1evFinnXaafvKTn0iKAvlzn/ucFixYoAULFuill16SJN1222168cUXNX/+fK1Zs0YDAwP6+te/Przc6gMPPCApWovkpptu0uzZs7V8+XLt2LFjXK8b1YUlU1MyxleUUwYS/Arj7on/NDY2+mhvvPHGx24ratYs9yiiD/+ZNSv+YxR4/ksvvdT37dvn7u433nij33HHHb5kyZLhYz744AN3dz///PP9l7/85fDtI69L8qeeesrd3S+//HL//Oc/7/v27fOuri6fN2+eu7vv3r3b9+zZ4+7umzZt8qH6eO6553z58uXDj/vAAw/46tWr3d1979693tjY6Fu2bPHHH3/clyxZ4gcOHPDt27f78ccf74899ljR1wWM1tYW/Xcxiy7b2rIuUQ4Vy6mjzCtJHV4kU8Mc9VGGTrZnnnlGnZ2dOvfccyVJe/bs0bJly7RlyxbdfPPNWr58uZYuXTrm4xxzzDFatmyZJGnOnDk69thjNXHiRM2ZM0fd3d2SpP379+umm25SV1eXamtrtWnTpoKP9fTTT+u1114b7n/etWuXNm/erBdeeEFXXXWVamtrdfLJJ+vCCy886teN6sSSqSm4887Dh3uMlPBXmDC7PsrQyebuuvbaa9XV1aWuri69+eabuueee/Tqq69q8eLF+t73vqfrrrtuzMeZOHHi8O7fNTU1OvbYY4d/P3DggCRpzZo1OvHEE/Xqq6+qo6ND+/btK1qm7373u8Nlevvtt4f/WLDDeOVjLHPOjdztQZJqa6PLMuz6ECuozexvzew3Zva6mf3IzCYlVoJCytDJdtFFF2ndunXD/b3vv/++enp6dPDgQV155ZVavXq1XnnlFUnS1KlT9eGHHx71c+3atUsnnXSSampq9MMf/lADAwMFH/fiiy/Wfffdp/3790uSNm3apN27d2vRokV69NFHNTAwoPfee0/PPffcUZcF2Sh0muXqq6Xp0wnsXGlulrq7ozf5wIHosrs78a8zY3Z9mNkpkv5a0mx332Nm/ybpy5IeTrQkI5VhndPZs2frm9/8ppYuXaqDBw9q4sSJ+va3v60rrrhieKPbu+66S1I0HO6GG27Qcccdp/Xr15f8XCtXrtSVV16pxx57TBdccMHwhgVz587VhAkTNG/ePK1YsUK33HKLuru7tWDBArm76uvr9eSTT+qKK67Qs88+qzlz5uj000/X+eeff9SvG9koNiCgr4/JJyjdmMucDgb1y5LmSfq9pCclrXX3p4v9G5Y5TQ/1GqaamqhxVcysWVHDa6T29vKtwY7wjWuZU3ffLumfJW2V9J6kXYVC2sxazKzDzDp6e3vHW2agoo11OmX0efEyjUhFTowZ1GZ2gqTLJJ0q6WRJk83s6tHHuXuruze5e1N9fX3yJQUqSKHTLCONDnKmfSckp2dw45xMXCLpbXfvdff9kn4s6TNH82RjdbOgNNRnuIYGBBTawL7QeXGmfScgx19L4gT1VkkLzazOojFjF0naWOoTTZo0SX19fYRLQtxdfX19mjSpvANw0pDTRpCam6WdO6W2tqhP2qz4yC2mfScgx19LYu2ZaGb/KOnPJB2Q9CtJ17n7/xU7vtDJxP3792vbtm3au3fv+EqMYZMmTdKMGTM0ceLErIty1EYvESlFLc6Eh6EGj3pIQLEzuGbS4MiukB3pZGJqm9sChTQ0RN9QRys0KiLvGPUxThX+YWJzWwSLvtlDhuZOHDxYljkT+Zfj1agIamSKvlkkZuSU7iOdEKhABDUyleNGELKQ068lBDUyleNGEJCYMJc5RVVhSU7gyGhRA0DgCGoAycrrDKYM0fUBIDmjZ+4MTeOW6N8aB1rUAJKT42ncWSKoASSHGUxlQVADSA4zmMqCoAaQHGYwlQVBDSA5zGAqC0Z9AEgWM5gSR4saFYmhuuHivUkeQY2KUwk7LlVrWFXCe1OJ2DgAFSf09eGrebeW0N+bkLHDC3Il9B2XqjmsQn9vQsYOL8iV0IfqBjHnI6O+l9Dfm0pFUKPihD5UN/OwyrCjOPT3plIR1Kg4oQ/VzTysMlxvI/T3plLRRw2UQaY7itNRXJGO1EfNhBegDDKd8zFzZuGzmXQUVyy6PoC8ybzvBUkjqIG8oaM4dwhqIBQrV0oTJkThOmFCdP1oNTdHg7YPHowuCemKRh81EIKVK6X77jt0fWDg0PV7782mTAgGLWogBK2tpd2OqkJQAyEYGCjtdlQVghool6Fp3EN9zmbFp3PX1hZ+jGK3o6oQ1EA5rFwpfeUrh8YzD7WMi03nbmkp/DjFbkdVIaiBpLW3S/ffX3h2oFR4Ove990o33nioBV1bG13nRCLEFHIgecXWOR2J6dwYhWVOgaTEWT40znqmTOdGCQhqIK64y4eOFcJM50aJgg7qat13LmRV/Z7EXT600FobQ5jOjaPh7kf8kXSGpK4RP7+X9DdH+jeNjY0+Xm1t7nV17lHTJfqpq4tuRzaq5j1pa3OfNcvdLLoceoFmh7/4oR+z+I8BFCGpw4tkakknE82sVtJ2SX/i7kXPliRxMrGa950LVVW8J0famXbVqiqoAGQlyZOJF0n63ZFCOilB7DuHw1TFe3Kk7g2WD0VGSg3qL0v6UaE7zKzFzDrMrKO3t3fcBct83zl8TFW8J0f6a8TyochI7KA2s2MkfVHSY4Xud/dWd29y96b6+vpxF4zGS3iq4j0Z668Ry4ciA6W0qC+R9Iq7/2+5CjMSjZfwVMV7UhV/jVBpYp9MNLNHJf3M3X8w1rHMTERFy3RnWlSrI51MjBXUZlYn6R1Jp7n7rrGOJ6gBoDTjHvXh7v3uPi1OSAOpqOqZNwhFWh9DtuJC5Rk91nloKrdEFwVSk+bHkNXzUHmqYuYNQpf0x5DV85AvVTHzBqFL82NIUKPyVMXMG4QuzY8hQY3slXpGhrHOCECaH0OCGtmKu8bzSFUx8wahS/NjyMlEZIsTg4AkTiYiZJwYRILyOryeoEa2ODGIhBxNL1qlIKiRLU4MIiFxd0qrRAQ1ssWJQSQkz71oBDWOTpKdgazxjATkuReNoEZp2tul6dOlq6/OZ2cgKlaee9EIasQ3dLamr+/j9+WlMxAVK8+9aIyjRnzFxjwPMYu6LwCUjHHUSMZYZ2Xy0BkIBIigRnxHCuK8dAYCASKoEV+hszWSNG1afjoDgQAR1Iiv0NmatjZp505CGigjtuJCaZqbCWUgZbSoASBwBHWe5HXpsIBQxcgCXR95wc7cZUcVIytMeMkLFuAvO6oY5cSEl2qQ56XDAkEVIysEdV7keemwQFDFyApBnaFET0zleemwQFDFyApBnZHEtw3K89JhgaCKkRVOJmaEE1MARuJkYoC2bpWuUrveVoMGVKO31aCr1M6JKQAfwzjqjNz0qXbd1deiyYoG5TaoR99Xi6Z/SpL4Lg3gEFrUGfmWVg2H9JDJ6te3xC4pAA5HUGdkyvuF+ziK3Q6gehHUWWFQLoCYCOqsMCgXQEwEdVYYlAsgplijPszsk5IelHS2JJf0F+6+vozlqg4swg8ghrjD8+6R9FN3/5KZHSOpwMZ5AIByGDOozewTkhZJWiFJ7r5P0r7yFgsAMCROH/Vpknol/cDMfmVmD5rZ5NEHmVmLmXWYWUdvb2/iBQWAahUnqCdIWiDpPnc/R9JuSbeNPsjdW929yd2b6uvrEy4mAFSvOEG9TdI2d//F4PV1ioIbAJCCMYPa3f9H0jtmdsbgTRdJeqOspQIADIs76uNmSe2DIz62SPpq+YoEABgpVlC7e5ekguukAgDKi5mJABA4ghoAAkdQA0DgCGoACFzYQd3eHu0CW1MTXR71Ft1AfHzsEJpw90xsb5daWqT+we2qenqi6xIrzqFs+NghROG0qEc3Y2655dD/liH9/dIq9hRE+axaxccO4QmjRV2oGVPMVvYURPkU+3jxsUOWwmhRF2rGFMOegigjtrJEiMII6rjNFfYURJmxlSVCFEZQF2uuTJvGnoIlYLTC+LGVJUJk7p74gzY1NXlHR0f8fzC6j1qKmjH8D4mNKgQqm5l1unvBNZXCaFHTjBk3RisA+RVGixrjVlMjFXorzaSDB9MvD4DShN+ixrgxWgHIL4I6JxitAOQXQZ0TdPMD+RXGzEQkormZYAbyiBY1AASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHAENQAEjqDGUWGTAiA9TCFHyQrtRdzSEv3OFHYgebSoUTI2KQDSRVCjZMX2Io67RzGA0hDUKBmbFADpIqhRMjYpANJFUKNkbFIApItRHzgqbFIApIcWNQAELlaL2sy6JX0oaUDSgWJbmgMAkldK18cF7r6zbCUBABRE1wcABC5uULukp82s08xaCh1gZi1m1mFmHb29vcmVEACqXNyg/qy7L5B0iaS/MrNFow9w91Z3b3L3pvr6+kQLCQDVLFZQu/u7g5c7JD0h6bxyFgoAcMiYQW1mk81s6tDvkpZKer3cBQMAROKM+jhR0hNmNnT8I+7+07KWCgAwbMygdvctkualUBYAQAEMzwOAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBI6gBIHAENQAEjqAGgMAR1AAQOIIaAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABI6gBoDAEdQAEDiCGgACR1ADQOAIagAIHEENAIEjqAEgcAQ1AASOoAaAwBHUABA4ghoAAkdQA0DgCGoACBxBDQCBCyao29ulhgappia6bG/PukQAEIYJWRdAikK5pUXq74+u9/RE1yWpuTm7cgFACIJoUa9adSikh/T3R7cDQLULIqi3bi3tdgCoJkEE9cyZpd0OANUkiKC+806pru7w2+rqotsBoNrFDmozqzWzX5nZfyRdiOZmqbVVmjVLMosuW1s5kQgAUmmjPm6RtFHSJ8pRkOZmghkAConVojazGZKWS3qwvMUBAIwWt+vjO5L+TtLBYgeYWYuZdZhZR29vbxJlAwAoRlCb2aWSdrh755GOc/dWd29y96b6+vrECggA1S5Oi/qzkr5oZt2SHpV0oZm1lbVUAIBhYwa1u9/u7jPcvUHSlyU96+5Xl71kAABJZVrro7Ozc6eZ9ZTwT6ZL2lmOslQY6iFCPVAHQ6qpHmYVu8PcPc2CFC6EWYe7N2VdjqxRDxHqgToYQj1EgpiZCAAojqAGgMCFEtStWRcgENRDhHqgDoZQDwqkjxoAUFwoLWoAQBEENQAELrWgNrNlZvammb1lZrcVuN/MbO3g/a+Z2YK0ypamGPXQPPj6XzOzl8xsXhblLLex6mHEceea2YCZfSnN8qUlTj2Y2WIz6zKz35jZf6ddxjTE+H9xvJn9u5m9OlgPX82inJlx97L/SKqV9DtJp0k6RtKrkmaPOuYLkv5TkklaKOkXaZQtzZ+Y9fAZSScM/n5JtdbDiOOelfSUpC9lXe6MPg+flPSGpJmD1/8g63JnVA9/L+nuwd/rJb0v6Zisy57WT1ot6vMkveXuW9x9n6I1Qy4bdcxlkv7VIy9L+qSZnZRS+dIyZj24+0vu/sHg1ZclzUi5jGmI83mQpJslPS5pR5qFS1GcevhzST92962S5O55rIs49eCSppqZSZqiKKgPpFvM7KQV1KdIemfE9W2Dt5V6TKUr9TX+paJvGXkzZj2Y2SmSrpB0f4rlSlucz8Ppkk4ws+fNrNPMrkmtdOmJUw//IulMSe9K+rWkW9y96LLLeVOWtT4KsAK3jR4XGOeYShf7NZrZBYqC+k/LWqJsxKmH70i61d0HokZULsWphwmSGiVdJOk4SevN7GV331TuwqUoTj1cLKlL0oWS/kjSf5nZi+7++zKXLQhpBfU2SX844voMRX8ZSz2m0sV6jWY2V9FuOpe4e19KZUtTnHpokvToYEhPl/QFMzvg7k+mUsJ0xP1/sdPdd0vabWYvSJonKU9BHacevirpnzzqpH7LzN6W9MeSNqRTxIyldLJggqQtkk7VoZMFZ406ZrkOP5m4IesO/IzqYaaktyR9JuvyZlkPo45/WPk8mRjn83CmpGcGj62T9Lqks7Muewb1cJ+kOwZ/P1HSdknTsy57Wj+ptKjd/YCZ3STpZ4rO8D7k7r8xsxsG779f0Zn9LygKqX5Ff0FzJWY9/IOkaZLuHWxNHvCcrR4Wsx5yL049uPtGM/uppNcUbYX3oLu/nl2pkxfz87Ba0sNm9mtFjblb3b1alj9lCjkAhI6ZiQAQOIIaAAJHUANA4AhqAAgcQQ0AgSOoASBwBDUABO7/ARH9usRSKFMjAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "y_ = linear_model(x_train)\n",
- "plt.plot(x_train.data.numpy(), y_train.data.numpy(), 'bo', label='real')\n",
- "plt.plot(x_train.data.numpy(), y_.data.numpy(), 'ro', label='estimated')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "经过 100 次更新,可以发现红色的预测结果已经比较好的拟合了蓝色的真实值。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 4. 多项式回归模型"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "下面更进一步尝试一下多项式回归,下面是关于 x 的多项式:\n",
- "\n",
- "$$\n",
- "\\hat{y} = w_0 + w_1 x + w_2 x^2 + w_3 x^3 \n",
- "$$\n",
- "\n",
- "这样就能够拟合更加复杂的模型,这里使用了 $x$ 的更高次,同理还有多元回归模型,形式也是一样的,只是除了使用 $x$,还是更多的变量,比如 $y$、$z$ 等等,同时他们的 $loss$ 函数和简单的线性回归模型是一致的。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "首先定义一个需要拟合的目标函数,这个函数是个三次的多项式"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3\n"
- ]
- }
- ],
- "source": [
- "# 定义一个多变量函数\n",
- "\n",
- "w_target = np.array([0.5, 3, 2.4]) # 定义参数\n",
- "b_target = np.array([0.9]) # 定义参数\n",
- "\n",
- "f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(\n",
- " b_target[0], w_target[0], w_target[1], w_target[2]) # 打印出函数的式子\n",
- "\n",
- "print(f_des)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "多项式的的曲线绘制"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7ac4f25b0>"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出这个函数的曲线\n",
- "x_sample = np.arange(-3, 3.1, 0.1)\n",
- "y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 + w_target[2] * x_sample ** 3\n",
- "\n",
- "plt.plot(x_sample, y_sample, label='real curve')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "接着构建数据集,需要 x 和 y,同时是一个三次多项式,所以取 $x,\\ x^2, x^3$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 构建数据 x 和 y\n",
- "# x 是一个如下矩阵 [x, x^2, x^3]\n",
- "# y 是函数的结果 [y]\n",
- "\n",
- "x_train = np.stack([x_sample ** i for i in range(1, 4)], axis=1)\n",
- "x_train = torch.from_numpy(x_train).float() # 转换成 float tensor\n",
- "\n",
- "y_train = torch.from_numpy(y_sample).float().unsqueeze(1) # 转化成 float tensor "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "torch.Size([61, 3])\n"
- ]
- }
- ],
- "source": [
- "print(x_train.size())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "接着定义需要优化的参数,就是前面这个函数里面的 $w_i$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 定义参数\n",
- "w = torch.randn((3, 1), dtype=torch.float, requires_grad=True)\n",
- "b = torch.zeros((1), dtype=torch.float, requires_grad=True)\n",
- "\n",
- "# 定义模型\n",
- "def multi_linear(x):\n",
- " return torch.mm(x, w) + b\n",
- "\n",
- "def get_loss(y_, y):\n",
- " return torch.mean((y_ - y) ** 2)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "我们可以画出没有更新之前的模型和真实的模型之间的对比"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7ac41e220>"
- ]
- },
- "execution_count": 20,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出更新之前的模型\n",
- "y_pred = multi_linear(x_train)\n",
- "\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以发现,这两条曲线之间存在差异,计算一下他们之间的误差"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(1144.2654, grad_fn=<MeanBackward0>)\n"
- ]
- }
- ],
- "source": [
- "# 计算误差,这里的误差和一元的线性模型的误差是相同的,前面已经定义过了 get_loss\n",
- "loss = get_loss(y_pred, y_train)\n",
- "print(loss)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 自动求导\n",
- "loss.backward()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor([[ -94.7455],\n",
- " [-139.1247],\n",
- " [-629.8584]])\n",
- "tensor([-25.7413])\n"
- ]
- }
- ],
- "source": [
- "# 查看一下 w 和 b 的梯度\n",
- "print(w.grad)\n",
- "print(b.grad)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 更新一下参数\n",
- "w.data = w.data - 0.001 * w.grad.data\n",
- "b.data = b.data - 0.001 * b.grad.data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7ac392850>"
- ]
- },
- "execution_count": 25,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出更新一次之后的模型\n",
- "y_pred = multi_linear(x_train)\n",
- "\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "因为只更新了一次,所以两条曲线之间的差异仍然存在,下面进行 100 次迭代"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "epoch 20, Loss: 65.56586\n",
- "epoch 40, Loss: 15.41177\n",
- "epoch 60, Loss: 3.70702\n",
- "epoch 80, Loss: 0.97122\n",
- "epoch 100, Loss: 0.32874\n"
- ]
- }
- ],
- "source": [
- "# 进行 100 次参数更新\n",
- "for e in range(100):\n",
- " y_pred = multi_linear(x_train)\n",
- " loss = get_loss(y_pred, y_train)\n",
- " \n",
- " w.grad.data.zero_()\n",
- " b.grad.data.zero_()\n",
- " loss.backward()\n",
- " \n",
- " # 更新参数\n",
- " w.data = w.data - 0.001 * w.grad.data\n",
- " b.data = b.data - 0.001 * b.grad.data\n",
- " if (e + 1) % 20 == 0:\n",
- " print('epoch {}, Loss: {:.5f}'.format(e+1, loss.data.item()))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看到更新完成之后 loss 已经非常小了,我们画出更新之后的曲线对比"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7fb7ac305e80>"
- ]
- },
- "execution_count": 27,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出更新之后的结果\n",
- "y_pred = multi_linear(x_train)\n",
- "\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve', color='r')\n",
- "plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')\n",
- "plt.legend()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看到,经过 100 次更新之后,可以看到拟合的线和真实的线已经完全重合了"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": true
- },
- "source": [
- "## 练习题\n",
- "\n",
- "* 上面的例子是一个三次的多项式,尝试使用二次的多项式去拟合它,看看最后能做到多好\n"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.5.4"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|