{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Logistic 回归模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上一节课我们学习了简单的线性回归模型,这一次课中,我们会学习第二个模型,Logistic 回归模型。\n", "\n", "Logistic 回归是一种广义的回归模型,其与多元线性回归有着很多相似之处,模型的形式基本相同,虽然也被称为回归,但是其更多的情况使用在分类问题上,同时又以二分类更为常用。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 模型形式\n", "Logistic 回归的模型形式和线性回归一样,都是 y = wx + b,其中 x 可以是一个多维的特征,唯一不同的地方在于 Logistic 回归会对 y 作用一个 logistic 函数,将其变为一种概率的结果。 Logistic 函数作为 Logistic 回归的核心,我们下面讲一讲 Logistic 函数,也被称为 Sigmoid 函数。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sigmoid 函数\n", "Sigmoid 函数非常简单,其公式如下\n", "\n", "$$\n", "f(x) = \\frac{1}{1 + e^{-x}}\n", "$$\n", "\n", "Sigmoid 函数的图像如下\n", "\n", "![](https://ws2.sinaimg.cn/large/006tKfTcly1fmd3dde091g30du060mx0.gif)\n", "\n", "可以看到 Sigmoid 函数的范围是在 0 ~ 1 之间,所以任何一个值经过了 Sigmoid 函数的作用,都会变成 0 ~ 1 之间的一个值,这个值可以形象地理解为一个概率,比如对于二分类问题,这个值越小就表示属于第一类,这个值越大就表示属于第二类。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "另外一个 Logistic 回归的前提是确保你的数据具有非常良好的线性可分性,也就是说,你的数据集能够在一定的维度上被分为两个部分,比如\n", "\n", "![](https://ws1.sinaimg.cn/large/006tKfTcly1fmd3gwdueoj30aw0aewex.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,上面红色的点和蓝色的点能够几乎被一个绿色的平面分割开来" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 回归问题 vs 分类问题\n", "Logistic 回归处理的是一个分类问题,而上一个模型是回归模型,那么回归问题和分类问题的区别在哪里呢?\n", "\n", "从上面的图可以看出,分类问题希望把数据集分到某一类,比如一个 3 分类问题,那么对于任何一个数据点,我们都希望找到其到底属于哪一类,最终的结果只有三种情况,{0, 1, 2},所以这是一个离散的问题。\n", "\n", "而回归问题是一个连续的问题,比如曲线的拟合,我们可以拟合任意的函数结果,这个结果是一个连续的值。\n", "\n", "分类问题和回归问题是机器学习和深度学习的第一步,拿到任何一个问题,我们都需要先确定其到底是分类还是回归,然后再进行算法设计" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 损失函数\n", "前一节对于回归问题,我们有一个 loss 去衡量误差,那么对于分类问题,我们如何去衡量这个误差,并设计 loss 函数呢?\n", "\n", "Logistic 回归使用了 Sigmoid 函数将结果变到 0 ~ 1 之间,对于任意输入一个数据,经过 Sigmoid 之后的结果我们记为 $\\hat{y}$,表示这个数据点属于第二类的概率,那么其属于第一类的概率就是 $1-\\hat{y}$。如果这个数据点属于第二类,我们希望 $\\hat{y}$ 越大越好,也就是越靠近 1 越好,如果这个数据属于第一类,那么我们希望 $1-\\hat{y}$ 越大越好,也就是 $\\hat{y}$ 越小越好,越靠近 0 越好,所以我们可以这样设计我们的 loss 函数\n", "\n", "$$\n", "loss = -(y * log(\\hat{y}) + (1 - y) * log(1 - \\hat{y}))\n", "$$\n", "\n", "其中 y 表示真实的 label,只能取 {0, 1} 这两个值,因为 $\\hat{y}$ 表示经过 Logistic 回归预测之后的结果,是一个 0 ~ 1 之间的小数。如果 y 是 0,表示该数据属于第一类,我们希望 $\\hat{y}$ 越小越好,上面的 loss 函数变为\n", "\n", "$$\n", "loss = - (log(1 - \\hat{y}))\n", "$$\n", "\n", "在训练模型的时候我们希望最小化 loss 函数,根据 log 函数的单调性,也就是最小化 $\\hat{y}$,与我们的要求是一致的。\n", "\n", "而如果 y 是 1,表示该数据属于第二类,我们希望 $\\hat{y}$ 越大越好,同时上面的 loss 函数变为\n", "\n", "$$\n", "loss = -(log(\\hat{y}))\n", "$$\n", "\n", "我们希望最小化 loss 函数也就是最大化 $\\hat{y}$,这也与我们的要求一致。\n", "\n", "所以通过上面的论述,说明了这么构建 loss 函数是合理的。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们通过例子来具体学习 Logistic 回归" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.autograd import Variable\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 设定随机种子\n", "torch.manual_seed(2017)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们从 data.txt 读入数据,感兴趣的同学可以打开 data.txt 文件进行查看\n", "\n", "读入数据点之后我们根据不同的 label 将数据点分为了红色和蓝色,并且画图展示出来了" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAH7RJREFUeJzt3X2MXNWZ5/HvY2Nj9Y4TwO6RIrdd3Vk5GwxBm7jNZCJlE4VlcZisvUlGI6BhgyYTh2xIRtpMFCIIioxayaxWMwpadiXnZWLcHRDLHyuPhsWMBlC0Uci6Ea82gRgP2O2MlE6b7IYELy959o9bDdXlerlddV/Ouff3kUquun1cdereW8899znnnmvujoiIVMuqsisgIiLZU3AXEakgBXcRkQpScBcRqSAFdxGRClJwFxGpIAV3EZEKUnAXEakgBXcRkQo6p6wP3rhxo4+Pj5f18SIiUXr00Ud/6e6j/cqVFtzHx8eZm5sr6+NFRKJkZi+mKae0jIhIBSm4i4hUkIK7iEgFKbiLiFSQgruISAX1De5m9j0z+4WZPd3l72Zmt5vZMTN70szel301MzQ7C+PjsGpV8u/sbNk1EhHJXJqW+/eBnT3+/lFga/OxB/hvw1crJ7OzsGcPvPgiuCf/7tmjAC8ildM3uLv7D4HTPYrsBu70xCPAeWb2jqwqmKmbb4bf/nb5st/+NlkuIlIhWeTcNwEnW17PN5edxcz2mNmcmc0tLCxk8NErdOLEypaLiESq0A5Vd9/n7pPuPjk62vfq2ext2bKy5RlRmr+etN2lTFkE91PA5pbXY81l4ZmehpGR5ctGRpLlOVGav5603aVsWQT3g8C/b46aeT/wf9z9nzJ43+xNTcG+fdBogFny7759yfKcKM1fnjJbztruUrY0QyHvAn4M/AszmzezT5vZDWZ2Q7PIfcBx4BjwbeA/5FbbLExNwQsvwO9+l/ybY2AHpfmL1BrMN26EP/3T8lrOMWx3pY0qzt1LeWzfvt3roNFwT8LL8kejUXbNqmVmxn1kpPO6LmO9D7rdZ2aSMmbJvzMz+dSv0/oaGcnv8yQ7wJyniLG6QjVnJaT5a6lTGqSTolrOg2z3IvP0ShtVX3WDeyDnnCWk+WspbdDOeWDUmwbZ7kUG3BjSRjIcS1r5xZucnPTcbtax1ARq/aWMjCiqVtj4eNLS7SX0XWDVqqTF3s4s6SLKUrf11WgkXVESLjN71N0n+5WrZstd55y10ykNsmYNbNgQzxlTkZdhKF1YfdUM7jrnrJ1OaZC/+Rv45S8LGxg1tCIDrtKF1VfNtIzOOSVSs7PJCeaJE0mLfXpaAVeWq3daRuecEqmCL8OQCqtmcNc5p4jUXDWDO6gJJEEKZISu1MA5ZVdApC7aR+guXaQEantI9qrbchcJjEboSpEU3EUKEtMIXaWP4qfgLlKQku4Vs2Kai74aFNxFChLLCF2lj6pBwV2kILGM0O2WJuo3d4+ERcFdpEAxjNDtliYyU2omJgruIpHLuvNzejoJ5O3clZqJiYK7SMTy6Pycmuo89TCEObJHOlNwF4nQUmv92mvz6fxsNDovD21kj3SXKrib2U4ze9bMjpnZTR3+3jCzfzCzJ83sYTMby76qEgONj85fa2u9m2Fb2LGM7JHu+gZ3M1sN3AF8FNgGXG1m29qK/WfgTne/BNgLfCPrikr4ND66GGnuFztsCzuWkT3SXZqW+6XAMXc/7u6vAncDu9vKbAMebD5/qMPfpWBltKA1ProY/VrlWbWwYxjZ00+dzyTTBPdNwMmW1/PNZa2eAD7RfP5xYL2ZbRi+ejKIslrQMV1eH7NerXK1sN9S9zPJrDpU/wL4kJk9BnwIOAW80V7IzPaY2ZyZzS0sLGT00dKurBZ0LJfXx65bPnxmJt4Wdh7qfiaZJrifAja3vB5rLnuTu//c3T/h7u8Fbm4u+1X7G7n7PnefdPfJ0dHRIaotvZTVglYnXDF65cPrnIZo/+7dOpxrcybp7j0fJHO+HwcmgLUkKZiL2spsBFY1n08De/u97/bt210SMzPujYa7WfLvzMxw79douCcnossfjcbwde0n6+8i6c3MuI+MLN/mIyP12AadvrtZeb+DPAFz3ie+evJVUxSCK4HngOeBm5vL9gK7ms//GPhZs8x3gHP7vaeCeyKPH2Sdf+R1VuZBvWzdvnt7gM/zd1BUwybT4J7HQ8E9kdcPUi3o+unWUjUru2b56/bdl35Lef8OimxQpQ3ulpQt3uTkpM/NzZXy2SFZtarzpd5myRA06W52NukcO3Ei6bSdnq53Z2K3PHOjkXS0VlnZ373IzzezR919sl85TT9QMo0wGUzdh7m1WupIfPHFsyf8qkuHdtmd+SEOA1ZwL1nZO2WsQhzmVsZIlfapCNzfCvB1GvNe9hW1QTbS0uRu8ngo5/4W5cdXLrT8clmd2HXuRA2Jcu4tlHOXYZSdY21XZH1a+xq6/XzVZ1O8ovqAlHOXSgstnVVUzrW9r6Eb9dkUL7S5eBTcJUpl51jbFZVzTTMjpPpsBBTcJWIhtZSKOpPodSYQwkFOwnFO2RUQqYKlYJp3znXLlrD6GiRcarmLZKSIM4nQ+hokXAruIhEJra9BwqW0jEhkpqYUzKU/tdxFRCpIwV1EpIIU3EVEKkjBXWqtzrelk2pTh6rU1tKl/EtXfC5NGwzqsJT4qeUutRXitMEiWVFwl9oK8QYLIllRcJfaCvIGCyIZSRXczWynmT1rZsfM7KYOf99iZg+Z2WNm9qSZXZl9VUWypUv5pcr6BnczWw3cAXwU2AZcbWbb2ordAtzj7u8FrgL+a9YVrSwN1yiNLuWXKkvTcr8UOObux939VeBuYHdbGQfe1nz+duDn2VWxwnSX59KFMG2wju+ShzTBfRNwsuX1fHNZq68D15rZPHAf8IVMald1Gq5Rezq+S16y6lC9Gvi+u48BVwIHzOys9zazPWY2Z2ZzCwsLGX30EMpuMmm4RunK3gV0fJe8pAnup4DNLa/HmstafRq4B8DdfwysAza2v5G773P3SXefHB0dHazGWQmhyaThGqUKYRfQ8V3ykia4Hwa2mtmEma0l6TA92FbmBHAZgJldSBLcA2ia9xBCk0nDNUoVwi6g43s4yj6Ly1rf4O7urwM3AoeAZ0hGxRwxs71mtqtZ7EvAZ8zsCeAu4Hr3XvdmD0AITSYN1yhVCLuAju9hCOEsLnPuXspj+/btXqpGwz3ZjssfjUa59ZLChLILzMwkn2mW/DszU+znDyrWencSyr6QBjDnKWJsfa9QzbPJVLXzu4oKpdUcwnDMlapaSzeEs7jMpTkC5PEoveXunk/TY2bGfWRk+eF/ZCSaZk2IrbE86xTi941BTC3dNGL6PqRsudc7uOchpr2kTYjHpRDrVDWDHODMOu/mZnnXNh8x7WcK7r3k2VyLeK8P8bgUYp2qZNCgVsXtEstZXNrgbknZ4k1OTvrc3FzxH9x+hwZIEq1ZjVIZH08SkO0ajSShGrBVq5KfaDuzJB9chhDrVCWD7q55/4ykOzN71N0n+5WrX4dq3oObQ+mlG0CIY65DrFOVDNqRqFG84atfcM+7WzzivT7E41KIdaqSbgfJVav6D/aKcZRPraTJ3eTxKC3nXsVkYYZCzDuGWKeq6JRzb3+E2rFYVyjn3oWShSLLzM4mWckTJ5LW+htvnF0mgi6j2lDOvZs0aRNdhBQNbarhtaZXunVSR30xT03Vr+Xej1r20ei0qdauhfXr4fTpJJ88Pa3NthIRD/aqDbXcBxXCVIGSSqdN9eqrsLhYjUviy6AO7HwVeaap4N6ukpNMVFOaTaLj8spEPNgreEXPx6O0TDudl0aj26ZqpwueJARZhRalZQaV9Xmpevxy02lTdaILniQERScFFNzbZXleWrV5UQPTvqk2bIA1a5aXUb5YQlH01dZKy+RJKZ7CtY7Z1mgZCUlWA/GUlgmBOmcLp0viz6bMYBiK7qxWcM+TZr0qXdUC20q/T9Uyg7Fvz0IbH2nmKAB2As8Cx4CbOvz9r4HHm4/ngF/1e8/K3qyjVUx3AEghtjleKrb6B/o+VZpKqWrbc1BkdbMOYDXwPPBOYC3wBLCtR/kvAN/r9761CO7u8UXELmL8YVUpsLkP9n0ivnfMWaq2PQeVNrinSctcChxz9+Pu/ipwN7C7R/mrgbtWdPpQZRVJAsd44W7VujwG+T7dMoDu8aU1Qt6evdJFZaWSzklRZhNwsuX1PPAHnQqaWQOYAB4cvmoSkpB/WN1s2dJ5sFKsXR6DfJ/p6bNHaCxZyr9DHG2OULdn+yiY1vUK3f+W9zrPukP1KuBed+8waSiY2R4zmzOzuYWFhYw/WvIUY99w1eZJGeT7tI7Q6CT0s69WoW7PXme1pZ7x9svbAH8IHGp5/VXgq13KPgZ8IE0+qDY594qIMefuXpkujzcN832qkH8PcXv2Wq95rHOyulmHmZ1DMgLmMuAUcBi4xt2PtJV7N3A/MOH93pSaXMRUMbpAKG66pi4fvdYrZL/OM7uIyd1fB24EDgHPAPe4+xEz22tmu1qKXgXcnSawS5wq0jdcW6GmNWLXa72Wus7TNO/zeCgtI1K8ENMaVdBrvWa9ztE9VCtIeZFgaFNIWdKmZdIMhZQQ9BpvpahSKG0KiYHmlolFjFcRVZQ2RRhin2cmbwrusYjxKqKK0qYoX14TolXpgKHgnoUi9ogYryKqKG2K8uVx9lS1GTQV3IdV1B4R+Ti2KrWIIt8UlZDH2VPl0m1phtTk8ajMUMgip6qLdBxbrFe39hLppqiMPH52sVzBi4ZCFmTVqmQfaGeWXO0jujJSMpfVLetaxbKf6jZ7RVECti91QErW8rhlXdXSbQruw6raHpEDHf8kD1lPh1H0PU7zpuA+rKrtEb0M2Cuq45/EokrzJ+kK1SxMTcW9F6QxxGWZS3/W5foixVGHqqQTS2+TSMWpQ1WypV5RkagouEs66hUViYqCu6SjXlGRqCi4Szp1GhXURZWmUJDilLXfKLhLekvjxA4cSF5fd11tolzVJpWSYpS539Q7uKsptnI1jXKVm1RKClHmflPf4F7TIHWWlR7gahrlNFhIBlHmfpMquJvZTjN71syOmdlNXcr8iZkdNbMjZvaDbKuZg5oGqWUGOcDltLeGfhKlwUIyiFL3m37TRgKrgeeBdwJrgSeAbW1ltgKPAec3X/9+v/ctfcrfWOb3zNMg86bmMNdqDFMCx1BHCU8e+w0pp/xN03K/FDjm7sfd/VXgbmB3W5nPAHe4+0vNA8Yvhj3o5E5NscFa4TkMiYzhJEqDhWQQZe43aYL7JuBky+v55rJW7wLeZWY/MrNHzGxnVhXMjcZtD3aAy2FvjSWfXaVJpaQ4Ze03WXWonkOSmvkwcDXwbTM7r72Qme0xszkzm1tYWMjoowekptjgB7iM91adRIlkL01wPwVsbnk91lzWah446O6vufs/As+RBPtl3H2fu0+6++To6Oigdc5O3ZtigRzgdBIlkr00wf0wsNXMJsxsLXAVcLCtzP8gabVjZhtJ0jTHM6yn5CWAA1wgxxiRSukb3N39deBG4BDwDHCPux8xs71mtqtZ7BCwaGZHgYeAL7v7Yl6VluoJ4BgjBQl92GtVaD53ESlMHje2rhvN515lavpI4LrtojEMe60K3WYvNkPc7k6kCL120ViGvVaB0jKx0e3uJHC9dlHQ7jsspWWqSk0fCVyvXVTDXouj4B4bXfEjgeu1i2rYa3EU3GOjpo8Ert8uqmGvxVBwj42aPhI47aJhUIeqiEhE1KEqIlJjCu4iIhWk4C4iUkEK7iIZ0+wQEgIFd8lPDaPcIPccF8mDgntd5R14Q4lyBR9gNDGWhEJDIeuoiHlXQ5gDp4T5ZVetSo5l7cySi3ZEhpV2KKSCex0VEXhDiHIlHGBCOKZJtWmcu3RXxORjIcyBU8Ika5odQkKh4F5HRQTeEKJcCQcYXXovoVBwr6MiAm/ZUW52Fl5++ezlBRxgNDGWhEDBvY56Bd4sR5eUFeWWOlIX2+7RvmGDmtFSG6mCu5ntNLNnzeyYmd3U4e/Xm9mCmT3efPxZ9lWlluOmc9Mp8IYyfHFYncYjAvze7ymwS230HS1jZquB54DLgXngMHC1ux9tKXM9MOnuN6b94BWPltFt0/NXlaEeIYzUEclJlqNlLgWOuftxd38VuBvYPWwFV0xXh+Sv2yiSTgE/ZCGM1BEpWZrgvgk42fJ6vrms3SfN7Ekzu9fMNnd6IzPbY2ZzZja3sLCwsprq3qH56xb8zOJKzYQwUqemlDkNR1Ydqn8LjLv7JcDfA/s7FXL3fe4+6e6To6OjK/sEtcbyNz2dBPJ27nGdIZU9UqemqtJlUxVpcu5/CHzd3a9ovv4qgLt/o0v51cBpd397r/dVzj1QnYL70nLlq6WHqnTZhC7LnPthYKuZTZjZWuAq4GDbh72j5eUu4JmVVDYVtcaK0Wh0Xq4zpCgVmSZR5jQsfYO7u78O3AgcIgna97j7ETPba2a7msW+aGZHzOwJ4IvA9bnUNs9x00oWJpSvroyi0yTKnAbG3Ut5bN++3YMxM+M+MuKe/AaSx8hIsryOZmbcGw13s+Tfz31u+eu6rpdu2tdXIOun0Vi+Sy89Go18Pk8/o2IAc54ixiq4uxf/K4iJfrG9lbh++h1TzDrv1mbl1UmGlza4a8pf0EUvvaiXrLeS1k+a8QXadNWkKX9XQsnCzmZnu1/ApF6yREm9iGmu6VP3yeCq0AWn4A76FXSy1DTspu4HviUlNQzSHFPqMMAsjyBcmfH6aXI3eTyCyrm7K1nYrls/hHLuy5WUc1c3UX6rPvR1izpUZSjdeuNAgb1dCQ2DvI8pMbR18grCZXREr0Ta4K4OVelMvXHBm51NcuwnTiRZoOnpbFIusVwMntc4iNB3/ep2qGaRZKtCb0ne1A8RvLyu6YtlAta8ujsqs+unad7n8RgoLZPFuajGbacXw7m5ZC7LtESeu1CeP+WQd30qmXNfaZKt0xYKvbdEpGRZ/USKaEe1/sQ3bEgeIQbkLFUzuK+kSdFtz+rWSRhKb4nEIeSm3ZCyCspFtqPqdEJezeC+kr2lW9nVq9Vyr5IygmwNIkmq1dqnUJGjTup0Ql7N4L6SH1WvoXwV/2HWRllBtk6RpJsU677I1RT68MUsVTO4u6dvqfXasyp8Sl0rZQXZOkWSblKs+yKPvXU63lY3uKdVg1Pn2isryIYQScpuoKRc90VVs04/dwV39/J/AJKvsoJs2ZGk7M93D+MA16YuP3cFd6m+ooNcKOPuQgisIRxgaiptcI/vClWRJUVOe9g+VeDiIrzyChw4kP0tH/sJ4WalU1PwqU/B6tXJ69Wrk9chzU9Qc5pbRiSNkCYcCaEusUxAU0Fp55ZJFdzNbCfwLWA18B13/2aXcp8E7gV2uHvPyN0puL/22mvMz89z5syZvnWKwbp16xgbG2PNmjVlV0WGFdLdukIIrCEcYGoqbXA/J8UbrQbuAC4H5oHDZnbQ3Y+2lVsP/Dnwk8GqDPPz86xfv57x8XHMbNC3CYK7s7i4yPz8PBMTE2VXR4a1ZUvnYFbGTUuWAngeU0KmFUJqSHpKk3O/FDjm7sfd/VXgbmB3h3K3AX8JDNzsPnPmDBs2bIg+sAOYGRs2bKjMWUjmYpuZM7SpAvOaEjIt3ZoyeGmC+ybgZMvr+eayN5nZ+4DN7v53w1aoCoF9SZW+S6ZivI9ZHe5ZtxKhHexWKrbGxQCGHi1jZquAvwK+lKLsHjObM7O5hYWFYT9aYhXLhOHtym4thyTmg12MjYsBpAnup4DNLa/HmsuWrAcuBh42sxeA9wMHzeyshL+773P3SXefHB0dHbzWS0o8+u7fv5+tW7eydetW9u/fX9jnVoLytdUQ68FukMZFjC39fgPhSTpdjwMTwFrgCeCiHuUfBib7vW+ni5iOHj2afiR/iRdRLC4u+sTEhC8uLvrp06d9YmLCT58+3bHsir5TXYRwEY7U10qnrQjsgi2yuojJ3V8HbgQOAc8A97j7ETPba2a78jjgpJLDqf3hw4e55JJLOHPmDL/5zW+46KKLePrpp88qd+jQIS6//HIuuOACzj//fC6//HLuv//+gT+3dmLP10rcVtoZHGkase9QSAB3vw+4r23ZrV3Kfnj4aqWQw6n9jh072LVrF7fccguvvPIK1157LRdffPFZ5U6dOsXmzW9lqsbGxjh16tRZ5aSLEIbySX1NT3e+TqBb4yLSNGKq4B6knMYd33rrrezYsYN169Zx++23D/Ve0sPUlIK5lGOljYuQrnFYgXjnlsnp1H5xcZGXX36ZX//6113HqG/atImTJ98aHTo/P8+mTZs6lhWRAK2kMzjSNGK8wT2noVif/exnue2225iamuIrX/lKxzJXXHEFDzzwAC+99BIvvfQSDzzwAFdcccVQnysigYp02Ge8aRnI/NT+zjvvZM2aNVxzzTW88cYbfOADH+DBBx/kIx/5yLJyF1xwAV/72tfYsWMHkKRyLrjggszqISKBiTCNGNSskM888wwXXnhhKfXJSxW/00BmZ9WBKpKBzCYOExla+yyGS1cEggK8SE4U3Ht46qmnuO6665YtO/fcc/nJTwae+LKeeo0TVnAXyYWCew/vec97ePzxx8uuRvwiHScsErN4R8tIPDQ9rEjhFNwlf5GOExaJmYK75C/SccIiMVPOXYoR4ThhkZhF3XIvc4rlnTt3ct555/Gxj32suA8VEUkp2uBe9s1UvvzlL3PgwIFiPkxEZIWiDe55TLGcdj53gMsuu4z169cP/mEiIjmKNueex9DptPO5i4iELtrgntcUy5rPXUSqINq0TF5Dp9PM5y4iErpog3teQ6fTzOcuEoQyh4tJ8KJNy0D2Q6fTzucO8MEPfpCf/vSnvPzyy4yNjfHd735XN+yQ4mimTekj1XzuZrYT+BawGviOu3+z7e83AJ8H3gBeBva4+9Fe76n53EWGMD7eudOp0UhuGyeVlXY+975pGTNbDdwBfBTYBlxtZtvaiv3A3d/j7v8S+E/AXw1QZxFJSzNtSh9p0jKXAsfc/TiAmd0N7AbebJm7+/9tKf/PgHJu75QxzecuwcpruJhURprgvgk42fJ6HviD9kJm9nngPwJrgbOT1BHSfO4SrOnp5Tl30Eybskxmo2Xc/Q53/+fAV4BbOpUxsz1mNmdmcwsLC93eJ6sqla5K30UCo5k2pY80wf0UsLnl9VhzWTd3A/+u0x/cfZ+7T7r75Ojo6Fl/X7duHYuLi5UIiu7O4uIi69atK7sqUlVTU0nn6e9+l/yrwC4t0qRlDgNbzWyCJKhfBVzTWsDMtrr7z5ov/wj4GQMYGxtjfn6ebq362Kxbt46xsbGyqyEiNdQ3uLv762Z2I3CIZCjk99z9iJntBebc/SBwo5n9a+A14CXgU4NUZs2aNUxMTAzyX0VEpEWqi5jc/T7gvrZlt7Y8//OM6yUiIkOIdvoBERHpTsFdRKSCUk0/kMsHmy0AHa7CSGUj8MsMq5OnmOoKcdU3prqC6punmOoKw9W34e5nDzdsU1pwH4aZzaWZWyEEMdUV4qpvTHUF1TdPMdUViqmv0jIiIhWk4C4iUkGxBvd9ZVdgBWKqK8RV35jqCqpvnmKqKxRQ3yhz7iIi0lusLXcREekh6OBuZjvN7FkzO2ZmN3X4+w1m9pSZPW5m/6vDTUQK06+uLeU+aWZuZqX27KdYt9eb2UJz3T5uZn9WRj2bdem7bs3sT8zsqJkdMbMfFF3Htrr0W7d/3bJenzOzX5VRz2Zd+tV1i5k9ZGaPmdmTZnZlGfVsqU+/+jbM7B+adX3YzEqb3MnMvmdmvzCzp7v83czs9uZ3edLM3pdpBdw9yAfJPDbPA+8kmSP+CWBbW5m3tTzfBdwfal2b5dYDPwQeASYDX7fXA/8lkv1gK/AYcH7z9e+HXN+28l8gma8pyLqS5IY/13y+DXgh5HUL/HfgU83nHwEOlFjffwW8D3i6y9+vBP4nYMD7gZ9k+fkht9zfvAOUu79KMpXw7tYCHs4doPrWtek24C+BM0VWroO09Q1Bmrp+BrjD3V8CcPdfFFzHVitdt1cDdxVSs7OlqasDb2s+fzvw8wLr1y5NfbcBDzafP9Th74Vx9x8Cp3sU2Q3c6YlHgPPM7B1ZfX7Iwb3THaA2tRcys8+b2fMk9279YkF1a9e3rs1Trs3u/ndFVqyLVOsW+GTzdPFeM9vc4e9FSFPXdwHvMrMfmdkjzRu6lyXtusXMGsAEbwWjoqWp69eBa81snmTywC8UU7WO0tT3CeATzecfB9ab2YYC6jaI1PvKIEIO7ql4ijtAlc3MVpHcNPxLZddlBf4WGHf3S4C/B/aXXJ9eziFJzXyYpCX8bTM7r9QapXMVcK+7v1F2RXq4Gvi+u4+RpBEONPfnUP0F8CEzewz4EMk9KEJev7kJeSNldgeoAvSr63rgYuBhM3uBJL92sMRO1b7r1t0X3f3/NV9+B9heUN3apdkP5oGD7v6au/8j8BxJsC/DSvbbqygvJQPp6vpp4B4Ad/8xsI5kXpQypNlvf+7un3D39wI3N5eV1mHdx0pj3MqU1dmQojPiHOA4yWnrUufJRW1ltrY8/7ckNw8Jsq5t5R+m3A7VNOv2HS3PPw48EnBddwL7m883kpzqbgi1vs1y7wZeoHmtSah1Jenwu775/EKSnHspdU5Z343AqubzaWBvWeu3WYdxuneo/hHLO1T/d6afXeYXT7FiriRphT0P3NxcthfY1Xz+LeAI8DhJ50nXgFp2XdvKlhrcU67bbzTX7RPNdfvugOtqJGmvo8BTwFUhr9vm668D3yyzninX7TbgR8394HHg3wRe3z8muc3ncyRnnOeWWNe7gH8iuUPdPMlZ0A3ADc2/G3BH87s8lXVM0BWqIiIVFHLOXUREBqTgLiJSQQruIiIVpOAuIlJBCu4iIhWk4C4iUkEK7iIiFaTgLiJSQf8fr3/Tekz1AzcAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 从 data.txt 中读入点\n", "with open('./data.txt', 'r') as f:\n", " data_list = [i.split('\\n')[0].split(',') for i in f.readlines()]\n", " data = [(float(i[0]), float(i[1]), float(i[2])) for i in data_list]\n", "\n", "# 标准化\n", "x0_max = max([i[0] for i in data])\n", "x1_max = max([i[1] for i in data])\n", "data = [(i[0]/x0_max, i[1]/x1_max, i[2]) for i in data]\n", "\n", "x0 = list(filter(lambda x: x[-1] == 0.0, data)) # 选择第一类的点\n", "x1 = list(filter(lambda x: x[-1] == 1.0, data)) # 选择第二类的点\n", "\n", "plot_x0 = [i[0] for i in x0]\n", "plot_y0 = [i[1] for i in x0]\n", "plot_x1 = [i[0] for i in x1]\n", "plot_y1 = [i[1] for i in x1]\n", "\n", "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n", "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n", "plt.legend(loc='best')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来我们将数据转换成 NumPy 的类型,接着转换到 Tensor 为之后的训练做准备" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "np_data = np.array(data, dtype='float32') # 转换成 numpy array\n", "x_data = torch.from_numpy(np_data[:, 0:2]) # 转换成 Tensor, 大小是 [100, 2]\n", "y_data = torch.from_numpy(np_data[:, -1]).unsqueeze(1) # 转换成 Tensor,大小是 [100, 1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们来实现以下 Sigmoid 的函数,Sigmoid 函数的公式为\n", "\n", "$$\n", "f(x) = \\frac{1}{1 + e^{-x}}\n", "$$" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# 定义 sigmoid 函数\n", "def sigmoid(x):\n", " return 1 / (1 + np.exp(-x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "画出 Sigmoid 函数,可以看到值越大,经过 Sigmoid 函数之后越靠近 1,值越小,越靠近 0" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAHB9JREFUeJzt3XmYVNWd//H3V1YXIiDIjuAIRiZjXFqNOv7UURSIgsYNonFDSYg4cVxGHR00ap4kkp+JTjSKW9xZ4q9bRHhwHxNXlggqiDauoCwqImKgafj+/jjVWjbVdHV3VZ2qW5/X89ynqu493fXt28WnL+fee465OyIikizbxC5ARERyT+EuIpJACncRkQRSuIuIJJDCXUQkgRTuIiIJpHAXEUkghbuISAIp3EVEEqh1rDfu0qWL9+vXL9bbi4iUpLlz537i7l0baxct3Pv168ecOXNivb2ISEkys/ezaaduGRGRBFK4i4gkkMJdRCSBFO4iIgmkcBcRSaBGw93M7jKzlWb2egPbzcxuMrNqM1tgZvvkvkwREWmKbI7c/wwM2cr2ocCA1DIG+FPLyxIRkZZo9Dp3d3/OzPptpckI4F4P8/W9ZGYdzayHu3+coxpFJKncYcOGzMv69VBTA7W1sGlT5mVr2zZtgs2bw3vULXXv2dR1Tf26+j9jfcceC/vtl9t9WU8ubmLqBXyY9nppat0W4W5mYwhH9/Tt2zcHby0i0dTWwqefwiefwKpV4bHu+erVsHbtlsuXX4bHr74K4b1xY+yfojDMvv26Z8+SCPesuftEYCJARUWFZuYWKWbusHQpLFoEb70FH3wQlg8/DI8ffRSOjDPZYQfo0OGbZYcdoHfvb15vtx20axeW9u2/eZ6+tG8PbdtC69bQqtWWS0Pr05dttgnBmr5A89Y15euKQC7CfRnQJ+1179Q6ESkVGzbAggUweza88gq8/jq8+SasW/dNm3btoE8f6NsXjjgiPO/eHbp2hS5dvnncaacQyhJVLsJ9GjDOzCYBBwBr1N8uUuRqauDFF+GJJ+Cpp2DevLAOYOed4fvfh9GjYY89wrL77tCtW1EdmcrWNRruZvYQcBjQxcyWAlcBbQDc/VZgBjAMqAa+As7KV7Ei0gJr18Kjj8LUqSHU160LXRf77Qe/+AXsv39Y+vRRiCdANlfLjGpkuwPn5awiEcmdzZvh8cdh4kSYMSN0v/TqBWeeCYMHw2GHwY47xq5S8iDakL8ikkdffAG33hqWd98N/eE//SmcfDIceGA40SiJpnAXSZLPPoObboIbb4TPP4dDD4Vf/xqOOy6cEJWyoXAXSYKaGrj5ZvjlL2HNmhDmV14J++4buzKJROEuUuqefBLGjYPFi+Hoo+H662HPPWNXJZGp402kVK1dG/rRBw8OJ06nT4eZMxXsAujIXaQ0vfwyjBwJ778Pl1wC11wT7ugUSdGRu0gpcQ9XwBxySHj9t7+FbhgFu9SjcBcpFRs3wrnnwtix4fb/uXPhoINiVyVFSuEuUgrWrYMRI+DOO+GKK0L/eufOsauSIqY+d5Fi99lnMHQozJkDt90GY8bErkhKgMJdpJitWQNHHRVGaXz44XD9ukgWFO4ixWrt2nDEvmABVFbCD38YuyIpIQp3kWK0cSMcf3wYW33qVAW7NJnCXaTYuIcrYp56Cu65J4S8SBPpahmRYvO734WrYq68Ek4/PXY1UqIU7iLFZNYsuPRSOOWUMAiYSDMp3EWKxbJlcNpp8L3vwV13acx1aRF9ekSKQW0tjBoF//hHOIG63XaxK5ISpxOqIsXg2mvhr3+F++8Pk1GLtJCO3EVimzcPfvUrOOMMOPXU2NVIQijcRWKqqQmTVXfrBr//fexqJEHULSMS03XXwWuvhYHAOnWKXY0kiI7cRWJZuDBMXv2Tn+gOVMk5hbtIDO5w/vnQoQPccEPsaiSB1C0jEsPUqfD003DLLdClS+xqJIF05C5SaF9+CRddBHvtpbHZJW905C5SaNdfD0uXwqRJ0KpV7GokoXTkLlJIK1aEPvaTT4aDD45djSSYwl2kkK67DtavD3ekiuSRwl2kUN55J8yBOno0DBwYuxpJuKzC3cyGmNliM6s2s8sybO9rZs+Y2d/NbIGZDct9qSIlbvz40Md+1VWxK5Ey0Gi4m1kr4GZgKDAIGGVmg+o1uxKY4u57AyOBW3JdqEhJW7wYHnwwXNves2fsaqQMZHPkvj9Q7e7vuHsNMAkYUa+NA99JPd8R+Ch3JYokwG9+A+3ahUsgRQogm0shewEfpr1eChxQr83VwONmdj6wPXBkTqoTSYL33w9D+Y4dGwYIEymAXJ1QHQX82d17A8OA+8xsi+9tZmPMbI6ZzVm1alWO3lqkyE2YAGZwySWxK5Eykk24LwP6pL3unVqXbjQwBcDdXwTaA1vcU+3uE929wt0runbt2ryKRUrJ8uVwxx1hous+fRpvL5Ij2YT7bGCAmfU3s7aEE6bT6rX5ADgCwMz2IIS7Ds1FbrwRNm6Ey7a4yEwkrxoNd3evBcYBs4BFhKti3jCza8xseKrZRcC5ZjYfeAg40909X0WLlISvvgrXtR9/POy2W+xqpMxkNbaMu88AZtRbNz7t+UJA91KLpLvvPli9Gi64IHYlUoZ0h6pIPriHLpl99tEYMhKFRoUUyYcnnoBFi+Dee8OVMiIFpiN3kXz4wx+ge/cw+qNIBAp3kVyrroaZM8NNS+3axa5GypTCXSTXbr89DBB27rmxK5EypnAXyaWaGrj7bjj2WOjRI3Y1UsYU7iK59MgjsGqV5kaV6BTuIrk0cSL07QtHHRW7EilzCneRXFmyBJ58Es45RxNfS3QKd5FcueMO2GYbOPvs2JWIKNxFcqK2NpxIPeYY6NUrdjUiCneRnHj8cVixAs46K3YlIoDCXSQ37r0XdtoJhmlueCkOCneRllqzBqqqYORIaNs2djUigMJdpOX+8hfYsAF+8pPYlYh8TeEu0lL33gsDB8L++8euRORrCneRlnjvPXjuuTBHqob2lSKicBdpifvvD4+nnhq3DpF6FO4izeUeumQOPRT69Ytdjci3KNxFmmvePHj7bR21S1FSuIs01+TJ0Lo1nHBC7EpEtqBwF2kOd5gyBQYPhs6dY1cjsgWFu0hzvPwyvP8+nHJK7EpEMlK4izTHlCnhbtQRI2JXIpKRwl2kqTZvDuF+9NHQsWPsakQyUriLNNULL8CyZeqSkaKmcBdpqsmToX17GD48diUiDVK4izTFpk1hoLBhw6BDh9jViDRI4S7SFH/9Kyxfri4ZKXoKd5GmmDoVtt0WfvjD2JWIbJXCXSRbmzfDI4/AkCGw/faxqxHZqqzC3cyGmNliM6s2s8saaHOymS00szfM7MHclilSBObODVfJHH987EpEGtW6sQZm1gq4GRgMLAVmm9k0d1+Y1mYAcDlwsLuvNrOd81WwSDRVVdCqlbpkpCRkc+S+P1Dt7u+4ew0wCah/W965wM3uvhrA3VfmtkyRIlBVFYb31VgyUgKyCfdewIdpr5em1qUbCAw0s+fN7CUzG5LpG5nZGDObY2ZzVq1a1byKRWJ46y1YuBCOOy52JSJZydUJ1dbAAOAwYBRwu5ltcV+2u0909wp3r+jatWuO3lqkAKqqwqPGkpESkU24LwP6pL3unVqXbikwzd03uvu7wFuEsBdJhqoq2Gcf6Ns3diUiWckm3GcDA8ysv5m1BUYC0+q1qSIctWNmXQjdNO/ksE6ReD7+GF56SVfJSElpNNzdvRYYB8wCFgFT3P0NM7vGzOoG15gFfGpmC4FngEvc/dN8FS1SUI8+GibnUH+7lBBz9yhvXFFR4XPmzIny3iJNMmxYOKH69ttgFrsaKXNmNtfdKxprpztURbbmiy/gqafCUbuCXUqIwl1ka2bOhJoadclIyVG4i2xNVRXsvDMceGDsSkSaROEu0pANG+Cxx8KkHK1axa5GpEkU7iINefZZWLtWXTJSkhTuIg2pqgpD+x5xROxKRJpM4S6SSd3Y7UOHhvlSRUqMwl0kk1deCXemqktGSpTCXSSTqipo3Vpjt0vJUriLZFJVBYcfDh23GNxUpCQo3EXqe/NNWLxYXTJS0hTuIvVVVobH4cO33k6kiCncReqrqoL99oPevWNXItJsCneRdMuWhStlNHa7lDiFu0i6aal5aNTfLiVO4S6SrqoKBg6E7343diUiLaJwF6nz+efw9NMau10SQeEuUmfGDKitVZeMJILCXaROVRV07w4HHBC7EpEWU7iLAKxfH2ZdGjECttE/Cyl9+hSLQOhr//JLdclIYijcRSDcldqhQxhPRiQBFO4imzaF69uHDYN27WJXI5ITCneRF1+ElSt1V6okisJdpLIS2rYNsy6JJITCXcqbewj3I4+E73wndjUiOaNwl/L22mvw7ru6SkYSR+Eu5a2yMgw1oLHbJWEU7lLeKivh4IOhW7fYlYjklMJdyte778L8+bpKRhIpq3A3syFmttjMqs3ssq20O8HM3MwqcleiSJ5UVYVH9bdLAjUa7mbWCrgZGAoMAkaZ2aAM7ToAvwBeznWRInlRWQl77gm77hq7EpGcy+bIfX+g2t3fcfcaYBIwIkO7a4HfAutzWJ9IfqxcCX/7m7pkJLGyCfdewIdpr5em1n3NzPYB+rj7YzmsTSR/Hn00XOOuLhlJqBafUDWzbYAbgIuyaDvGzOaY2ZxVq1a19K1Fmq+yEvr1g+9/P3YlInmRTbgvA/qkve6dWlenA/A94Fkzew/4ATAt00lVd5/o7hXuXtG1a9fmVy3SEmvXwpNPhi4ZTacnCZVNuM8GBphZfzNrC4wEptVtdPc17t7F3fu5ez/gJWC4u8/JS8UiLTV9OmzYAD/6UexKRPKm0XB391pgHDALWARMcfc3zOwaM9NtfVJ6pkyBnj3hoINiVyKSN62zaeTuM4AZ9daNb6DtYS0vSyRP1q4N0+n99KeaTk8STZ9uKS91XTInnRS7EpG8UrhLeVGXjJQJhbuUj7oumRNPVJeMJJ4+4VI+Hn1UXTJSNhTuUj6mTlWXjJQNhbuUB3XJSJnRp1zKg7pkpMwo3KU8TJ6sLhkpKwp3Sb5PPw1dMqNGqUtGyoY+6ZJ8U6fCxo1w6qmxKxEpGIW7JN8DD8CgQbDXXrErESkYhbsk23vvhRmXTjtNw/tKWVG4S7I9+GB4/PGP49YhUmAKd0kud7j/fjjkENhll9jViBSUwl2S69VXYdEinUiVsqRwl+S6/35o00Y3LklZUrhLMm3cGML9mGOgc+fY1YgUnMJdkmn6dFi5EkaPjl2JSBQKd0mmO+8Mww0cfXTsSkSiULhL8ixbFoYbOPNMaJ3VNMEiiaNwl+S55x7YvBnOOit2JSLRKNwlWTZvhrvugkMPhd12i12NSDQKd0mW556DJUt0IlXKnsJdkuXWW6FjRzjhhNiViESlcJfk+OgjePhhOPts2G672NWIRKVwl+SYOBE2bYKxY2NXIhKdwl2SoaYGbrsNhg7ViVQRFO6SFJWVsHw5jBsXuxKRoqBwl2T44x/hn/5Jd6SKpCjcpfTNnh1mWzrvPE2ALZKifwlS+iZMgB13hHPOiV2JSNHIKtzNbIiZLTazajO7LMP2C81soZktMLOnzEzT3khhVFeHyx/HjoUOHWJXI1I0Gg13M2sF3AwMBQYBo8xsUL1mfwcq3H1P4C/A9bkuVCSjG24Ig4P9+7/HrkSkqGRz5L4/UO3u77h7DTAJGJHewN2fcfevUi9fAnrntkyRDFatgrvvhtNPhx49YlcjUlSyCfdewIdpr5em1jVkNDAz0wYzG2Nmc8xszqpVq7KvUiSTP/wBNmyAiy6KXYlI0cnpCVUzOw2oACZk2u7uE929wt0runbtmsu3lnLzySdw001w8snw3e/Grkak6GQzk8EyoE/a696pdd9iZkcCVwCHuvuG3JQn0oDf/Q7WrYOrropdiUhRyubIfTYwwMz6m1lbYCQwLb2Bme0N3AYMd/eVuS9TJM3KlfA//wM//jHssUfsakSKUqPh7u61wDhgFrAImOLub5jZNWY2PNVsArADMNXMXjWzaQ18O5GWmzAB1q+H8eNjVyJStLKaYNLdZwAz6q0bn/b8yBzXJZLZBx+EoQZOOw0GDoxdjUjR0h2qUlr+67/C43XXxa1DpMgp3KV0vPIKPPBAuPSxT5/G24uUMYW7lAZ3uPBC6NYNLr00djUiRS+rPneR6CZNguefh9tv1xgyIlnQkbsUv9Wr4T/+Ayoq4KyzYlcjUhJ05C7F7/LLwzgyM2dCq1axqxEpCTpyl+L2wgthbtQLLoC9945djUjJULhL8Vq3Ds48E/r2hV/+MnY1IiVF3TJSvC6+OEzG8dRTsMMOsasRKSk6cpfi9NhjcOut4fLHww+PXY1IyVG4S/FZujRcFfMv/wK/+lXsakRKksJdisuGDXDSSfCPf8DkydCuXeyKREqS+tyluFx4Ibz0EkydquF8RVpAR+5SPG65JSwXXwwnnhi7GpGSpnCX4vDII3D++XDssfDrX8euRqTkKdwlvuefh1GjYN994aGHoLV6C0VaSuEucb3wAgwdCr17w/TpsP32sSsSSQSFu8Tz4oswZAh07w7PPAM77xy7IpHEULhLHNOnw5FHfhPsvXrFrkgkURTuUnh/+hOMGBEudXzuOQW7SB4o3KVw1q+HsWPh5z+HYcPgf/83HLmLSM4p3KUwliyBgw4K48VccglUVurkqUge6Zozya/Nm0OgX3optGkDjz4KxxwTuyqRxNORu+TPwoVw2GFw3nlw4IHw978r2EUKROEuubdiBfzsZ2FUx9deg7vvhlmzYJddYlcmUjbULSO589FH8Pvfh6thNmyAcePgv/8bunSJXZlI2VG4S8u4w9y5oV/9vvugthZOOSVMizdgQOzqRMqWwl2aZ8WKMCzvHXfA/Pmw7bZwzjlw0UWw666xqxMpewp3yY47vPVWuLO0sjKMCeMeBvu65ZYw8FfHjrGrFJEUhbtkVlsLixeH8V+eeQaefTb0qQPstRdcfTUcf3w4aSoiRSercDezIcCNQCvgDnf/Tb3t7YB7gX2BT4FT3P293JYqebF5cwjtt98OR+bz58O8ebBgQZjqDqBbtzBJ9eGHw+DB0L9/3JpFpFGNhruZtQJuBgYDS4HZZjbN3RemNRsNrHb33cxsJPBb4JR8FCxNsHEjfP45LF8eArz+smQJVFd/E+IA3/kO7L13uJRx771hv/1g993BLN7PISJNls2R+/5Atbu/A2Bmk4ARQHq4jwCuTj3/C/BHMzN39xzWWrrcQzdHbW0I3Lrn9V/X31ZTA199FcI3/bH+ui+/DCG+evW3l3XrMtfTuTP06BFOfA4eDLvtFq5sGTAA+vSBbXT7g0ipyybcewEfpr1eChzQUBt3rzWzNcBOwCe5KPJb7roLJkwIgRne8NtLtusK9fW1taHrIx+23TYs228fTmZ26hQCu1Onb5aOHcPgXD17hqVHD2jfPj/1iEjRKOgJVTMbA4wB6Nu3b/O+SZcu4SReXTeB2ZZLpvXZrst129atw9KmTebnW3vdpg1st10I8PTH7baDdu10hC0iDcom3JcBfdJe906ty9RmqZm1BnYknFj9FnefCEwEqKioaF6XzfDhYRERkQZlc+g3GxhgZv3NrC0wEphWr8004IzU8xOBp9XfLiIST6NH7qk+9HHALMKlkHe5+xtmdg0wx92nAXcC95lZNfAZ4Q+AiIhEklWfu7vPAGbUWzc+7fl64KTcliYiIs2lM3IiIgmkcBcRSSCFu4hIAincRUQSSOEuIpJAFutydDNbBbzfzC/vQj6GNmg51dU0qqvpirU21dU0LalrF3fv2lijaOHeEmY2x90rYtdRn+pqGtXVdMVam+pqmkLUpW4ZEZEEUriLiCRQqYb7xNgFNEB1NY3qarpirU11NU3e6yrJPncREdm6Uj1yFxGRrSjacDezk8zsDTPbbGYV9bZdbmbVZrbYzI5u4Ov7m9nLqXaTU8MV57rGyWb2amp5z8xebaDde2b2WqrdnFzXkeH9rjazZWm1DWug3ZDUPqw2s8sKUNcEM3vTzBaYWaWZdWygXUH2V2M/v5m1S/2Oq1OfpX75qiXtPfuY2TNmtjD1+f9FhjaHmdmatN/v+EzfKw+1bfX3YsFNqf21wMz2KUBNu6fth1fN7Aszu6Bem4LtLzO7y8xWmtnraes6m9kTZvZ26rFTA197RqrN22Z2RqY2TeLuRbkAewC7A88CFWnrBwHzgXZAf2AJ0CrD108BRqae3wqMzXO9/xcY38C294AuBdx3VwMXN9KmVWrf7Qq0Te3TQXmu6yigder5b4Hfxtpf2fz8wM+BW1PPRwKTC/C76wHsk3reAXgrQ12HAdML9XnK9vcCDANmAgb8AHi5wPW1ApYTrgOPsr+A/wPsA7yetu564LLU88syfe6BzsA7qcdOqeedWlJL0R65u/sid1+cYdMIYJK7b3D3d4FqwiTeXzMzA/6NMFk3wD3AcfmqNfV+JwMP5es98uDric/dvQaom/g8b9z9cXevTb18iTCrVyzZ/PwjCJ8dCJ+lI1K/67xx94/dfV7q+VpgEWGO4lIwArjXg5eAjmbWo4DvfwSwxN2be3Nki7n7c4Q5LdKlf44ayqKjgSfc/TN3Xw08AQxpSS1FG+5bkWnC7vof/p2Az9OCJFObXDoEWOHubzew3YHHzWxuah7ZQhiX+q/xXQ38NzCb/ZhPZxOO8jIpxP7K5uf/1sTvQN3E7wWR6gbaG3g5w+YDzWy+mc00s38uUEmN/V5if6ZG0vABVoz9Vaebu3+cer4c6JahTc73XUEnyK7PzJ4EumfYdIW7P1LoejLJssZRbP2o/V/dfZmZ7Qw8YWZvpv7C56Uu4E/AtYR/jNcSuozObsn75aKuuv1lZlcAtcADDXybnO+vUmNmOwAPAxe4+xf1Ns8jdD18mTqfUgUMKEBZRft7SZ1TGw5cnmFzrP21BXd3MyvIJYpRw93dj2zGl2UzYfenhP8Stk4dcWVqk5MaLUwI/iNg3618j2Wpx5VmVknoEmjRP4ps952Z3Q5Mz7Apm/2Y87rM7EzgGOAIT3U2ZvgeOd9fGeRs4vdcM7M2hGB/wN3/X/3t6WHv7jPM7BYz6+LueR1DJYvfS14+U1kaCsxz9xX1N8TaX2lWmFkPd/841U21MkObZYRzA3V6E843NlspdstMA0amrmToT/gL/Ep6g1RoPEOYrBvC5N35+p/AkcCb7r4000Yz297MOtQ9J5xUfD1T21yp1895fAPvl83E57muawjwn8Bwd/+qgTaF2l9FOfF7qk//TmCRu9/QQJvudX3/ZrY/4d9xXv/oZPl7mQacnrpq5gfAmrTuiHxr8H/PMfZXPemfo4ayaBZwlJl1SnWjHpVa13yFOIPcnIUQSkuBDcAKYFbatisIVzosBoamrZ8B9Ew935UQ+tXAVKBdnur8M/Czeut6AjPS6pifWt4gdE/ke9/dB7wGLEh9sHrUryv1ehjhaowlBaqrmtCv+GpqubV+XYXcX5l+fuAawh8fgPapz0516rO0awH20b8SutMWpO2nYcDP6j5nwLjUvplPODF9UAHqyvh7qVeXATen9udrpF3llufatieE9Y5p66LsL8IfmI+Bjan8Gk04T/MU8DbwJNA51bYCuCPta89OfdaqgbNaWovuUBURSaBS7JYREZFGKNxFRBJI4S4ikkAKdxGRBFK4i4gkkMJdRCSBFO4iIgmkcBcRSaD/DydAb7nqWwBcAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出 sigmoid 的图像\n", "\n", "plot_x = np.arange(-10, 10.01, 0.01)\n", "plot_y = sigmoid(plot_x)\n", "\n", "plt.plot(plot_x, plot_y, 'r')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "x_data = Variable(x_data)\n", "y_data = Variable(y_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在 PyTorch 当中,不需要我们自己写 Sigmoid 的函数,PyTorch 已经用底层的 C++ 语言为我们写好了一些常用的函数,不仅方便我们使用,同时速度上比我们自己实现的更快,稳定性更好\n", "\n", "通过导入 `torch.nn.functional` 来使用,下面就是使用方法" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "# 定义 logistic 回归模型\n", "w = Variable(torch.randn(2, 1), requires_grad=True) \n", "b = Variable(torch.zeros(1), requires_grad=True)\n", "\n", "def logistic_regression(x):\n", " return torch.sigmoid(torch.mm(x, w) + b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在更新之前,我们可以画出分类的效果" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3Xl8VNXd+PHPSUhIWAQJCEpWJlhFRJZAqYLsGaSKfSj2h1vhcaHS2traRX2BvqyV1vrz50+t9leX+oiCrV2ePg9VW3ZERJSgKGXRhyVAACEEBAIBspzfHzMZJmEmcye5+3zfr9e8mJlc5n7nzr3fe+45556jtNYIIYTwlzSnAxBCCGE+Se5CCOFDktyFEMKHJLkLIYQPSXIXQggfkuQuhBA+JMldCCF8SJK7EEL4kCR3IYTwoXZOrbh79+66sLDQqdULIYQnrV+//pDWukei5RxL7oWFhZSVlTm1eiGE8CSl1C4jy0m1jBBC+JAkdyGE8CFJ7kII4UOS3IUQwockuQshhA8lTO5KqZeVUgeVUv+K83ellHpGKbVNKfWpUmqw+WH62IIFUFgIaWmhfxcscDoiIYQPGCm5vwJMbOHv1wB9w4+ZwP9re1gpYsECmDkTdu0CrUP/zpwpCV4I0WYJk7vWehVwuIVFrgde1SFrga5KqQvNCtDXZs+GkyebvnfyZOh9IYRoAzPq3HsDe6JeV4TfO4dSaqZSqkwpVVZZWWnCqj1u9+7k3m8lqfnxLvntRGvZ2qCqtX5Ba12itS7p0SPh3bP+l5+f3PutIDU/3iW/nWgLM5L7XiAv6nVu+D2RyNy50KFD0/c6dAi9bxKp+Wmb5iXn737XvpK02347uYrwGK11wgdQCPwrzt++DvwDUMBw4EMjnzlkyBAttNbz52tdUKC1UqF/58839eOV0jpU7mv6UMrU1fjS/Plad+gQe/s1Pjp0MP0ni0j2t7NyV4q1Laz87iI+oEwbydsJF4A/APuBWkL16bcDdwF3hf+ugOeA7cBGoMTIilMyuVucyGMpKIidIAoKLF+158XbdnZty2R+O6uTr+xH7mFacrfqkXLJ3aGij5S4Wi9eydmuq6Bkfjurk69cAbqH0eQud6jaxaEK1JtvhhdegIICUCr07wsvhN4XLTParm1i+3cTyfx2Vne8sqHtX5hMkrtdbOr2GMvNN0N5OTQ0hP6VxG5MrPbu5kxu/z6H0d/O6uRrQ9u/MJkkd7tI0cdzYpWcZ81y51WQ1clXrgC9R4WqcOxXUlKiU2ompsZOy9FVMx06yBEiTLNgQaiWb/fuUJlh7lzZtfxIKbVea12SaDkpudtFij7CJPH6m0v1m4jm2ByqKenmm+WIE23S/AKw8a5VkF1LNCUldyE8xG13rYLcuepWUnIXwkMc7HQVk1xJuJeU3IXwELd1unLjlYQIkeQuhIe4rb+5264kxFmS3IXwEDM7XZlRV96tW3LvC/tInbsQHmNGpyupK/c/KbkLkUIaS+u33GJOXfnhOBNwxntf2EeSu/Ac6XrXOtEzO8WTbF252xp4xVmS3IVp7Ei6MvVc68Xq2dJcsknZbQ28iaRUwcDIuMBWPFJuPHefs2vceJk0ovUSjU/f2t/LgTloWsUvcxtgcDx3GThMmKKwMPblfkFBaJwTs6SlhQ7L5pQKjaki4ov3G0Hod/LjQGPRg6mlpUF9/bnLmL2PWk0GDhOGmHWZald/Z6njbb14VSizZoWe33qrv6oqmlfhxUrs4N8++ZLcU5iZ9dd2JV2v1fG6Saw+8tOnw7x5/mzDMNLGAObuo66q0zdSd2PFQ+rcnWdm/bWd9ZlW1/F6pQ7ZDH5uwzAyB66Z+6hdxwAyQbZIxOxJj/2QFP3S6GbE/Pnxk54fJr6Od+JKT7dmH7XrRGk0uUu1TAozuyrFD5NFOD0Qll2X9Y1VcvH4oQ0jXhXevHnW7KNuG2dHknsKk/rrczl5gFrdhz/6xDF9evz6aL/sA3ZPfua6xn4jxXsrHlIt4w5+qEoxk5N10FauO1Z1U7xHqu8DreW2Onfp5y5EFCfnMbeyD39Lfdyjea3Pt9vYMUm59HMXohWcnMfcyst6I9VKfqmOcZKb2p0kuQvRjFMHqJVtIPFOEOnp9p/EhD0kuQvhElZeNdjdc0Q4TybrEMJFzJiII97ngvX1wcI9JLkLkSKsOnEId5JqGSGE8CFJ7kKYyFUDR4mUJtUyQphEJp0WbiIldyFM4vS4NEJEM5TclVITlVKfKaW2KaXuj/H3fKXUCqXUx0qpT5VSk8wPVQh3c9vAUSK1JUzuSql04DngGqAfcKNSql+zxeYAf9JaDwKmAb81O1Ah3M51A0eJlGak5D4M2Ka13qG1PgP8Ebi+2TIaOC/8vAuwz7wQhaWkBdA0To2yKT+hiMVIcu8N7Il6XRF+L9rDwC1KqQrgbeD7pkSXauw+Sq0eYzYFZWeffZ6TY/0t/fITinjMalC9EXhFa50LTAJeU0qd89lKqZlKqTKlVFllZaVJq/YJJ45SaQE0TePPV1V19r2aGuvXKz+hiMdIct8L5EW9zg2/F+124E8AWuv3gSyge/MP0lq/oLUu0VqX9OjRo3UR+5UTR6m0AJrGqSQrP6E1/FDVZSS5rwP6KqWKlFKZhBpMFzZbZjcwDkApdSmh5C5F82Q4cZRKC6BpnEqy8hOazy9VXQmTu9a6DrgbWARsIdQrZpNS6hGl1OTwYj8G7lRKfQL8AZihnZoFxKucOEplnj3TOJVkvfATeq0U7JuqLiPTNVnxkGn2mjF7ji6j8+c5PM+ek6s3c912TbEWb91unSrRye3SWkrFnn5QKacjC8HgNHuS3N3ErKPUI0eU0wnR7HW7OcmaJdnv6OSctK3l9pgluXuJ2VnB7XtnmF8no/ar1pwQ3V4KjsXtZSNJ7l5hxZ7kkSPKyTA9solcpTUnRK+eRN18FWY0ucvAYU6zovXGI10onAzTI5vIVeL1/Nm1K35jqRcafGNx00TXrSXJ3WlW9KHzyBHlZJge2USu0tKJT8fpMmjlvLAiASPFeyseUi0TZtV1q5uvK6O4tbeMRzafrWLVIHqxysXrkDp3j0hU5y5ZxnaxfpLGOvpU/wmid8d4yV3aLawlyd1L4iVwtzfb+1S8iyn5CZryamOpk8woqxlN7iq0rP1KSkp0WVmZI+v2jMLCUEVmcwUFoVYeYYm0tFCaaon8BOdOKwihdgupU4/NrO2llFqvtS5JtJw0qLqZjArlCCM9ZuQnkMbSZNk9rIEkdzczq7+e1wb3cFisnjTNSZfJED90GbSL3WU1Se5uZkZ/Pb8McWej6BIphEql0fzeZVLKAtaw/d4KIxXzVjykQdWgtrbASKtXmzX/CWbN8l4HpmTGkZM2fGuYtW2R3jJCa+3Z++zd2gPUi8kvmZi9WhZw6/7SnJ29ZSS5+50Hj1Y3J1APbs6kYm6p/7pbk6ab9xcrGE3uUufudx68z97NkyV4sQNTMjG3VP/r1uYat+wvLbVVONKOYeQMYMVDSu428so1a5iba5L8XnI3MsSA276rG/aXlq4ezL6yQKplhFe5OYF6sQog2ZgbywJeGV7ADftLSzGYHZ/R5C7VMsJ13FyT5MUbd5KNubHvemNX0Obc1sffDftLS1VfjlXlGTkDWPGQkrtoicdqknzJS1cpTu8vbiy5S3IXiTl95PiY2zet2+NzC6lzl+TuPV4qvnmMbFprOHVCsmt+AKPJXUaFFC2TkSktI5vWfGaOVLlgQag75e7doXaGuXPd0bYio0KmIis603qkY7cXx0PxyKb1FLP6vPthSCZJ7n5h1d7ogZmkvXogemDTeo5ZJ0y33BjVFpLc/cKqvdEN/cwS8OqB6IFN6zlmnTD9cFUlyd0vrNobPdCx26sHogc2reeYdcL0w1WVJHe/sHJvdGJGhiQq0b18IMpkF+Yy64Tpi6sqI11qrHhIV0iT+alfXZLfxU9fXbiHW/v4I10hU5Bb+24lqxV9BP3y1YVIxGhXSEnuwn3S0kIF8OaUCtVfCJHCpJ+78C4vV6IL4RKS3IX7+KI1q2VevOlKtJ4Tv7ckdzeRIz6ksctDTs7Z97KznYvHZF696Uq0jlO/tyR3t5Aj/lw1NWefV1X5Znt49aYr0TpO/d6GkrtSaqJS6jOl1Dal1P1xlvmWUmqzUmqTUup1c8NMAalyxBu9OjFxe7jtgsirN12J1nHtZB1AOrAd6ANkAp8A/Zot0xf4GDg//PqCRJ8r/dybccNEkFZLpkO6SdvDjX3g3TAtnLCPm6fZGwZs01rv0FqfAf4IXN9smTuB57TWR8InjINtPemknFToIZJMadyk7eHGC6IUaC8WUZz6vY0k997AnqjXFeH3ol0MXKyUek8ptVYpNdGsAFNGKhzxyVyfmrQ93FgFImPKpBbHfu9ERXtgKvBS1OtbgWebLfMm8DcgAygidDLoGuOzZgJlQFl+fn7rrkn8zK33O5sl2etTE7aHVIEIv8HEapm9QF7U69zwe9EqgIVa61qt9U7gc0L18M1PJC9orUu01iU9evQwdvZJJX4fRSrZ0rgJ2yMVLoiEiMVIcl8H9FVKFSmlMoFpwMJmy/wXMBpAKdWdUDXNDhPjFH7gwPWpVIF4j9t6N3mWkeI9MIlQaXw7MDv83iPA5PBzBTwJbAY2AtMSfab0lvEgv1cbCVvF2p3c2LvJbZBRIYWpzJx5WKS8eLtTdnbofrXmZNLws2RUSGGuVgzDK0Q88XaneGRA0LNkVEhhLjf2KRSelexu46fbPewiyV0Ykwo3WQnbxNttcnKkd5NZJLkLY6RPoTBRvN3p6aeld5NZ2jkdgPCIxqNL5rITJki0O8lu1XbSoCqEEB4iDapCeJjcyCPaSpK78C6fZkCZt0WYQZK7sIfZidipDGjDCcWNwxQL75E6d2E9K+5udeKmKpvu0k1LC52vmpMbeQTIHarCTaxIxE5kQJtOKHIzsGiJNKgK97Di7lYnbqqy6S5duaVAmEGSu7CeFYnYiQzYrVvs900+ocgwxcIMktzN4tOeG6aIl4gnTWr9NrM7Ay5YAMeOnft+ZqYlJxS/z9sirCd17maQ4XATW7Cg6e2IkybBvHne2WbxKsJzcuDQIdvDEalLGlTtJC1gyfNaspQuLMIlpEHVTjIcbvLibZuqKndWacmomKaRGkx7SHI3gxz4yWtp27jxbh3pwmIKufvWPpLczSAHfvJa2jZuvOJJoS4sVpas5e5bGxmZaNWKh+8myDZ78uhUmIw6J6fpTMiNj/R0f39vF7N6gmqlYv/kSpnz+akAgxNkS3J3o1SZAj7W92z+8OP3jsXGk3lLqyooiP0zFBSYs26rP9+N6hvq9bFTx0z7PEnuXpZKR0B0pklPT53vHc3Gk3miVVldsk6VckvF0Qr98kcv62l/maZzfp2jf/iPH5r22ZLcvSwVr13nz49fevfz99ba1pN5olXZEYrXahyNxFtTW6MXb1usf7zox7r/b/trHkbzMLrXE730t//2bf3W52+ZFo8kdy9LpZK71omrZ/z6vRvZeDJPtCq/lKzNOoHE3x4NetPBTfrJNU/qifMn6uxHszUPozN/kanHzRunH1/9uN6wf4NuaGgw82tprSW5e5tfjjCj4p3M/P69G7mo5K518onRbSVxMw+feNsr/fw9kdL5V37zFf2Dt3+g3/r8LV19utr079OcJHevaX6EzJrlriPGSvGKk+Dv793IRXXuTn+eGcw6V9bW12qlGmLvmqpeP1/2vC4/Um7FV2iRJHcntLYI48YjxE6pVg0Vi0t6yyTLjT9dW2q5yo+U6+fLntdT3piiu/yqi6bLTtd9P0nudjOaoGMdWW48QuyU6ic3DzOjucDs81rcqpT0cz+7+nS1fuvzt/QP3v6B/spvvhKpaun9f3rr2/7rNn33Y2t0dnaDq3ZNSe52M1qZGSuJxauS8HsvkWhuq7i1i8e/t6FySQvf0Yrzekvt8x06NOhfPleuf73613rsvLE68xeZmofRWY9m6eBrQf3kmif1poObmjSEuu0nkuRuNyNFmJaKFKlccncjO45oH1yxJPwKCRaw6qJ1/vz4hxVddmoeRvf/bX/940U/1ou3LdY1tTVtW6GNJLnbzche2lLDoccPcl+xK+n6pDquxfNggu9oRS/Q03Wn9cqdKzVxG0Mb9N5je1u/AodJcrebkYTQ0o7utmu/VGZX0nXiZjW797ME39GsTb2tapt+9oNn9XWvX6c7/bJTqO68S7kfzp3nkOTuhEQHjg8uw1OCHffgt9S336rs48T+lyB7tzako6eO6r9t+Zue9eYs3efpPpGG0KKnivRdf79L/23L3/SLr5zw5eEmyd2tpITuflaW3BPdjWtl9nGiGmj+fK0zM5uuLzPznEbVRIdEfUO9Xrd3nX70nUf1yJdH6naPtNM8jO44t6O+7vXr9LMfPKs/P/T5OXeE+vFwM5rcZZo9IZqzck7ceNMLQmiM+LlzrRsj3ompAhcsgH//d6itPfteRgb8x38k/J77ju9j8fbFLN6+mCU7lnDoZGj6xUG9BhEMBAkWB7ky70oy0zOtid2lTJ1DVSk1EXgaSAde0lo/Fme5bwJ/AYZqrVvM3JLchas1n9DbrKTr5FysTsz1m8Q6T9WdYvXu1SzatohF2xex8eBGAHp27ElpoJRgIMiEwAQu6HiBNbF6hGnJXSmVDnwOTAAqgHXAjVrrzc2W6wy8BWQCd0tyF46yKjm3lZOTqVt5RRJPCyczXV/P1kNbWbR9EYu3L2Zl+Upq6mrITM9kRP4ISvuUEiwOMqDnANKUTBrXyGhyT1wpD18DFkW9fgB4IMZyTwFfB1YCJYk+N2Xr3IX13Nxw7XRsdldCx6nnr+zRSec9mRdpCL34Nxfr77/9ff33z/6uj58+bm1MsXioch6zGlSBqYSqYhpf3wo822yZwcBfw8/jJndgJlAGlOXn59uzJUTqcXv/cQ8lkraqe+1VXZfdvsnvUJ2Bvv1b2XrKG1P082XP651HdjobZLInXId/P9uSO6FJtlcChTpBco9+SMldWCYVJztxkd1f7tYvrn9RT/3TVN31sa76xinonV3Q9aCPXNBFb336IV1bX+t0mGclUxhw+spLG0/u7QxU8ewF8qJe54bfa9QZ6A+sVEoB9AIWKqUm6wT17kJYIj8/dr12fr79saSAk7Uneaf8HRZtDzWEbj20FYDenXsz5ZIpBK8Nct788aRld6Mr0NXZcM+1e7fx92fPbtpmAaHXs2e7o00nipHkvg7oq5QqIpTUpwE3Nf5Ra30U6N74Wim1EviJJHbhmLlzYzcczp3rXEw+orVm48GNLNq2iMU7FvPurnc5XX+arHZZjCoYxZ2D7yQYCNKvRz/CBT53S6YwkMyJwGEJk7vWuk4pdTewiFBXyJe11puUUo8QujxYaHWQQiSlsQTlxt4yHnXo5CGWbF8S6dmyv3o/AJf1uIzvDf0eweIgI/NHkp2R7XCkrZBMYcBDV4VyE5MQ4hy19bW8X/F+pM/5R/s/QqPplt2N8X3GEwwEKQ2UkntertOhmsNo11knupM2Y+pNTFaQ5C5M59a+7R6x48iOSDJfvnM5x88cJ12lMzx3eOSO0CEXDiE9Ld3pUJ3l8H4myV2kFheUqLzm+OnjrChfEUno249sB6Cwa2EomQeCjC0aS5esLg5HKqJJchepxck7Pz2iQTfw8f6PWbx9MYu2L2LNnjXUNtTSMaMjY4rGRKpa+nbr642G0BRlNLkb6S0jhPt5qBeDnb6o/iKSzJdsX0LlyUoABvYayL1fu5dgIDT4Vvt27R2OVJhNkrvwBw/1YrDS6brTrN69OpLQPznwCQAXdLyAYHGoqmV8n/H06tTL4UiF1SS5C39I0b7tWms+q/osksxXlq/kZO1JMtIyuCr/Kn417lcEA0Gu6HWFDL6VYiS5C39Iob7tX576kmU7lkXuCN19NFT11LdbX24beBvB4iCjC0fTKbOTw5EKJ0mDqhAuV99Qz7p96yJ3hH5Q8QH1up7OmZ0Z12dcpGdL0flFTocqbCANqkJ4WMWxikgXxaU7lnLk1BEUiqG9h/LAiAcIFgf5au+vkpGe4XSowqUkuQvhAidrT7Jq16pI6XxzZWgunIs6X8Q3LvkGwUCQcX3G0b1D9wSfJESIJHchHKC1ZlPlpkjpfNWuVZyuP0379PZcXXB1pO78sh6XSZ9z0SqS3IWwSdXJKpbsODv41r7j+wDo16Mf3x36XYKBICMLRtIho4O1gcgwDSlBkrsQFqmtr+WDvR9ESudl+8rQaM7POr/J4Ft5XfISf5hZmg/TsGtX6DVIgvcZ6S0jhIl2HtkZ6aK4fOdyjp0+RppKOzv4ViBIyUUlzg2+JcM0eJ70lhHCBtVnqlmxc0UkoW87vA2Agi4FTLtsGqWBUsb1GUfXLJfMPyTDNKQMSe5CJKFBN7Dhiw2RO0Lf2/0etQ21dMjowOjC0Xx/2PcJBoJcnHOxOxtCZZiGlCHJXYgEDlQfODv41o4lHDxxEIArel7Bj4b/iNJAKSPyR3hj8K02DNNQW1tLRUUFp06dsjBA0SgrK4vc3FwyMlp3L4MkdyGaOV13mvf2vBdJ6Bu+2ABAjw49mBCYEGkI9eTgW20YpqGiooLOnTtTWFjozqsSH9FaU1VVRUVFBUVFrbvzWJK7SHlaaz6v+jzSRXFl+UpO1J6gXVo7rsq7irlj5zKxeCIDew30x+BbN9/cqp4xp06dksRuE6UUOTk5VFZWtvozJLmLlHT01FGW7VwW6aa462ioHrq4WzHTr5hOsDjImMIxdG7f2eFI3UUSu33auq0luYuUUN9QT9m+skjpfG3F2sjgW2OLxnLfVfcRLA7S5/w+TocqTFJeXs6aNWu46aabANiwYQP79u1j0qRJACxcuJDNmzdz//33t3ldM2bM4Nprr2Xq1Knccccd3HvvvfTr16/Nn9sWktyFb1Ucq4jUmy/dsZTDNYdRKIZcNIT7R9xPMBBkeO5wGXzLp8rLy3n99debJPeysrJIcp88eTKTJ082fb0vvfSS6Z/ZGpLchW/U1NawateqSELfVLkJgAs7Xch1F19HMBBkQmCCDL7lYa+++ipPPPEESikGDBjAa6+91qTUDNCpUyeqq6u5//772bJlCwMHDuTGG2/kueeeo6amhtWrV/PAAw9QU1NDWVkZzz77LDNmzOC8886jrKyML774gscff5ypU6fS0NDA3XffzfLly8nLyyMjI4Pbbrstsq5YRo8ezRNPPEFJSQmdOnXinnvu4c033yQ7O5v//u//pmfPnlRWVnLXXXexO3x/wVNPPcVVV11l6raS5C48q3HwrcZkvmrXKk7VnaJ9entGFoxkxsAZBANB+l/QX+qKTfbDf/4w0ovILAN7DeSpiU/F/fumTZt49NFHWbNmDd27d+fw4cMtft5jjz3GE088wZtvvglAz549I8kc4JVXXmmy/P79+1m9ejVbt25l8uTJTJ06lf/8z/+kvLyczZs3c/DgQS699FJuu+02w9/pxIkTDB8+nLlz5/Kzn/2MF198kTlz5nDPPffwox/9iBEjRrB7926CwSBbtmwx/LlGSHIXnlJ1soqlO5ZG6s73Ht8LwCXdL+E7Q75DMBBkVOEo6wffErZbvnw5N9xwA927h668unXrZurnf+Mb3yAtLY1+/fpx4MABAFavXs0NN9xAWloavXr1YsyYMUl9ZmZmJtdeey0AQ4YMYcmSJQAsXbqUzZs3R5Y7duwY1dXVdOpk3uxZktyFq9U11LG2Ym1knPN1e9eh0XTN6tpk8K38LnKHpZ1aKmHbrV27djQ0NADQ0NDAmTNnWvU57dufvQnNrDG3MjIyIleN6enp1NXVAaE4165dS1ZWlinricUHnXaF35R/Wc7zZc8z5Y0p5Dyew8j/GMkvV/+SNJXGQ6MeYs1ta6j8aSV/vuHP3DH4DknsKWLs2LH8+c9/pqqqCiBSLVNYWMj69euBUA+Y2tpaADp37szx48cj/7/5ayOuuuoq/vrXv9LQ0MCBAwdYuXKlCd8ESktL+c1vfhN5vWGDuVVcICV34QLVZ6pZWb4y0uf8fw7/DwB55+XxrX7fIlgcZFzROM7PPt/hSIWTLrvsMmbPns2oUaNIT09n0KBBvPLKK9x5551cf/31XHHFFUycOJGOHTsCMGDAANLT07niiiuYMWMG06dP57HHHmPgwIE88MADhtb5zW9+k2XLltGvXz/y8vIYPHgwXbp0afN3eeaZZ/je977HgAEDqKur4+qrr+Z3v/tdmz83mgz5K2zXoBv49MCnkWS+evfqJoNvlfYpJVgc5Cs5X5GGUBfZsmULl156qdNh2K6xLryqqophw4bx3nvv0auXPUNPxNrmMuSvcJUD1QcisxAt2b6EAydCDVYDeg7gh8N/SDAQ9M7gWyKlXHvttXz55ZecOXOGBx980LbE3laS3IUlztSfYc2eNZHS+cdffAxA9w7dKQ2UUtqnlNJAKRd2vtDhSIVomVn17HaT5C5MobVm2+FtkUkrVuxcERl868q8K5k7di7BQJBBFw7yx+BbQricJHfRakdPHWX5zuWRhF7+ZTkAgfMDTL9iOqWBUsYUjeG89uc5G6gQKUiSuzCsvqGe9fvXR+4IfX/P+9TrejpldmJc0Th+euVPCQaCBLoFnA5ViJQnyV20aN/xfZF68yU7lkQG3xp84WDuu+o+SgOlXJl3pQy+JYTLGEruSqmJwNNAOvCS1vqxZn+/F7gDqAMqgdu01jEmahRuV1Nbw7u7343cEfqvg/8CoFenXlx38XWUBkqZ0GcCPTr2cDhSIURLEiZ3pVQ68BwwAagA1imlFmqtN0ct9jFQorU+qZSaBTwO/C8rAhbm0lqz5dCWSOn8nV3vcKruFJnpmYzMH8m3x3+bYHGQyy+4XPqci+QsWNCq6fzMMm/ePB599FEA5syZw/Tp021btxsYKbkPA7ZprXcAKKX+CFwPRJK71npF1PJrgVvMDFKY63DNYZbuWBqpO684VgHI4FvCRAsWNJ2Ie9eu0GuwJcEfPnyYn//855SVlaGUYsiQIUyePJnzz0+du5yNJPfewJ6o1xU1tK3aAAALLklEQVTAV1tY/nbgH7H+oJSaCcwEyM+X8UDsUtdQxwcVH0RGUly3bx0NuoEu7bswvs94Hrr6IYLFQRmjRZhn9uyzib3RyZOh99uQ3NetW8ftt9/Ohx9+SH19PcOGDeONN96gf//+TZZbtGgREyZMiIwcOWHCBP75z39y4403tnrdXmNqg6pS6hagBBgV6+9a6xeAFyA0/ICZ6xZN7fpyV6SL4rIdyzh6+ihpKo1hvYfx4NUPEgwEGdp7KO3SpE1dWCA8CYXh9w0aOnQokydPZs6cOdTU1HDLLbeck9gB9u7dS15eXuR1bm4ue/fubdO6vcbIkb0XyIt6nRt+rwml1HhgNjBKa33anPCEUSfOnOCdXe9E6s4/q/oMCA2+dUO/GygNlDK+z3gZfEvYIz8/VBUT6/02euihhxg6dChZWVk888wzbf48vzKS3NcBfZVSRYSS+jTgpugFlFKDgOeBiVrrg6ZHKc6htQ4NvrX97OBbZ+rPkN0um1GFo7ir5C6CgSCXdL9EGkKF/ebObVrnDtChQ+j9NqqqqqK6upra2lpOnToVGQUyWu/evZsMG1BRUcHo0aPbvG4vMTQqpFJqEvAUoa6QL2ut5yqlHgHKtNYLlVJLgcuB/eH/sltr3eLMszIqZPIOnjjIku1LInXnjYNv9b+gP8FAkGAgyMiCkWS1s24CAJG6kh4V0qLeMpMnT2batGns3LmT/fv3R6bNi3b48GGGDBnCRx99BMDgwYNZv3696bM3Wc3yUSG11m8Dbzd776Go5+ONhSqScab+DO/veT9SOv9of2hHzcnOYUJgQmQWoos6X+RwpELEcPPNpveMefXVV8nIyOCmm26ivr6eK6+8kuXLlzN27Ngmy3Xr1o0HH3yQoUOHAqGqHK8l9raS8dxdZtvhbZF68xXlK6g+U027tHYMzx3OxMBEgsVBBl84WAbfErZL1fHcnSTjuXvYsdPHQoNvhe8I3XFkBwBFXYu45fJbCBYHGVs0VgbfEkIkRZK7zRp0Ax/t/yhSOn+/4n3qGurolNmJMYVjuHf4vQSLgxR3K3Y6VCFcb+PGjdx6661N3mvfvj0ffPCBQxG5hyR3G+w7vi/SELpkxxIOnTwEwKBeg/jJ135CsDjIlXlXkpme6XCkQnjL5Zdfbsnk0n4gyd0Cp+pO8e6udyO39288uBGAnh17ck3xNQQDQSYEJnBBxwscjlQI4VeS3E2gtWbroa2RXi3vlL9DTV0NmemZjMgfwa/H/5rSQCkDeg6QhlAhhC0kubfSkZojLNu5LFJ3vudYaPidi3Mu5o7BdxAMBBldOJqOmefeYCGEEFaT5G5QXUMd6/aui5TOP9z7YWTwrXF9xjHn6jmUBkop7FrodKhCCCHJvSV7ju6JJPOlO5by5akvUSiG9h7K7JGzCQaCfDX3qzL4lhAxODycOxMnTmTt2rWMGDGCN998074Vu4RkpSgnzpxg1a5VkYS+9dBWAHp37s2US6ZEBt/K6ZDjcKRCuJvDw7kD8NOf/pSTJ0/y/PPP27NCl0np5K61ZuPBjZF683d3v8uZ+jNktctiVMEoZg6eSWmglH49+sngW0IkwaLh3A2P5w4wbty4JoOHpZqUS+6VJypZsuPs4FtfVH8BhAbfunvo3QSLg4zMH0l2RrbDkQrhXRYN5254PHeRAsm9tr6W9yvej5TOP9r/ERpNt+xuTOhzdvCt3uf1djpUIXzDwuHcZTx3g3yZ3Lcf3h4pmS/fuZzjZ46TrtIZnjucn4/+OcHiIEMuHEJ6WrrToQrhSxYO525oPHfhk+R+/PRxVpSviJTOtx/ZDkBh10JuuvwmSgOljCsaR5esLg5HKkRqaKxXt6K3zHe+8x1+8YtfsHPnTu67776Y47kLjyb3Bt3Ax/s/jtze/96e96hrqKNjRkfGFI3hnq/eQ7A4SN9ufaUhVAiHWDCcu+Hx3AFGjhzJ1q1bqa6uJjc3l9///vcEg0FzA3Ixz43n/vuPfs8Dyx6g8mQlEBp8qzRQSjAQGnyrfbv2ZocqhEDGc3dCSo3nflHniwgWh6aUm9BnAj079XQ6JCGEcB3PJfdr+l7DNX2vcToMIYQLyHju8XkuuQshRCMZzz0+GX9WCGGYU210qait21qSuxDCkKysLKqqqiTB20BrTVVVFVlZWa3+DKmWEUIYkpubS0VFBZWVlU6HkhKysrLIzc1t9f+X5C6EMCQjI4OioiKnwxAGSbWMEEL4kCR3IYTwIUnuQgjhQ44NP6CUqgRiDApqSHfgkInhmEXiSo7ElTy3xiZxJactcRVorXskWsix5N4WSqkyI2Mr2E3iSo7ElTy3xiZxJceOuKRaRgghfEiSuxBC+JBXk/sLTgcQh8SVHIkreW6NTeJKjuVxebLOXQghRMu8WnIXQgjRAlcnd6XURKXUZ0qpbUqp+2P8/V6l1Gal1KdKqWVKqQKXxHWXUmqjUmqDUmq1UqqfG+KKWu6bSimtlLKlF4GB7TVDKVUZ3l4blFJ3uCGu8DLfCu9jm5RSr7shLqXU/43aVp8rpb50SVz5SqkVSqmPw8fkJJfEVRDOD58qpVYqpVo/YEtycb2slDqolPpXnL8rpdQz4bg/VUoNNjUArbUrH0A6sB3oA2QCnwD9mi0zBugQfj4LeMMlcZ0X9Xwy8E83xBVerjOwClgLlLghLmAG8KwL96++wMfA+eHXF7ghrmbLfx942Q1xEapHnhV+3g8od0lcfwamh5+PBV6zaR+7GhgM/CvO3ycB/wAUMBz4wMz1u7nkPgzYprXeobU+A/wRuD56Aa31Cq31yfDLtYAdZ2QjcR2LetkRsKNhI2FcYb8Afg2csiGmZOKym5G47gSe01ofAdBaH3RJXNFuBP7gkrg0cF74eRdgn0vi6gcsDz9fEePvltBarwIOt7DI9cCrOmQt0FUpdaFZ63dzcu8N7Il6XRF+L57bCZ0FrWYoLqXU95RS24HHgR+4Ia7wZV+e1votG+IxHFfYN8OXpn9RSuW5JK6LgYuVUu8ppdYqpSa6JC4gVN0AFHE2cTkd18PALUqpCuBtQlcVbojrE2BK+Pm/AZ2VUjk2xJZIsjkuKW5O7oYppW4BSoD/7XQsjbTWz2mtA8B9wByn41FKpQFPAj92OpYY/g4Uaq0HAEuAeQ7H06gdoaqZ0YRKyC8qpbo6GlFT04C/aK3rnQ4k7EbgFa11LqEqh9fC+53TfgKMUkp9DIwC9gJu2WaWccOGj2cvEF2Cyw2/14RSajwwG5istT7tlrii/BH4hqURhSSKqzPQH1iplConVMe30IZG1YTbS2tdFfXbvQQMsTgmQ3ERKkkt1FrXaq13Ap8TSvZOx9VoGvZUyYCxuG4H/gSgtX4fyCI0hoqjcWmt92mtp2itBxHKFWitbWmETiDZXJIcOxoWWtkY0Q7YQeiys7Gh5LJmywwi1JjS12Vx9Y16fh1Q5oa4mi2/EnsaVI1srwujnv8bsNYlcU0E5oWfdyd0CZ3jdFzh5S4Bygnfq+KS7fUPYEb4+aWE6twtjc9gXN2BtPDzucAjdmyz8PoKid+g+nWaNqh+aOq67fqSrdwwkwiVlrYDs8PvPUKolA6wFDgAbAg/FrokrqeBTeGYVrSUZO2Mq9mytiR3g9vrV+Ht9Ul4e13ikrgUoaqszcBGYJob4gq/fhh4zI54kthe/YD3wr/jBqDUJXFNBf4nvMxLQHub4voDsB+oJXQVeDtwF3BX1P71XDjujWYfj3KHqhBC+JCb69yFEEK0kiR3IYTwIUnuQgjhQ5LchRDChyS5CyGED0lyF0IIH5LkLoQQPiTJXQghfOj/A112Rtzml8WtAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出参数更新之前的结果\n", "w0 = w[0].data[0]\n", "w1 = w[1].data[0]\n", "b0 = b.data[0]\n", "\n", "plot_x = np.arange(0.2, 1, 0.01)\n", "plot_y = (-w0.numpy() * plot_x - b0.numpy()) / w1.numpy()\n", "\n", "plt.plot(plot_x, plot_y, 'g', label='cutting line')\n", "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n", "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n", "plt.legend(loc='best')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到分类效果基本是混乱的,我们来计算一下 loss,公式如下\n", "\n", "$$\n", "loss = -(y * log(\\hat{y}) + (1 - y) * log(1 - \\hat{y}))\n", "$$" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "# 计算loss\n", "def binary_loss(y_pred, y):\n", " logits = (y * y_pred.clamp(1e-12).log() + (1 - y) * (1 - y_pred).clamp(1e-12).log()).mean()\n", " return -logits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "注意到其中使用 `.clamp`,这是[文档](http://pytorch.org/docs/0.3.0/torch.html?highlight=clamp#torch.clamp)的内容,查看一下,并且思考一下这里是否一定要使用这个函数,如果不使用会出现什么样的结果\n", "\n", "**提示:查看一个 log 函数的图像**" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.7911, grad_fn=)\n" ] } ], "source": [ "y_pred = logistic_regression(x_data)\n", "loss = binary_loss(y_pred, y_data)\n", "print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "得到 loss 之后,我们还是使用梯度下降法更新参数,这里可以使用自动求导来直接得到参数的导数,感兴趣的同学可以去手动推导一下导数的公式" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.7801, grad_fn=)\n" ] } ], "source": [ "# 自动求导并更新参数\n", "loss.backward()\n", "w.data = w.data - 0.1 * w.grad.data\n", "b.data = b.data - 0.1 * b.grad.data\n", "\n", "# 算出一次更新之后的loss\n", "y_pred = logistic_regression(x_data)\n", "loss = binary_loss(y_pred, y_data)\n", "print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上面的参数更新方式其实是繁琐的重复操作,如果我们的参数很多,比如有 100 个,那么我们需要写 100 行来更新参数,为了方便,我们可以写成一个函数来更新,其实 PyTorch 已经为我们封装了一个函数来做这件事,这就是 PyTorch 中的优化器 `torch.optim`\n", "\n", "使用 `torch.optim` 需要另外一个数据类型,就是 `nn.Parameter`,这个本质上和 Variable 是一样的,只不过 `nn.Parameter` 默认是要求梯度的,而 Variable 默认是不求梯度的\n", "\n", "使用 `torch.optim.SGD` 可以使用梯度下降法来更新参数,PyTorch 中的优化器有更多的优化算法,在本章后面的课程我们会更加详细的介绍\n", "\n", "将参数 w 和 b 放到 `torch.optim.SGD` 中之后,说明一下学习率的大小,就可以使用 `optimizer.step()` 来更新参数了,比如下面我们将参数传入优化器,学习率设置为 1.0" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "# 使用 torch.optim 更新参数\n", "from torch import nn\n", "w = nn.Parameter(torch.randn(2, 1))\n", "b = nn.Parameter(torch.zeros(1))\n", "\n", "def logistic_regression(x):\n", " return torch.sigmoid(torch.mm(x, w) + b)\n", "\n", "optimizer = torch.optim.SGD([w, b], lr=1.)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:15: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n", " from ipykernel import kernelapp as app\n", "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:17: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch: 200, Loss: 0.39010, Acc: 0.00000\n", "epoch: 400, Loss: 0.32184, Acc: 0.00000\n", "epoch: 600, Loss: 0.28917, Acc: 0.00000\n", "epoch: 800, Loss: 0.26983, Acc: 0.00000\n", "epoch: 1000, Loss: 0.25700, Acc: 0.00000\n", "\n", "During Time: 0.248 s\n" ] } ], "source": [ "# 进行 1000 次更新\n", "import time\n", "\n", "start = time.time()\n", "for e in range(1000):\n", " # 前向传播\n", " y_pred = logistic_regression(x_data)\n", " loss = binary_loss(y_pred, y_data) # 计算 loss\n", " # 反向传播\n", " optimizer.zero_grad() # 使用优化器将梯度归 0\n", " loss.backward()\n", " optimizer.step() # 使用优化器来更新参数\n", " # 计算正确率\n", " mask = y_pred.ge(0.5).float()\n", " acc = (mask == y_data).sum().data[0] / y_data.shape[0]\n", " if (e + 1) % 200 == 0:\n", " print('epoch: {}, Loss: {:.5f}, Acc: {:.5f}'.format(e+1, loss.data[0], acc))\n", "during = time.time() - start\n", "print()\n", "print('During Time: {:.3f} s'.format(during))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到使用优化器之后更新参数非常简单,只需要在自动求导之前使用**`optimizer.zero_grad()`** 来归 0 梯度,然后使用 **`optimizer.step()`**来更新参数就可以了,非常简便\n", "\n", "同时经过了 1000 次更新,loss 也降得比较低了" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们画出更新之后的结果" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3XlcFdX7B/DPAVncFzAtUSCzEhRQyTVzSUvL8KtpZViaKYqmv+qbZdqCW5nZYinuu5hrX9Pccsl9SdwVy9TcNQhMJUWW+/z+uGCIXJgLM3Nm5j7v1+u+uMvcmecOM8+cOXPmHEFEYIwxZi1usgNgjDGmPk7ujDFmQZzcGWPMgji5M8aYBXFyZ4wxC+LkzhhjFsTJnTHGLIiTO2OMWRAnd8YYs6ASshbs6+tLAQEBshbPGGOmtG/fvr+IqHJh00lL7gEBAYiPj5e1eMYYMyUhxFkl03G1DGOMWRAnd8YYsyBO7owxZkGc3BljzII4uTPGmAVxcmeMMQvi5M4YYxZkuuT+61+/4oNNHyAtM012KIwxZlimS+4rf1uJ0dtGo/6U+th9YbfscBhjzJBMl9wHNxuMNZFrkJqeimYzm+Gdn97BzYybssNijDFDMV1yB4B2D7XD0f5H0ad+H3yx6wuETg7F1rNbZYfFGGOGYcrkDgDlvMphcofJ2PjqRmTZstBidgsMXD0QqempskNjjDHpTJvcc7QObI3D0YcxqOEgTNw7EXUn1cXG0xtlh8UYY1KZPrkDQBnPMhjffjy2vrYVnu6eaDOvDfqu7Itraddkh8YYY1JYIrnneLzG4zjY9yAGNx2M6Qemo86kOljz+xrZYTHGmO4sldwBoKRHSYxtOxa7Xt+Fcl7l8MyCZ9BjeQ+k3EqRHRpjjOnGcsk9R8NqDbE/aj+GNR+GuMNxCI4NxvJfl8sOizHGdGHZ5A4AXiW8MKr1KOztsxdVSldBp0Wd0G1ZNyT9kyQ7NMYY05Slk3uOevfXw94+ezGy1UgsS1iGoNggLDq6CEQkOzTGGNOESyR3APBw98AHT3yA/X33I7BCIF5a9hKeX/w8rqRekR0aY4yprtDkLoSYKYRIFEIcdfC5EEJ8I4Q4KYQ4LISor36Y6qlzXx3sfH0nPmvzGVb/vhpBE4Mw99BceaX4uDggIABwc7P/jYuTEwdjzFKUlNxnA2hXwOftAdTKfkQBmFT8sLRVwq0E3m32Lg71O4TalWujx/Ie6PBdB1y4fkHfQOLigKgo4OxZgMj+NyqKEzxjrNgKTe5EtBVAQe0IOwKYS3a7AVQQQtyvVoBaesT3EWztuRXj243H5jObERwbjGn7pulXih82DLiZp9Ozmzft7zPGWDGoUedeDcD5XK8vZL93DyFElBAiXggRn5RkjBYr7m7uGNRoEA73O4z699dH1I9RaDuvLf64+of2Cz93zrn3mcvhWjtWVLpeUCWiqUQUTkThlStX1nPRhapZqSY2vroRk56dhD0X96DupLqY8MsE2Mim3UJr1HDu/SLiBGFOXGvHikON5H4RQPVcr/2y3zMdN+GGfuH9cKz/MTxe43EMXDMQLWe3xO/Jv2uzwNGjgVKl7n6vVCn7+yrhBKEuPQ+UXGvHioWICn0ACABw1MFnzwJYA0AAaAzgFyXzbNCgARmZzWajWQdmUflPy5P3KG8at2McZWZlqr+g+fOJ/P2JhLD/nT9f1dn7+xPZ0/rdD39/VRdjWbn/PT4+RJ6ed6/HUqVU/5fdIUT+/zshtFleYTTeVJlCAOJJSd4udALgOwCXAWTAXp/+OoB+APplfy4ATARwCsARAOFKFmz05J7j4vWL9NyC5wgxoEbTGtGxxGOyQ3KK0RKEmcyfb0/e+a0/PQ6Uzh6YtUy++a0LLQ9szDHVkrtWD7MkdyJ7KT7ucBxV+qwSeY70pE+2fkIZWRnOz0hC0YdL7kXnaN3pdaB0JqFqnXx5OzIOTu4auHLjCnVZ3IUQA6o/pT4dunJI+ZclFX24xFV0js569ExwSssDWidfPgM0Dk7uGlpybAnd9/l9VGJECfr454/pdubtwr8ksejDdaVFo6TkbpQDpdbJl0vuxqE0ubtM3zJq6hLUBQn9E/Bi8IsYvmU4wqeGY9+lfQV/SWKb9shI4MwZwGaz/42M1HyRlpBfYyYPD8DHBxAC8PcHpk41xvrUulWtDg27mMo4uReRTykfzO88HyteWoHkW8loNL0Rhm4cirTMtPy/oFObdqaeyEh78vb3/zeZz5oF/PWX8Q6UWiff/NaFUQ5szAElxXstHmaulsnr6q2r1Gt5L0IM6NEJj9LOczvvnYgrv5nGuPrNNYCrZfRTwbsCZnScgXXd1+Fmxk00m9kMb697Gzczct2BwkUfphJHN1Jx9RvLTdgPBPoLDw+n+Ph4KcvW0o3bN/DehvcwKX4SalasiRkRM9AioIXssJhF5NxxnPvO1VKluJzgSoQQ+4govLDpuOSusrJeZRH7bCx+7vEzCISWc1piwKoBuHH7huzQmAUYsUsC7rvImDi5a6RlQEsc7ncYbzZ6E5PiJ6HupLpYf2q97LCYyRmtI1Huu8i4OLlrqLRnaXzV7its77Ud3iW88dT8p9BnRR9cS7smOzRmUkZrdGXEMwlmx8ldB02rN8WBvgfwXrP3MPPgTATHBmPViVWyw2ImZLT25kY7k2D/4uSuk5IeJTGmzRjsfn03KpasiA7fdcAr/3sFKbcKGuSKsbup2ehKjbrySpWce5/ph1vLSHA78zY+2fYJPtn+CXxK+iD22Vh0rt1ZdljMhajV6sbXF0hOvvd9Hx/7zV5MfdxaxsC8SnhheKvh2NtnLx4o+wCeX/w8Xlz6IhL/SZQdGrO4nNJ69+7q1JWnODjxdPQ+0w8nd4nCqoZhT+89GN16NJb/uhzBscFYeHQhZJ1NmQU3vSua3C1bHHG2rtxoF3jZvzi5S+bh7oGhzYdif9R+PFjxQXRb1g2dFnXCpRuXZIdmSNz0rujya9mSl7NJ2WgXeAvjUgUDJX0UaPGwUt8yasnMyqRxO8aR9yhvqjCmAs06MItsNpvssBTTo28T7nq26Arrn76oXR2ZpU8bq3TvBIV9y/AFVQM6kXwCvVf0xrZz29DuoXaY0mEKapQ39nmuXrfFu7nZd8u8hLD3qcIcCwhwXCXj728vbVu5CwNHv9/f394Xj1nwBVUTe9jnYWzuuRnftv8W285uQ53YOpgSP8XQdfF63czCdbxF56gKJTra/vyVV6xXVZG7GsbRgc2ybfKVFO+1eHC1jDKnU05T6zmtCTGg1nNa06mUU6rOX61Tar2GYbPKqbUsef/f0dHWXZ+yBzjXCniYPeuw2Ww0NX4qlf2kLJUaXYrG7x5PWbasYs9XzUSpZ124Wep4zcDK1zBkDJOox7bJyd2Czv19jtrPb0+IATWb0Yx+++u3Ys1PzR3bSiVqVzp4WHng64IuIGvxv9VrH+DkblE2m41mH5hNFcZUIO9R3jR2+1jKzMos0rzU3rGtkBStdJAqzPz5RO7urldy1+q36bU8Tu4Wd+n6Jer4XUdCDKjhtIZ09M+jTs/DyqfkRSV7neh1gCyoPtoqBzO9D9R6nQVxcncBNpuNFh5ZSL5jfclzpCeN2jKK0jPTFX/flUqpSsmsptD6/5H7wOGoxO7ubq3/v55nk1xy5+SuusTURHpxyYuEGFC9yfXowOUDir9rhaoUNcksuWu5bKUtR6xQ1y4L17lzctfM9wnfU5XPq1CJESXow00fUlpGmuyQTEfm2YyWZw1KWo64epWcGozUWoZvYrKQTrU7IWFAAl6u+zJGbh2JBlMbYO/FvbLDMhU1+0t3lpY3aCm5UcfIfcKYRWSk/W5Xm83+V+Ydv5zcLaZSyUqY8585WPXyKvyd9jcaz2iM99a/h1sZt2SHZhqydlAtO+FydIBwd9f/IMb0wcndop6p9QyO9T+GXmG9MHbnWIRNCcOOcztkh8UKoOVZg6MDx5w5xihlMvVxcrew8t7lMS1iGta/sh63M2+j+azmeHPtm/gn/R/ZoTEHtDprkFndxOTgXiFdRGp6KoZsGIKJeyfiwYoPYvpz09EqsJXssBhjTuJeIdldyniWwYRnJmBzj80QEGg9tzWif4zG9dvXZYfGGNMAJ3cX0yKgBQ5HH8bbjd/GlH1TUCe2DtadXCc7LMaYyji5u6BSHqXwxdNfYEevHSjtWRrt4tqh1w+9cPXWVdmhmZ5LDePGDI2TuwtrUr0JDvQ9gPcffx9zD81FcGwwVv62UnZYpsXjuzIj4eTu4rxLeOOTJz/Bnt574FvKFxELIxD5fSSSbybLDs109BqNijElFCV3IUQ7IcRvQoiTQogh+XxeQwjxsxDigBDisBDiGfVDZVpq8EADxEfFI6ZFDBYfW4yg2CAsTVgqOyxTcXQXqGWHcWOGVmhyF0K4A5gIoD2AIADdhBBBeSb7AMBiIqoH4CUAsWoHyjSSq5LYs+bD+PjiQ9gXtQ9+5fzQdUlXdF3SFX+m/ik7SlPg8V2ZkSgpuTcEcJKIThNROoCFADrmmYYAlMt+Xh7AJfVCZJpxUEkcsuEI9vTeg0+f/BQrfluB4NhgLDiyALLuiTALLbsPKAhfxGX5UZLcqwE4n+v1hez3cosB0F0IcQHAagADVYnO1ei9lxZQSVzCrQSGPD4EB/seRC2fWoj8PhIdF3bExesXtY3J5EqW/Pe5j4/2d4HyRVzmiFoXVLsBmE1EfgCeATBPCHHPvIUQUUKIeCFEfFJSkkqLtggZe6mCSuLalWtj+2vb8eVTX2LD6Q0Ijg3GrAOzuBSfR86/LznXdehbOvTVxhdxmSNKkvtFANVzvfbLfi+31wEsBgAi2gXAG4Bv3hkR0VQiCiei8MqVKxctYquSsZcqrCR2d3PHW03ewuHowwitGopeK3qhXVw7nP37rHaxmYysJMsXcbVhhaouJcl9L4BaQohAIYQn7BdMV+SZ5hyAJwFACFEb9uTORXNnyNhLnawkfqjSQ/i5x8+Y+MxE7Di3A3Um1cHk+MmwkU27GE1CVpLli7jqs0xVl5IRPWCvajkB4BSAYdnvjQAQkf08CMAOAIcAHATwVGHz5JGY8pA1vlsRh4754+of1HZuW0IMqOXslnQy+aSmYRqdzH+f0cfBNdtQjrIHSS8MeJg9k1F7L9Vhj7LZbDR933Qq92k5KjW6FH2962vKzMpUfTlaUXMVyUyyRk6eZjj45CVzkHQlOLmbkVp7qc571Plr5+mZuGcIMaCmM5rS8aTjToUqIzFpsYqMnGRlMXopOD9Gj5mTu5monRUkbJ02m43mHZpHFcdUJK+RXjRm2xjKyMoo8DsyS3VG34GNytlN1eil4PwY/WyDk7tZaLElSdyjLt+4TJ0XdSbEgMKnhtORP484nFZmgjVj0pGtKJuqWQ+iRj4L4+RuFlps/ZL3KJvNRouPLqbKYyuTxwgPGrF5BKVnpt8zncwEa9akI5Ojdebu7jgJGr0UbEac3M1CiwxnkD0qMTWRui3tRogBhU4KpX2X9t31ucwEa5BVZCqONtXC1qGRS8FmxMndLLTKcAbao5YfX05Vx1Ul9+HuNHTDUErLSLsToswEa6BVZAqONlU++9EXJ3ezKCzDWSQDpdxMoR7/60GIAQVNDKLd53cTkXF/nlHjkim/TZWvW+iPk7uZOMoksou2Glh9YjX5felHbsPd6J1179DN9JuyQ7pHfqs9p0rC1RN97k3V3Z1L7jJwcrcCi171u5Z2jfqu7EuIAdX6phZtO7tNdkh3Kaz6weTHV9VYsOyhOTXOCJUmdx5mz8gs2itUOa9ymNxhMja+uhGZtkw8MesJDFozCKnpqbJDA1D46uVeF+0iI+1dGvv7A0LY/2rdxbGZ6d1nDSd3I1OrVyiDdnHXOrA1DkcfxhsN38C3v3yLkEkh2PTHJtlhKVq9Jj++qiYyEjhzBrDZ7H85sTumd8+hnNyNTI2hfQzexV0ZzzL4pv032NpzK0q4lcCTc59E35V9cf32dWkx5bfa8+JeF5mz9D4R5+RuZGqc95pkNIfm/s1xsN9BvNPkHUw/MB3BscFYe3KtlFhyr3bAvupz02PoPJkMeqJnerp3z6ykYl6LB19Q1YkJ77PffX43BU0MIsSAei7vSSk3U6TGk/ciWHS0dZtJ8kVS7ai1bsGtZRgRmbbFzey56VTuvhQCssitwjl6a+xe2SERkXmTn9JWGibdXExzX4KerWU4uVudCbNRvjfLeKRS40HfUGJqotTYzJj8nNkETHiiZ8ZNvFg4ubN/maVYk81hO/PyZ6jy2Mq06OgistlsUmIzY/Jz5oBUUBt/o246RjngFrSbqbkLcnJnpuU4gdoofGo4IQbUaWEnunzjsu6xGSWROMOZA1JhXQwYsURshANuQWcPap9ZcHJnplVQAs3IyqDPtn9GXiO9qOKYijT34FxdS/FmrAJw9oCUU8osqARvJEY44BYUg9rxcXJnpqUkgf6a9Cs1ndGUEAN6Nu5ZOn/tvK7xmaiWq8gHJCOUiJUwwgG3oHWl9nrk5M5MTUkCzczKpK93fU0lR5Wkcp+Wo+n7pkurize6ohyQjFAiVkr2AZdL7pzczUn2nlOIk8knqeXsloQYUJu5beiPq3/IDkkxI69aI5SIzYLr3Dm5m49J9vAsWxZN2juJynxShsp8UoYm/jKRsmxZssMqkBlWrZEPPkZjtNYywj6t/sLDwyk+Pl7KspkTAgLs/dHk5e9v7ynKYM5dO4c+K/vgp1M/4Qn/JzAjYgYeqvSQ7LDyZbJVaxpxcfbeNc6ds9/aP3q0tTo0E0LsI6LwwqbjvmVYwUzW7XCN8jWwNnItZkbMxKErhxAyKQRf7voSWbYs2aHdw2Sr1hTU7CfP7H3scHK3Ei22Rt17Oyqa3D89MFDAM+E1JAxIQJsH2+C/P/0Xj896HMeTjssO8y4mWbWmolY/eQbvTFUZJXU3Wjy4zl1lWlXgmqBiuKAQbTYbxR2Oo0qfVSKvkV706bZPKSMrQ3bIRGSKVWs6ajU7NHJLIfAFVRej5dZo8KtqSn76lRtXqMviLoQYUIMpDejQlUOywr2LwVet6ai1Gxi5jb/S5M4XVK3Czc2+/eUlhH2YHAtz5qcvTViK/qv642raVQxrPgxDmw+Fp7unPoEyzeVUp+SumilVyvlhEIx8sZsvqLoaq1XgOnH9wJmf3iWoCxIGJOCF4BcwfMtwhE8Nx75L+1QJmcmn1riuagyCJp2S4r0WD66WUZmVKnCd/C1F/ek//PoD3T/ufnIf7k7vb3ifbmXc0uDHMLMyapUZuM7dBRl1a3RWESpOi/rTr966Sq8tf40QA3p0wqO089xOFX4AY9pRmty5zp0Zj4TrB+tOrkOflX1w4foFvNX4LYxsPRKlPAoZJZsxCbjOnZmXhOsHTz/0NI72P4p+4f3w5e4vETIpBFvObNFseYxpjZM7Mx5JV7PKeZVD7LOx2PTqJhAILee0xIBVA5Canqr6ssx+9yNzjoz/Nyd3Zjw5TR58fP59r2RJ3RbfKrAVDvc7jP9r9H+YFD8JdWLrYMPpDarN3xJ3PzLFZP2/ObkbCRfn7nbr1r/Pk5N1zYClPUvj63ZfY9tr2+BVwgtt57VFnxV9cC3tWrHnrdYt8swcZP2/ObkbhasU55QewAySAZvVaIaDfQ/ivWbvYebBmQiODcaqE6uKNU/uMMy1yPp/K0ruQoh2QojfhBAnhRBDHEzzghAiQQhxTAixQN0wXYBBkpmmnDmAqbhHFPeEqKRHSYxpMwa7X9+NiiUrosN3HfDq/15Fyq0Up2MBrHe/GSuYtP93YW0lAbgDOAXgQQCeAA4BCMozTS0ABwBUzH59X2Hz5XbueRi5Mwu1ONN+XaVOQtS+tystI40+2vQRlRhRgqp8XoW+T/je6XlY6X4zVjjDjsQEoAmAdblevw/g/TzTjAXQW8kCcx6c3PMwcjd0anHmAKbSHqHVaj1w+QDVm1yPEAN6YckLlJia6NT3rXK/GVNGxkhMSpJ7FwDTc71+BcCEPNMsz07wOwDsBtDOwbyiAMQDiK9Ro0bRf50VuUJxztlMq8IeoeUJUXpmOo3aMoo8R3qS71hf+u7IdzxAN9Oc0uSu1gXVEtlVMy0BdAMwTQhRIe9ERDSViMKJKLxy5coqLdoi1OrxyMicbb8eGWnvgs9ms/8twrrQsr7Tw90Dw54Yhv1R+/FgxQfRbVk3dFrUCZdvXC7+zBkrJiXJ/SKA6rle+2W/l9sFACuIKIOI/gBwAvZkz5yhQjIzNAkHMD3uhwq+Lxg7e+3E520/x7pT6xAUG4Q5B+fknK0yJoWS5L4XQC0hRKAQwhPASwBW5JlmOeyldgghfAE8DOC0inEyq9D5AKbX8cTdzR3vNH0Hh/odQp376qDnDz3x7IJncf7aeXUX5AL4dg91FJrciSgTwBsA1gE4DmAxER0TQowQQkRkT7YOQLIQIgHAzwAGE1GyVkEzSUy61+l5PHnY52Fs6bkF37T7BlvObkFwbDCm7pvKpfh85Lc5ucrtHnrgXiGZMmoNceNCTl89jT4r+2DTH5vQOrA1pj83HYEVA2WHZQiONqeSJe03I+dlhBGQjEJpr5Cc3JkyRh53zMCICNP2T8M7P72DLMrCmCfHYEDDAXATrn1zuKPNyREXGC1SMe7yl6mL75kvEiEEohpE4Vj/Y3jC/wkMWjsILWa3wInkE7JDk8rZzYbv3nUeJ3emDN8zXyzVy1fH6pdXY3bH2TiaeBShk0Mxbuc4ZNmyZIcmhaPNxsfHAmOXGgQnd6aMJUYMlksIgR5hPZDQPwFP13wag9cPRtOZTZGQlCA7NN052pzGj7f+7R66UXKnkxYP7n7AhPieedXYbDb67sh35POZD3mO9KTRW0dTema67LB0xZtT0YDHUGXM+BL/ScTANQOx+Nhi1KtaD7M6zkJo1VDZYTED4wuqjJnAfaXvw6Iui7DshWW4dOMSwqeF4+OfP0Z6Vrrs0JjJcXJnzAA61+6MY/2P4aU6L2HE1hGo2ftD3O9322z3izED4eTOzEvGHbMaLtOnlA/mdZqH/5bZh4txH+PKRS++S5MVGSd3pg+1k6KM+9R1WubSb+uDMu5uSmK1QbmY9viCKtOeFl0XyLhjVqdlurnZjx33EITUtJso7VlatWUx8+ELqsw4tBgfVsYdszot0+F9YeXOImRyCDaf2azq8pg1cXJn2tMiKcq4Y7ZSJV2W6egGn2ExNyEg0GpOK/Rf1R83bt9QdbnMWji5M+1pkYj1vmM2Lg64fv3e9z09VV+moz7oR70ZhMPRh/F247cxOX4y6kyqg59O/aTqspmFKLnTSYuH5e5Q5dvtHHM0Pmx0dPHWmZ7r3NH4rz4+2i2zADvP7aRHJzxKiAH1Wt6Lrt66KiUOpj+oNUC2Vg9LJXdXGNy6uPIm4uhoc60zLUfaLqJbGbdoyPoh5D7cnR744gFa+dtKabEw/ShN7txaRg3c17nzHK0zHx/gr790D6dQBv4fx1+KR68feuFI4hFE1o3E+Hbj4VPKR2pMTDvcWkZP3Ne58xytm+RkY96tY+BeMcMfCEd8VDxiWsRg0bFFCIoNwrKEZbLDcsikozWaDid3NXBf584raN0Y8W4dvUbaLiJPd0983PJj7IvaB79yfuiypAu6LumKxH8SZYd2Fx4jVT+c3NVg4FKdYRW0box6xqPnSNtFFFIlBHt678GnT36KFb+tQNDEICw4sgDOVL9qWbLW4pYH5oCSinktHpa6oEqkfssNV2h94+OT/0VKd3dr/26dJCQmUOPpjQkxoOcWPEcXrl0o9Dtatw0w4HVp0wG3ljExV2l9k9/vzPuw4u/Oj0YH88ysTPpi5xfkPcqbyn9anmbun0nz59scLspRi09/f1XC0Xz+roCTu5m50h6QO6m5u7vO785Nh4P5ib9O0BOzniB07kZunrccLkrrkrWrlFu0xMndzFzx3HX+fMeldyv/biLdDuZZtiyqWPVagYvSIxSz1TgaLV6lyZ0vqBqRq7W+yWlC4YhVf3cOnZrSugk3/P1nuQIXpUfbAD2uS6t1UdjUrXuUHAG0eHDJvQCudu7qqLho9d+dQ8dqOEeLqlj1GmVmZRKR8yVVo5Vs1dx9jFhDCq6WMTmj7TFaclQNBVj7d+fQ8WCe36LcPG8ROnejpjOa0q9Jvxo1dMXUTMhGrCFVmty5WkZNxTkXzPtdwPBtqlXjqNrF39/avzuHjjdI5beouTO9MPej9jiedByhk0MxdsdYZNoyFc3PiO3W1azlMnUNqZIjgBYPy5Xci1OEMWLxR0+u/vsN4tL1S/Sfhf8hxIAem/oYHfnzSKHfUaNkq/ZJqpoldyNumuBqGZ0p3aLy25KNWLGnN1eqhjIwm81GC48sJN+xvuQxwoNGbhlJ6ZnpDqcv7qarRfJ0dPuEj0/R5mu0TZOTu96UFGEcbcmO6put3gSQGS9zZEtMTaQXl7xIiAGFTQ6j/Zf25zudouRcwG/Uqlwzf37+N0DLLnWrgZO73pRspY6mcdWbd4xMj6RrxHP+PP53/H9UdVxVch/uTsM2DqO0jLR7pilwVRXyG7W8YGnVE2JO7npTsqMW1CrE4Du5S9Er6Zok+yTfTKYe/+tBiAEFTQyiPRf2KP9yIb9Ry1VgxJYuauDkLkNhpb2CtmSDnp67JL2SrozsU4ztbNWJVeT3pR+5DXejwT8NppvpNwv/UiG/UcvjqEmOnU7j5G5EJjgNZ6RPBysF3bilVfZRYfv7+9bf1GdFH0IM6OFvH6btZ7cX/AUFGVarco1VdzdO7kbFJXTj07LIV1hPmFpmHxV/1/pT6yng6wASMYIGrR5EqbdT859w/nwiT8+7l+fpqdt2b8XdjZM7Y0Ulo64gd/WcVlQ+I7lx+wYNXD2QEAMK/DqQNp7eeO9E8+cTeXjcvTwPD2tkWUmUJndFd6gKIdoJIX4TQpwUQgwpYLrnhRAkhCh08FbGDEvLO0Yd3SYphPZ3Iqt8u2UZzzL4pv032NpzK9zd3PHk3CfR78d+uH77+r8TDRsGZGTc/cWMDB56SQeFJnchhDuAiQDaAwgC0E0IEZTPdGUB/B+APWqCXukgAAATfUlEQVQHyZjTitstoFZdF8q8n12jLh+b+zfHoX6H8N8m/8W0/dNQJ7YO1p5ca/+QB4+XpoSCaRoCOElEpwFACLEQQEcACXmmGwngMwCDixpMRkYGLly4gLS0tKLOgjnB29sbfn5+8PDwkB2KunL6ac3p9CSnn1ZAfl81o0ffHRug33i7Ob992DB7cq1Rw75cFdZJKY9SGPfUOHQJ6oJeP/RC+7j26BnWE9P9qsH9/IV7v2C0zlni4jRZL1IVVm8DoAuA6blevwJgQp5p6gNYlv18M4DwwuabX5376dOnKSkpiWw2mwY1VSw3m81GSUlJdPr0admhqM/obeCseJUvl7SMNBq6YSi5D3en6G7lKcPbS7+LxkVhsmY10KtXSCGEG4AvAfxXwbRRQoh4IUR8UlLSPZ+npaXBx8cHQojihsUKIYSAj4+PNc+SjF4VoMdoFRJ5lfDC6CdH45c+v2DHE/549ZnbSPItBdK4x8sic7ZrS7VGAtGYkuR+EUD1XK/9st/LURZAHQCbhRBnADQGsCK/i6pENJWIwokovHLlyvkujBO7fiy7rk3dT6t11L+/Pvb22Yvag0ag2qAMVPnMF0tWfQ56+WXZod3NmcKAiYZmUpLc9wKoJYQIFEJ4AngJwIqcD4noGhH5ElEAEQUA2A0ggojiNYnYQM6cOYMFCxbceX3w4EGsXr36zusVK1ZgzJgxqiyrZ8+eWLp0KQCgd+/eSEjIe8mD3aHHWHFMEU93T3zY4kPsi9oH/wr+eGHpC+iypAuupF6RHdq/nCkMGLEDewcKTe5ElAngDQDrABwHsJiIjgkhRgghIrQO0MgKS+4REREYMsRhy9Eimz59OoKC7mmwxHLoOPgFU6ZulbrY9foujHlyDFadWIXg2GDMPzw/55qdXM4UBoxe5Zebkop5LR75XVBNSEgo9sWG4pozZw7VrVuXQkJCqHv37kRE1KNHD1qyZMmdaUqXLk1ERI0aNaJy5cpRaGgojRkzhqpXr06+vr4UGhpKCxcupFmzZtGAAQPuzGPgwIHUpEkTCgwMvDO/rKwsio6OpkceeYTatGlD7du3v2tZOXLH0KJFC9q7d++dWIYOHUohISHUqFEjunLlChERJSYmUufOnSk8PJzCw8Np+/b8bxM3wjpnruV40nFqMr0JIQbUYUEHunDtguyQlF/kNsDFeii8oKqkKaQUb659EwevHFR1nmFVw/B1u68dfn7s2DGMGjUKO3fuhK+vL1JSUgqc35gxYzBu3Dj8+OOPAIAqVaogPj4eEyZMAADMnj37rukvX76M7du349dff0VERAS6dOmC77//HmfOnEFCQgISExNRu3Zt9OrVS/Fv+ueff9C4cWOMHj0a7777LqZNm4YPPvgA//d//4e33noLjz/+OM6dO4enn34ax48fVzxfxrTyqO+j2PbaNnz7y7cYunEogmKD8OVTX6JXvV7yrgNFRio7s5PZlNVJPIZqLps2bULXrl3h6+sLAKhUqZKq8//Pf/4DNzc3BAUF4c8//wQAbN++HV27doWbmxuqVq2KVq1aOTVPT09PdOjQAQDQoEEDnDlzBgCwYcMGvPHGGwgLC0NERASuX7+O1NRUVX+P4ZikFQMD3N3c8WbjN3Ek+gjqVa2H3it74+n5T+Ps32dlh1YwE1X5GbbkXlAJW28lSpSAzWYDANhsNqSnpxdpPl5eXneek0p1jR4eHndKO+7u7sjMtA9sbLPZsHv3bnh7e6uyHMMz8o1LzKGalWpiU49NmBI/Be9ueBd1JtXBZ20+Q7/wfnATBi17Ki3lS2bQtSdH69atsWTJEiQnJwPAnWqZgIAA7Nu3D4C9BUxGdl8ZZcuWxY0bN+58P+9rJZo1a4Zly5bBZrPhzz//xObNm1X4JcBTTz2Fb7/99s7rgwfVreIyHBO1YmB3cxNuiH4sGkejj6KJXxMMWD0Aree0xsmUk7JDMzVO7rkEBwdj2LBhaNGiBUJDQ/H2228DAPr06YMtW7YgNDQUu3btQunSpQEAISEhcHd3R2hoKL766iu0atUKCQkJCAsLw6JFixQt8/nnn4efnx+CgoLQvXt31K9fH+XLly/2b/nmm28QHx+PkJAQBAUFYfLkycWep6GZqRUDy5d/BX+s674OMyJm4OCVgwiZFIKvdn2FLFuW7NBMSahVPeCs8PBwio+/uyn88ePHUbt2bSnxyJSamooyZcogOTkZDRs2xI4dO1C1alVdlm2ZdR4QYK+Kycvf334XKDOVi9cvou+PfbHq91Vo4tcEMzvOxKO+j8oOyxCEEPuIqNCed7nkbgAdOnRAWFgYmjdvjg8//FC3xG4pfOOSpVQrVw0ru63E/E7z8VvybwibHIYx28cg05YpOzTTMOwFVVeiVj27S9Owx0MmhxACkSGRaPNgGwxYPQDvb3wfSxOWYmbHmQipEiI7PMPjkjuzDot3yOWqqpSpgqUvLMXiLotx7to5hE8Nx/DNw5GeVbRWa66CkztjzBS6BndFwoAEdA3uipgtMXhs2mPYf3m/7LAMi5M7Y8w0fEv5Iq5zHH546Qck/ZOEhtMaYujGoUjLtGDX1cXEyZ0xZjoRj0TgWP9jeDX0VXy6/VPUn1Ifuy/slh2WoXByZ4yZUsWSFTGz40ysjVyL1PRUNJvZDO/89A5uZtws/MsuwNzJXWJfInPmzEGtWrVQq1YtzJkzR7flMlZsFuuD5+mHnsbR/kcRVT8KX+z6AqGTQ7H17FbZYcmnpOtILR7F7vJX4riHycnJFBgYSMnJyZSSkkKBgYGUkpKi+XK1wF3+uhiTjRfqrI2nN1Lg14GEGNCAVQPoxu0bskNSHfQaQ1UaDfoS2bt3L0JCQpCWloZ//vkHwcHBOHr06D3TrVu3Dm3btkWlSpVQsWJFtG3bFmvXri3ychnTjcX74Gkd2BpHoo9gUMNBiN0bi7qT6mLD6Q2yw5LCvMldg75EHnvsMUREROCDDz7Au+++i+7du6NOnTr3THfx4kVUr/7vsLJ+fn64ePHiPdMxZjgu0AdPac/SGN9+PLa9tg2e7p5oO68tolZG4VraNdmh6cq8yV2jQZA/+ugjrF+/HvHx8Xj33XeLNS/GDMeFBg9vVqMZDvY9iMFNB2PGgRmoM6kOVv++uvAvWoR5k7tGfYkkJycjNTUVN27cQFpa/m1nq1WrhvPnz995feHCBVSrVq1Yy2VMFy7WB09Jj5IY23Ysdr2+C+W8yuHZBc+ix/IeSLlV8ChrlqCkYl6LhypjqCod99AJzz33HMXFxdGoUaPujH+aV3JyMgUEBFBKSgqlpKRQQEAAJScnF3vZMvAFVRekwX5jBmkZafTBxg/Ifbg7Vfm8Cn2f8L3skIoECi+omju5q2zOnDnUuXNnIiLKzMykhg0b0saNG/OddsaMGVSzZk2qWbMmzZw5U88wVSV7nTOmt/2X9lPY5DBCDOiFJS9QYmqi7JCcojS5c3/uLo7XOXNFGVkZGLtjLIZvGY7y3uXxbftv8WLwi/IG6HYC9+fOGGMOeLh7YNgTw3Cg7wEEVghEt2Xd0HlxZ1y+cVl2aKrh5F6AI0eOICws7K5Ho0aNZIfFGFNJ8H3B2Pn6ToxtMxZrfl+D4NhgzD00V7UB7GXiwToKULduXesPLM2YiyvhVgKDmw1GxCMReH3F6+ixvAcWHl2IKR2moHr56oXPwKC45M4YYwAe8X0EW1/bivHtxmPL2S0Ijg3GtH3TTFuK5+TOGGPZ3IQbBjUahCPRRxD+QDiifoxC23lt8cfVP2SH5jRO7owxlseDFR/Ehlc3YPKzk/HLxV9Qd1JdTPhlAmxkkx2aYpzcGWMsH27CDX3D++Jo/6N4vMbjGLhmIFrMboHfk3+XHZoipk7uMrulbteuHSpUqIAOHTrot1DGmO5qlK+BNZFrMKvjLBxNPIqQySEYt3McsmxZskMrkGmTe1wcEBUFnD1r75T67Fn7a70S/ODBgzFv3jx9FsYYk0oIgZ5hPXGs/zE8VfMpDF4/GE1nNkVCUoLs0BwybXLXoltqpf25A8CTTz6JsmXLFn1hjDHTeaDsA1j+4nIs6LwAp1JOod6Uevhk2yfIyMqQHdo9TJvcteiWWml/7owx1yWEQLe63ZAwIAEdH+mIYZuGodH0Rjh05ZDs0O5i2uSuVbfU3J87Y0yJ+0rfh8VdF2Np16W4eOMiwqeF4+OfP0Z6Vrrs0ACYOLlr1S21kv7cGWMsx/NBzyOhfwK61emGEVtHoMHUBoi/FF/4FzVm2uQeGQlMnQr4+wNC2P9OnWp/vzj69u2LkSNHIjIyEu+99546wTLGLM2nlA/mdpqLld1WIuVWChpNb4QhG4YgLVNeAdG0yR2wJ/IzZwCbzf63uIl97ty58PDwwMsvv4whQ4Zg79692LRpU77TNm/eHF27dsXGjRvh5+eHdevWFW/hjDHT6/BwBxzrfwyvhb2Gz3Z8hrDJYdh5fqeUWBT15y6EaAdgPAB3ANOJaEyez98G0BtAJoAkAL2I6GxB8+T+3I2B1zlj2lh/aj16r+yN89fO46MWHyGmZYwq81WtP3chhDuAiQDaAwgC0E0IEZRnsgMAwokoBMBSAGOdD5kxxqyjbc22OBp9FNHh0ahZsabuy1fS5W9DACeJ6DQACCEWAugI4E7rfSL6Odf0uwF0VzNIWY4cOYJXXnnlrve8vLywZ88eSRExxsykrFdZTHx2opRlK0nu1QCcz/X6AoCCRqx4HcCa4gRlFNyfO2PMrFQdrEMI0R1AOIAWDj6PAhAFADUcNEgnIlOMY2gFZu2nmjFWOCWtZS4CyD0ciV/2e3cRQrQBMAxABBHdzm9GRDSViMKJKLxy5cr3fO7t7Y3k5GROOjogIiQnJ8Pb21t2KIwxDSgpue8FUEsIEQh7Un8JwMu5JxBC1AMwBUA7IkosajB+fn64cOECkpKSijoL5gRvb2/4+fnJDoMxpoFCkzsRZQoh3gCwDvamkDOJ6JgQYgSAeCJaAeBzAGUALMmuUjlHRBHOBuPh4YHAwEBnv8YYYywPRXXuRLQawOo8732U63kbleNijDFWDKa+Q5Uxxlj+OLkzxpgFKep+QJMFC5EEoMAuCgrgC+AvFcNRC8flHI7LeUaNjeNyTnHi8ieie5sb5iEtuReHECJeSd8KeuO4nMNxOc+osXFcztEjLq6WYYwxC+LkzhhjFmTW5D5VdgAOcFzO4bicZ9TYOC7naB6XKevcGWOMFcysJXfGGGMFMHRyF0K0E0L8JoQ4KYQYks/nbwshEoQQh4UQG4UQ/gaJq58Q4ogQ4qAQYns+g5tIiSvXdM8LIUgIoUsrAgXrq6cQIil7fR0UQvQ2QlzZ07yQvY0dE0IsMEJcQoivcq2rE0KIvw0SVw0hxM9CiAPZ++QzBonLPzs/HBZCbBZC6NKhkhBiphAiUQhx1MHnQgjxTXbch4UQ9VUNgIgM+YC9H5tTAB4E4AngEICgPNO0AlAq+3k0gEUGiatcrucRANYaIa7s6coC2Ar7oCrhRogLQE8AEwy4fdWCfZSxitmv7zNCXHmmHwh7f0/S44K9Hjk6+3kQgDMGiWsJgB7Zz1sDmKfTNvYEgPoAjjr4/BnYx74QABoD2KPm8o1ccr8zAhQRpQPIGQHqDiL6mYhuZr/cDXt3xEaI63qul6UB6HFho9C4so0E8BkAvYZlVxqX3pTE1QfARCK6CgBUjB5PVY4rt24AvjNIXASgXPbz8gAuGSSuIAA5I93/nM/nmiCirQBSCpikI4C5ZLcbQAUhxP1qLd/IyT2/EaCqFTC9XiNAKYpLCDFACHEK9vFkBxkhruzTvupEtEqHeBTHle357FPTpUKI6vl8LiOuhwE8LITYIYTYnT1QvBHiAmCvbgAQiH8Tl+y4YgB0F0JcgL2jwYEGiesQgM7ZzzsBKCuE8NEhtsI4m+OcYuTkrliuEaA+lx1LDiKaSEQ1AbwH4APZ8Qgh3AB8CeC/smPJx0oAAWQfYH09gDmS48lRAvaqmZawl5CnCSEqSI3obi8BWEpEWbIDydYNwGwi8oO9ymFe9nYn2zsAWgghDsA+StxFAEZZZ5oxwop3RLURoGTElctCAP/RNCK7wuIqC6AOgM1CiDOw1/Gt0OGiaqHri4iSc/3vpgNooHFMiuKCvSS1gogyiOgPACdgT/ay48rxEvSpkgGUxfU6gMUAQES7AHjD3oeK1LiI6BIRdSaierDnChCRLhehC+FsLnGOHhcWingxogSA07CfduZcKAnOM0092C+m1DJYXLVyPX8O9kFNpMeVZ/rN0OeCqpL1dX+u550A7DZIXO0AzMl+7gv7KbSP7Liyp3sUwBlk36tikPW1BkDP7Oe1Ya9z1zQ+hXH5AnDLfj4awAg91ln28gLg+ILqs7j7guovqi5brx9ZxBXzDOylpVMAhmW/NwL2UjoAbADwJ4CD2Y8VBolrPIBj2TH9XFCS1TOuPNPqktwVrq9Ps9fXoez19ahB4hKwV2UlADgC4CUjxJX9OgbAGD3icWJ9BQHYkf1/PAjgKYPE1QXA79nTTAfgpVNc3wG4DCAD9rPA1wH0A9Av1/Y1MTvuI2rvj3yHKmOMWZCR69wZY4wVESd3xhizIE7ujDFmQZzcGWPMgji5M8aYBXFyZ4wxC+LkzhhjFsTJnTHGLOj/AaaF1CTFqYFCAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新之后的结果\n", "w0 = w[0].data[0]\n", "w1 = w[1].data[0]\n", "b0 = b.data[0]\n", "\n", "plot_x = np.arange(0.2, 1, 0.01)\n", "plot_y = (-w0.numpy() * plot_x - b0.numpy()) / w1.numpy()\n", "\n", "plt.plot(plot_x, plot_y, 'g', label='cutting line')\n", "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n", "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n", "plt.legend(loc='best')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到更新之后模型已经能够基本将这两类点分开了" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "前面我们使用了自己写的 loss,其实 PyTorch 已经为我们写好了一些常见的 loss,比如线性回归里面的 loss 是 `nn.MSE()`,而 Logistic 回归的二分类 loss 在 PyTorch 中是 `nn.BCEWithLogitsLoss()`,关于更多的 loss,可以查看[文档](http://pytorch.org/docs/0.3.0/nn.html#loss-functions)\n", "\n", "PyTorch 为我们实现的 loss 函数有两个好处,第一是方便我们使用,不需要重复造轮子,第二就是其实现是在底层 C++ 语言上的,所以速度上和稳定性上都要比我们自己实现的要好\n", "\n", "另外,PyTorch 出于稳定性考虑,将模型的 Sigmoid 操作和最后的 loss 都合在了 `nn.BCEWithLogitsLoss()`,所以我们使用 PyTorch 自带的 loss 就不需要再加上 Sigmoid 操作了" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# 使用自带的loss\n", "criterion = nn.BCEWithLogitsLoss() # 将 sigmoid 和 loss 写在一层,有更快的速度、更好的稳定性\n", "\n", "w = nn.Parameter(torch.randn(2, 1))\n", "b = nn.Parameter(torch.zeros(1))\n", "\n", "def logistic_reg(x):\n", " return torch.mm(x, w) + b\n", "\n", "optimizer = torch.optim.SGD([w, b], 1.)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " 0.6363\n", "[torch.FloatTensor of size 1]\n", "\n" ] } ], "source": [ "y_pred = logistic_reg(x_data)\n", "loss = criterion(y_pred, y_data)\n", "print(loss.data)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 200, Loss: 0.39538, Acc: 0.88000\n", "epoch: 400, Loss: 0.32407, Acc: 0.87000\n", "epoch: 600, Loss: 0.29039, Acc: 0.87000\n", "epoch: 800, Loss: 0.27061, Acc: 0.87000\n", "epoch: 1000, Loss: 0.25753, Acc: 0.88000\n", "\n", "During Time: 0.527 s\n" ] } ], "source": [ "# 同样进行 1000 次更新\n", "\n", "start = time.time()\n", "for e in range(1000):\n", " # 前向传播\n", " y_pred = logistic_reg(x_data)\n", " loss = criterion(y_pred, y_data)\n", " # 反向传播\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " # 计算正确率\n", " mask = y_pred.ge(0.5).float()\n", " acc = (mask == y_data).sum().data[0] / y_data.shape[0]\n", " if (e + 1) % 200 == 0:\n", " print('epoch: {}, Loss: {:.5f}, Acc: {:.5f}'.format(e+1, loss.data[0], acc))\n", "\n", "during = time.time() - start\n", "print()\n", "print('During Time: {:.3f} s'.format(during))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,使用了 PyTorch 自带的 loss 之后,速度有了一定的上升,虽然看上去速度的提升并不多,但是这只是一个小网络,对于大网络,使用自带的 loss 不管对于稳定性还是速度而言,都有质的飞跃,同时也避免了重复造轮子的困扰" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下一节课我们会介绍 PyTorch 中构建模型的模块 `Sequential` 和 `Module`,使用这个可以帮助我们更方便地构建模型" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }