{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Logistic 回归模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上一节课我们学习了简单的线性回归模型,这一次课中,我们会学习第二个模型,Logistic 回归模型。\n", "\n", "Logistic 回归是一种广义的回归模型,其与多元线性回归有着很多相似之处,模型的形式基本相同,虽然也被称为回归,但是其更多的情况使用在分类问题上,同时又以二分类更为常用。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 模型形式\n", "Logistic 回归的模型形式和线性回归一样,都是 y = wx + b,其中 x 可以是一个多维的特征,唯一不同的地方在于 Logistic 回归会对 y 作用一个 logistic 函数,将其变为一种概率的结果。 Logistic 函数作为 Logistic 回归的核心,我们下面讲一讲 Logistic 函数,也被称为 Sigmoid 函数。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sigmoid 函数\n", "Sigmoid 函数非常简单,其公式如下\n", "\n", "$$\n", "f(x) = \\frac{1}{1 + e^{-x}}\n", "$$\n", "\n", "Sigmoid 函数的图像如下\n", "\n", "![](https://ws2.sinaimg.cn/large/006tKfTcly1fmd3dde091g30du060mx0.gif)\n", "\n", "可以看到 Sigmoid 函数的范围是在 0 ~ 1 之间,所以任何一个值经过了 Sigmoid 函数的作用,都会变成 0 ~ 1 之间的一个值,这个值可以形象地理解为一个概率,比如对于二分类问题,这个值越小就表示属于第一类,这个值越大就表示属于第二类。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "另外一个 Logistic 回归的前提是确保你的数据具有非常良好的线性可分性,也就是说,你的数据集能够在一定的维度上被分为两个部分,比如\n", "\n", "![](https://ws1.sinaimg.cn/large/006tKfTcly1fmd3gwdueoj30aw0aewex.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,上面红色的点和蓝色的点能够几乎被一个绿色的平面分割开来" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 回归问题 vs 分类问题\n", "Logistic 回归处理的是一个分类问题,而上一个模型是回归模型,那么回归问题和分类问题的区别在哪里呢?\n", "\n", "从上面的图可以看出,分类问题希望把数据集分到某一类,比如一个 3 分类问题,那么对于任何一个数据点,我们都希望找到其到底属于哪一类,最终的结果只有三种情况,{0, 1, 2},所以这是一个离散的问题。\n", "\n", "而回归问题是一个连续的问题,比如曲线的拟合,我们可以拟合任意的函数结果,这个结果是一个连续的值。\n", "\n", "分类问题和回归问题是机器学习和深度学习的第一步,拿到任何一个问题,我们都需要先确定其到底是分类还是回归,然后再进行算法设计" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 损失函数\n", "前一节对于回归问题,我们有一个 loss 去衡量误差,那么对于分类问题,我们如何去衡量这个误差,并设计 loss 函数呢?\n", "\n", "Logistic 回归使用了 Sigmoid 函数将结果变到 0 ~ 1 之间,对于任意输入一个数据,经过 Sigmoid 之后的结果我们记为 $\\hat{y}$,表示这个数据点属于第二类的概率,那么其属于第一类的概率就是 $1-\\hat{y}$。如果这个数据点属于第二类,我们希望 $\\hat{y}$ 越大越好,也就是越靠近 1 越好,如果这个数据属于第一类,那么我们希望 $1-\\hat{y}$ 越大越好,也就是 $\\hat{y}$ 越小越好,越靠近 0 越好,所以我们可以这样设计我们的 loss 函数\n", "\n", "$$\n", "loss = -(y * log(\\hat{y}) + (1 - y) * log(1 - \\hat{y}))\n", "$$\n", "\n", "其中 y 表示真实的 label,只能取 {0, 1} 这两个值,因为 $\\hat{y}$ 表示经过 Logistic 回归预测之后的结果,是一个 0 ~ 1 之间的小数。如果 y 是 0,表示该数据属于第一类,我们希望 $\\hat{y}$ 越小越好,上面的 loss 函数变为\n", "\n", "$$\n", "loss = - (log(1 - \\hat{y}))\n", "$$\n", "\n", "在训练模型的时候我们希望最小化 loss 函数,根据 log 函数的单调性,也就是最小化 $\\hat{y}$,与我们的要求是一致的。\n", "\n", "而如果 y 是 1,表示该数据属于第二类,我们希望 $\\hat{y}$ 越大越好,同时上面的 loss 函数变为\n", "\n", "$$\n", "loss = -(log(\\hat{y}))\n", "$$\n", "\n", "我们希望最小化 loss 函数也就是最大化 $\\hat{y}$,这也与我们的要求一致。\n", "\n", "所以通过上面的论述,说明了这么构建 loss 函数是合理的。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们通过例子来具体学习 Logistic 回归" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.autograd import Variable\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 设定随机种子\n", "torch.manual_seed(2017)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们从 data.txt 读入数据,感兴趣的同学可以打开 data.txt 文件进行查看\n", "\n", "读入数据点之后我们根据不同的 label 将数据点分为了红色和蓝色,并且画图展示出来了" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 从 data.txt 中读入点\n", "with open('./data.txt', 'r') as f:\n", " data_list = [i.split('\\n')[0].split(',') for i in f.readlines()]\n", " data = [(float(i[0]), float(i[1]), float(i[2])) for i in data_list]\n", "\n", "# 标准化\n", "x0_max = max([i[0] for i in data])\n", "x1_max = max([i[1] for i in data])\n", "data = [(i[0]/x0_max, i[1]/x1_max, i[2]) for i in data]\n", "\n", "x0 = list(filter(lambda x: x[-1] == 0.0, data)) # 选择第一类的点\n", "x1 = list(filter(lambda x: x[-1] == 1.0, data)) # 选择第二类的点\n", "\n", "plot_x0 = [i[0] for i in x0]\n", "plot_y0 = [i[1] for i in x0]\n", "plot_x1 = [i[0] for i in x1]\n", "plot_y1 = [i[1] for i in x1]\n", "\n", "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n", "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n", "plt.legend(loc='best')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来我们将数据转换成 NumPy 的类型,接着转换到 Tensor 为之后的训练做准备" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "np_data = np.array(data, dtype='float32') # 转换成 numpy array\n", "x_data = torch.from_numpy(np_data[:, 0:2]) # 转换成 Tensor, 大小是 [100, 2]\n", "y_data = torch.from_numpy(np_data[:, -1]).unsqueeze(1) # 转换成 Tensor,大小是 [100, 1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们来实现以下 Sigmoid 的函数,Sigmoid 函数的公式为\n", "\n", "$$\n", "f(x) = \\frac{1}{1 + e^{-x}}\n", "$$" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# 定义 sigmoid 函数\n", "def sigmoid(x):\n", " return 1 / (1 + np.exp(-x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "画出 Sigmoid 函数,可以看到值越大,经过 Sigmoid 函数之后越靠近 1,值越小,越靠近 0" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAHB9JREFUeJzt3XmYVNWd//H3V1YXIiDIjuAIRiZjXFqNOv7UURSIgsYNonFDSYg4cVxGHR00ap4kkp+JTjSKW9xZ4q9bRHhwHxNXlggqiDauoCwqImKgafj+/jjVWjbVdHV3VZ2qW5/X89ynqu493fXt28WnL+fee465OyIikizbxC5ARERyT+EuIpJACncRkQRSuIuIJJDCXUQkgRTuIiIJpHAXEUkghbuISAIp3EVEEqh1rDfu0qWL9+vXL9bbi4iUpLlz537i7l0baxct3Pv168ecOXNivb2ISEkys/ezaaduGRGRBFK4i4gkkMJdRCSBFO4iIgmkcBcRSaBGw93M7jKzlWb2egPbzcxuMrNqM1tgZvvkvkwREWmKbI7c/wwM2cr2ocCA1DIG+FPLyxIRkZZo9Dp3d3/OzPptpckI4F4P8/W9ZGYdzayHu3+coxpFJKncYcOGzMv69VBTA7W1sGlT5mVr2zZtgs2bw3vULXXv2dR1Tf26+j9jfcceC/vtl9t9WU8ubmLqBXyY9nppat0W4W5mYwhH9/Tt2zcHby0i0dTWwqefwiefwKpV4bHu+erVsHbtlsuXX4bHr74K4b1xY+yfojDMvv26Z8+SCPesuftEYCJARUWFZuYWKWbusHQpLFoEb70FH3wQlg8/DI8ffRSOjDPZYQfo0OGbZYcdoHfvb15vtx20axeW9u2/eZ6+tG8PbdtC69bQqtWWS0Pr05dttgnBmr5A89Y15euKQC7CfRnQJ+1179Q6ESkVGzbAggUweza88gq8/jq8+SasW/dNm3btoE8f6NsXjjgiPO/eHbp2hS5dvnncaacQyhJVLsJ9GjDOzCYBBwBr1N8uUuRqauDFF+GJJ+Cpp2DevLAOYOed4fvfh9GjYY89wrL77tCtW1EdmcrWNRruZvYQcBjQxcyWAlcBbQDc/VZgBjAMqAa+As7KV7Ei0gJr18Kjj8LUqSHU160LXRf77Qe/+AXsv39Y+vRRiCdANlfLjGpkuwPn5awiEcmdzZvh8cdh4kSYMSN0v/TqBWeeCYMHw2GHwY47xq5S8iDakL8ikkdffAG33hqWd98N/eE//SmcfDIceGA40SiJpnAXSZLPPoObboIbb4TPP4dDD4Vf/xqOOy6cEJWyoXAXSYKaGrj5ZvjlL2HNmhDmV14J++4buzKJROEuUuqefBLGjYPFi+Hoo+H662HPPWNXJZGp402kVK1dG/rRBw8OJ06nT4eZMxXsAujIXaQ0vfwyjBwJ778Pl1wC11wT7ugUSdGRu0gpcQ9XwBxySHj9t7+FbhgFu9SjcBcpFRs3wrnnwtix4fb/uXPhoINiVyVFSuEuUgrWrYMRI+DOO+GKK0L/eufOsauSIqY+d5Fi99lnMHQozJkDt90GY8bErkhKgMJdpJitWQNHHRVGaXz44XD9ukgWFO4ixWrt2nDEvmABVFbCD38YuyIpIQp3kWK0cSMcf3wYW33qVAW7NJnCXaTYuIcrYp56Cu65J4S8SBPpahmRYvO734WrYq68Ek4/PXY1UqIU7iLFZNYsuPRSOOWUMAiYSDMp3EWKxbJlcNpp8L3vwV13acx1aRF9ekSKQW0tjBoF//hHOIG63XaxK5ISpxOqIsXg2mvhr3+F++8Pk1GLtJCO3EVimzcPfvUrOOMMOPXU2NVIQijcRWKqqQmTVXfrBr//fexqJEHULSMS03XXwWuvhYHAOnWKXY0kiI7cRWJZuDBMXv2Tn+gOVMk5hbtIDO5w/vnQoQPccEPsaiSB1C0jEsPUqfD003DLLdClS+xqJIF05C5SaF9+CRddBHvtpbHZJW905C5SaNdfD0uXwqRJ0KpV7GokoXTkLlJIK1aEPvaTT4aDD45djSSYwl2kkK67DtavD3ekiuSRwl2kUN55J8yBOno0DBwYuxpJuKzC3cyGmNliM6s2s8sybO9rZs+Y2d/NbIGZDct9qSIlbvz40Md+1VWxK5Ey0Gi4m1kr4GZgKDAIGGVmg+o1uxKY4u57AyOBW3JdqEhJW7wYHnwwXNves2fsaqQMZHPkvj9Q7e7vuHsNMAkYUa+NA99JPd8R+Ch3JYokwG9+A+3ahUsgRQogm0shewEfpr1eChxQr83VwONmdj6wPXBkTqoTSYL33w9D+Y4dGwYIEymAXJ1QHQX82d17A8OA+8xsi+9tZmPMbI6ZzVm1alWO3lqkyE2YAGZwySWxK5Eykk24LwP6pL3unVqXbjQwBcDdXwTaA1vcU+3uE929wt0runbt2ryKRUrJ8uVwxx1hous+fRpvL5Ij2YT7bGCAmfU3s7aEE6bT6rX5ADgCwMz2IIS7Ds1FbrwRNm6Ey7a4yEwkrxoNd3evBcYBs4BFhKti3jCza8xseKrZRcC5ZjYfeAg40909X0WLlISvvgrXtR9/POy2W+xqpMxkNbaMu88AZtRbNz7t+UJA91KLpLvvPli9Gi64IHYlUoZ0h6pIPriHLpl99tEYMhKFRoUUyYcnnoBFi+Dee8OVMiIFpiN3kXz4wx+ge/cw+qNIBAp3kVyrroaZM8NNS+3axa5GypTCXSTXbr89DBB27rmxK5EypnAXyaWaGrj7bjj2WOjRI3Y1UsYU7iK59MgjsGqV5kaV6BTuIrk0cSL07QtHHRW7EilzCneRXFmyBJ58Es45RxNfS3QKd5FcueMO2GYbOPvs2JWIKNxFcqK2NpxIPeYY6NUrdjUiCneRnHj8cVixAs46K3YlIoDCXSQ37r0XdtoJhmlueCkOCneRllqzBqqqYORIaNs2djUigMJdpOX+8hfYsAF+8pPYlYh8TeEu0lL33gsDB8L++8euRORrCneRlnjvPXjuuTBHqob2lSKicBdpifvvD4+nnhq3DpF6FO4izeUeumQOPRT69Ytdjci3KNxFmmvePHj7bR21S1FSuIs01+TJ0Lo1nHBC7EpEtqBwF2kOd5gyBQYPhs6dY1cjsgWFu0hzvPwyvP8+nHJK7EpEMlK4izTHlCnhbtQRI2JXIpKRwl2kqTZvDuF+9NHQsWPsakQyUriLNNULL8CyZeqSkaKmcBdpqsmToX17GD48diUiDVK4izTFpk1hoLBhw6BDh9jViDRI4S7SFH/9Kyxfri4ZKXoKd5GmmDoVtt0WfvjD2JWIbJXCXSRbmzfDI4/AkCGw/faxqxHZqqzC3cyGmNliM6s2s8saaHOymS00szfM7MHclilSBObODVfJHH987EpEGtW6sQZm1gq4GRgMLAVmm9k0d1+Y1mYAcDlwsLuvNrOd81WwSDRVVdCqlbpkpCRkc+S+P1Dt7u+4ew0wCah/W965wM3uvhrA3VfmtkyRIlBVFYb31VgyUgKyCfdewIdpr5em1qUbCAw0s+fN7CUzG5LpG5nZGDObY2ZzVq1a1byKRWJ46y1YuBCOOy52JSJZydUJ1dbAAOAwYBRwu5ltcV+2u0909wp3r+jatWuO3lqkAKqqwqPGkpESkU24LwP6pL3unVqXbikwzd03uvu7wFuEsBdJhqoq2Gcf6Ns3diUiWckm3GcDA8ysv5m1BUYC0+q1qSIctWNmXQjdNO/ksE6ReD7+GF56SVfJSElpNNzdvRYYB8wCFgFT3P0NM7vGzOoG15gFfGpmC4FngEvc/dN8FS1SUI8+GibnUH+7lBBz9yhvXFFR4XPmzIny3iJNMmxYOKH69ttgFrsaKXNmNtfdKxprpztURbbmiy/gqafCUbuCXUqIwl1ka2bOhJoadclIyVG4i2xNVRXsvDMceGDsSkSaROEu0pANG+Cxx8KkHK1axa5GpEkU7iINefZZWLtWXTJSkhTuIg2pqgpD+x5xROxKRJpM4S6SSd3Y7UOHhvlSRUqMwl0kk1deCXemqktGSpTCXSSTqipo3Vpjt0vJUriLZFJVBYcfDh23GNxUpCQo3EXqe/NNWLxYXTJS0hTuIvVVVobH4cO33k6kiCncReqrqoL99oPevWNXItJsCneRdMuWhStlNHa7lDiFu0i6aal5aNTfLiVO4S6SrqoKBg6E7343diUiLaJwF6nz+efw9NMau10SQeEuUmfGDKitVZeMJILCXaROVRV07w4HHBC7EpEWU7iLAKxfH2ZdGjECttE/Cyl9+hSLQOhr//JLdclIYijcRSDcldqhQxhPRiQBFO4imzaF69uHDYN27WJXI5ITCneRF1+ElSt1V6okisJdpLIS2rYNsy6JJITCXcqbewj3I4+E73wndjUiOaNwl/L22mvw7ru6SkYSR+Eu5a2yMgw1oLHbJWEU7lLeKivh4IOhW7fYlYjklMJdyte778L8+bpKRhIpq3A3syFmttjMqs3ssq20O8HM3MwqcleiSJ5UVYVH9bdLAjUa7mbWCrgZGAoMAkaZ2aAM7ToAvwBeznWRInlRWQl77gm77hq7EpGcy+bIfX+g2t3fcfcaYBIwIkO7a4HfAutzWJ9IfqxcCX/7m7pkJLGyCfdewIdpr5em1n3NzPYB+rj7YzmsTSR/Hn00XOOuLhlJqBafUDWzbYAbgIuyaDvGzOaY2ZxVq1a19K1Fmq+yEvr1g+9/P3YlInmRTbgvA/qkve6dWlenA/A94Fkzew/4ATAt00lVd5/o7hXuXtG1a9fmVy3SEmvXwpNPhi4ZTacnCZVNuM8GBphZfzNrC4wEptVtdPc17t7F3fu5ez/gJWC4u8/JS8UiLTV9OmzYAD/6UexKRPKm0XB391pgHDALWARMcfc3zOwaM9NtfVJ6pkyBnj3hoINiVyKSN62zaeTuM4AZ9daNb6DtYS0vSyRP1q4N0+n99KeaTk8STZ9uKS91XTInnRS7EpG8UrhLeVGXjJQJhbuUj7oumRNPVJeMJJ4+4VI+Hn1UXTJSNhTuUj6mTlWXjJQNhbuUB3XJSJnRp1zKg7pkpMwo3KU8TJ6sLhkpKwp3Sb5PPw1dMqNGqUtGyoY+6ZJ8U6fCxo1w6qmxKxEpGIW7JN8DD8CgQbDXXrErESkYhbsk23vvhRmXTjtNw/tKWVG4S7I9+GB4/PGP49YhUmAKd0kud7j/fjjkENhll9jViBSUwl2S69VXYdEinUiVsqRwl+S6/35o00Y3LklZUrhLMm3cGML9mGOgc+fY1YgUnMJdkmn6dFi5EkaPjl2JSBQKd0mmO+8Mww0cfXTsSkSiULhL8ixbFoYbOPNMaJ3VNMEiiaNwl+S55x7YvBnOOit2JSLRKNwlWTZvhrvugkMPhd12i12NSDQKd0mW556DJUt0IlXKnsJdkuXWW6FjRzjhhNiViESlcJfk+OgjePhhOPts2G672NWIRKVwl+SYOBE2bYKxY2NXIhKdwl2SoaYGbrsNhg7ViVQRFO6SFJWVsHw5jBsXuxKRoqBwl2T44x/hn/5Jd6SKpCjcpfTNnh1mWzrvPE2ALZKifwlS+iZMgB13hHPOiV2JSNHIKtzNbIiZLTazajO7LMP2C81soZktMLOnzEzT3khhVFeHyx/HjoUOHWJXI1I0Gg13M2sF3AwMBQYBo8xsUL1mfwcq3H1P4C/A9bkuVCSjG24Ig4P9+7/HrkSkqGRz5L4/UO3u77h7DTAJGJHewN2fcfevUi9fAnrntkyRDFatgrvvhtNPhx49YlcjUlSyCfdewIdpr5em1jVkNDAz0wYzG2Nmc8xszqpVq7KvUiSTP/wBNmyAiy6KXYlI0cnpCVUzOw2oACZk2u7uE929wt0runbtmsu3lnLzySdw001w8snw3e/Grkak6GQzk8EyoE/a696pdd9iZkcCVwCHuvuG3JQn0oDf/Q7WrYOrropdiUhRyubIfTYwwMz6m1lbYCQwLb2Bme0N3AYMd/eVuS9TJM3KlfA//wM//jHssUfsakSKUqPh7u61wDhgFrAImOLub5jZNWY2PNVsArADMNXMXjWzaQ18O5GWmzAB1q+H8eNjVyJStLKaYNLdZwAz6q0bn/b8yBzXJZLZBx+EoQZOOw0GDoxdjUjR0h2qUlr+67/C43XXxa1DpMgp3KV0vPIKPPBAuPSxT5/G24uUMYW7lAZ3uPBC6NYNLr00djUiRS+rPneR6CZNguefh9tv1xgyIlnQkbsUv9Wr4T/+Ayoq4KyzYlcjUhJ05C7F7/LLwzgyM2dCq1axqxEpCTpyl+L2wgthbtQLLoC9945djUjJULhL8Vq3Ds48E/r2hV/+MnY1IiVF3TJSvC6+OEzG8dRTsMMOsasRKSk6cpfi9NhjcOut4fLHww+PXY1IyVG4S/FZujRcFfMv/wK/+lXsakRKksJdisuGDXDSSfCPf8DkydCuXeyKREqS+tyluFx4Ibz0EkydquF8RVpAR+5SPG65JSwXXwwnnhi7GpGSpnCX4vDII3D++XDssfDrX8euRqTkKdwlvuefh1GjYN994aGHoLV6C0VaSuEucb3wAgwdCr17w/TpsP32sSsSSQSFu8Tz4oswZAh07w7PPAM77xy7IpHEULhLHNOnw5FHfhPsvXrFrkgkURTuUnh/+hOMGBEudXzuOQW7SB4o3KVw1q+HsWPh5z+HYcPgf/83HLmLSM4p3KUwliyBgw4K48VccglUVurkqUge6Zozya/Nm0OgX3optGkDjz4KxxwTuyqRxNORu+TPwoVw2GFw3nlw4IHw978r2EUKROEuubdiBfzsZ2FUx9deg7vvhlmzYJddYlcmUjbULSO589FH8Pvfh6thNmyAcePgv/8bunSJXZlI2VG4S8u4w9y5oV/9vvugthZOOSVMizdgQOzqRMqWwl2aZ8WKMCzvHXfA/Pmw7bZwzjlw0UWw666xqxMpewp3yY47vPVWuLO0sjKMCeMeBvu65ZYw8FfHjrGrFJEUhbtkVlsLixeH8V+eeQaefTb0qQPstRdcfTUcf3w4aSoiRSercDezIcCNQCvgDnf/Tb3t7YB7gX2BT4FT3P293JYqebF5cwjtt98OR+bz58O8ebBgQZjqDqBbtzBJ9eGHw+DB0L9/3JpFpFGNhruZtQJuBgYDS4HZZjbN3RemNRsNrHb33cxsJPBb4JR8FCxNsHEjfP45LF8eArz+smQJVFd/E+IA3/kO7L13uJRx771hv/1g993BLN7PISJNls2R+/5Atbu/A2Bmk4ARQHq4jwCuTj3/C/BHMzN39xzWWrrcQzdHbW0I3Lrn9V/X31ZTA199FcI3/bH+ui+/DCG+evW3l3XrMtfTuTP06BFOfA4eDLvtFq5sGTAA+vSBbXT7g0ipyybcewEfpr1eChzQUBt3rzWzNcBOwCe5KPJb7roLJkwIgRne8NtLtusK9fW1taHrIx+23TYs228fTmZ26hQCu1Onb5aOHcPgXD17hqVHD2jfPj/1iEjRKOgJVTMbA4wB6Nu3b/O+SZcu4SReXTeB2ZZLpvXZrst129atw9KmTebnW3vdpg1st10I8PTH7baDdu10hC0iDcom3JcBfdJe906ty9RmqZm1BnYknFj9FnefCEwEqKioaF6XzfDhYRERkQZlc+g3GxhgZv3NrC0wEphWr8004IzU8xOBp9XfLiIST6NH7qk+9HHALMKlkHe5+xtmdg0wx92nAXcC95lZNfAZ4Q+AiIhEklWfu7vPAGbUWzc+7fl64KTcliYiIs2lM3IiIgmkcBcRSSCFu4hIAincRUQSSOEuIpJAFutydDNbBbzfzC/vQj6GNmg51dU0qqvpirU21dU0LalrF3fv2lijaOHeEmY2x90rYtdRn+pqGtXVdMVam+pqmkLUpW4ZEZEEUriLiCRQqYb7xNgFNEB1NY3qarpirU11NU3e6yrJPncREdm6Uj1yFxGRrSjacDezk8zsDTPbbGYV9bZdbmbVZrbYzI5u4Ov7m9nLqXaTU8MV57rGyWb2amp5z8xebaDde2b2WqrdnFzXkeH9rjazZWm1DWug3ZDUPqw2s8sKUNcEM3vTzBaYWaWZdWygXUH2V2M/v5m1S/2Oq1OfpX75qiXtPfuY2TNmtjD1+f9FhjaHmdmatN/v+EzfKw+1bfX3YsFNqf21wMz2KUBNu6fth1fN7Aszu6Bem4LtLzO7y8xWmtnraes6m9kTZvZ26rFTA197RqrN22Z2RqY2TeLuRbkAewC7A88CFWnrBwHzgXZAf2AJ0CrD108BRqae3wqMzXO9/xcY38C294AuBdx3VwMXN9KmVWrf7Qq0Te3TQXmu6yigder5b4Hfxtpf2fz8wM+BW1PPRwKTC/C76wHsk3reAXgrQ12HAdML9XnK9vcCDANmAgb8AHi5wPW1ApYTrgOPsr+A/wPsA7yetu564LLU88syfe6BzsA7qcdOqeedWlJL0R65u/sid1+cYdMIYJK7b3D3d4FqwiTeXzMzA/6NMFk3wD3AcfmqNfV+JwMP5es98uDric/dvQaom/g8b9z9cXevTb18iTCrVyzZ/PwjCJ8dCJ+lI1K/67xx94/dfV7q+VpgEWGO4lIwArjXg5eAjmbWo4DvfwSwxN2be3Nki7n7c4Q5LdKlf44ayqKjgSfc/TN3Xw08AQxpSS1FG+5bkWnC7vof/p2Az9OCJFObXDoEWOHubzew3YHHzWxuah7ZQhiX+q/xXQ38NzCb/ZhPZxOO8jIpxP7K5uf/1sTvQN3E7wWR6gbaG3g5w+YDzWy+mc00s38uUEmN/V5if6ZG0vABVoz9Vaebu3+cer4c6JahTc73XUEnyK7PzJ4EumfYdIW7P1LoejLJssZRbP2o/V/dfZmZ7Qw8YWZvpv7C56Uu4E/AtYR/jNcSuozObsn75aKuuv1lZlcAtcADDXybnO+vUmNmOwAPAxe4+xf1Ns8jdD18mTqfUgUMKEBZRft7SZ1TGw5cnmFzrP21BXd3MyvIJYpRw93dj2zGl2UzYfenhP8Stk4dcWVqk5MaLUwI/iNg3618j2Wpx5VmVknoEmjRP4ps952Z3Q5Mz7Apm/2Y87rM7EzgGOAIT3U2ZvgeOd9fGeRs4vdcM7M2hGB/wN3/X/3t6WHv7jPM7BYz6+LueR1DJYvfS14+U1kaCsxz9xX1N8TaX2lWmFkPd/841U21MkObZYRzA3V6E843NlspdstMA0amrmToT/gL/Ep6g1RoPEOYrBvC5N35+p/AkcCb7r4000Yz297MOtQ9J5xUfD1T21yp1895fAPvl83E57muawjwn8Bwd/+qgTaF2l9FOfF7qk//TmCRu9/QQJvudX3/ZrY/4d9xXv/oZPl7mQacnrpq5gfAmrTuiHxr8H/PMfZXPemfo4ayaBZwlJl1SnWjHpVa13yFOIPcnIUQSkuBDcAKYFbatisIVzosBoamrZ8B9Ew935UQ+tXAVKBdnur8M/Czeut6AjPS6pifWt4gdE/ke9/dB7wGLEh9sHrUryv1ehjhaowlBaqrmtCv+GpqubV+XYXcX5l+fuAawh8fgPapz0516rO0awH20b8SutMWpO2nYcDP6j5nwLjUvplPODF9UAHqyvh7qVeXATen9udrpF3llufatieE9Y5p66LsL8IfmI+Bjan8Gk04T/MU8DbwJNA51bYCuCPta89OfdaqgbNaWovuUBURSaBS7JYREZFGKNxFRBJI4S4ikkAKdxGRBFK4i4gkkMJdRCSBFO4iIgmkcBcRSaD/DydAb7nqWwBcAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出 sigmoid 的图像\n", "\n", "plot_x = np.arange(-10, 10.01, 0.01)\n", "plot_y = sigmoid(plot_x)\n", "\n", "plt.plot(plot_x, plot_y, 'r')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "x_data = Variable(x_data)\n", "y_data = Variable(y_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在 PyTorch 当中,不需要我们自己写 Sigmoid 的函数,PyTorch 已经用底层的 C++ 语言为我们写好了一些常用的函数,不仅方便我们使用,同时速度上比我们自己实现的更快,稳定性更好\n", "\n", "通过导入 `torch.nn.functional` 来使用,下面就是使用方法" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "# 定义 logistic 回归模型\n", "w = Variable(torch.randn(2, 1), requires_grad=True) \n", "b = Variable(torch.zeros(1), requires_grad=True)\n", "\n", "def logistic_regression(x):\n", " return torch.sigmoid(torch.mm(x, w) + b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在更新之前,我们可以画出分类的效果" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出参数更新之前的结果\n", "w0 = w[0].data[0]\n", "w1 = w[1].data[0]\n", "b0 = b.data[0]\n", "\n", "plot_x = np.arange(0.2, 1, 0.01)\n", "plot_y = (-w0.numpy() * plot_x - b0.numpy()) / w1.numpy()\n", "\n", "plt.plot(plot_x, plot_y, 'g', label='cutting line')\n", "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n", "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n", "plt.legend(loc='best')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到分类效果基本是混乱的,我们来计算一下 loss,公式如下\n", "\n", "$$\n", "loss = -(y * log(\\hat{y}) + (1 - y) * log(1 - \\hat{y}))\n", "$$" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "# 计算loss\n", "def binary_loss(y_pred, y):\n", " logits = (y * y_pred.clamp(1e-12).log() + (1 - y) * (1 - y_pred).clamp(1e-12).log()).mean()\n", " return -logits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "注意到其中使用 `.clamp`,这是[文档](http://pytorch.org/docs/0.3.0/torch.html?highlight=clamp#torch.clamp)的内容,查看一下,并且思考一下这里是否一定要使用这个函数,如果不使用会出现什么样的结果\n", "\n", "**提示:查看一个 log 函数的图像**" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.7911, grad_fn=)\n" ] } ], "source": [ "y_pred = logistic_regression(x_data)\n", "loss = binary_loss(y_pred, y_data)\n", "print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "得到 loss 之后,我们还是使用梯度下降法更新参数,这里可以使用自动求导来直接得到参数的导数,感兴趣的同学可以去手动推导一下导数的公式" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.7801, grad_fn=)\n" ] } ], "source": [ "# 自动求导并更新参数\n", "loss.backward()\n", "w.data = w.data - 0.1 * w.grad.data\n", "b.data = b.data - 0.1 * b.grad.data\n", "\n", "# 算出一次更新之后的loss\n", "y_pred = logistic_regression(x_data)\n", "loss = binary_loss(y_pred, y_data)\n", "print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上面的参数更新方式其实是繁琐的重复操作,如果我们的参数很多,比如有 100 个,那么我们需要写 100 行来更新参数,为了方便,我们可以写成一个函数来更新,其实 PyTorch 已经为我们封装了一个函数来做这件事,这就是 PyTorch 中的优化器 `torch.optim`\n", "\n", "使用 `torch.optim` 需要另外一个数据类型,就是 `nn.Parameter`,这个本质上和 Variable 是一样的,只不过 `nn.Parameter` 默认是要求梯度的,而 Variable 默认是不求梯度的\n", "\n", "使用 `torch.optim.SGD` 可以使用梯度下降法来更新参数,PyTorch 中的优化器有更多的优化算法,在本章后面的课程我们会更加详细的介绍\n", "\n", "将参数 w 和 b 放到 `torch.optim.SGD` 中之后,说明一下学习率的大小,就可以使用 `optimizer.step()` 来更新参数了,比如下面我们将参数传入优化器,学习率设置为 1.0" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "# 使用 torch.optim 更新参数\n", "from torch import nn\n", "w = nn.Parameter(torch.randn(2, 1))\n", "b = nn.Parameter(torch.zeros(1))\n", "\n", "def logistic_regression(x):\n", " return torch.sigmoid(torch.mm(x, w) + b)\n", "\n", "optimizer = torch.optim.SGD([w, b], lr=1.)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:15: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n", " from ipykernel import kernelapp as app\n", "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:17: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch: 200, Loss: 0.39010, Acc: 0.00000\n", "epoch: 400, Loss: 0.32184, Acc: 0.00000\n", "epoch: 600, Loss: 0.28917, Acc: 0.00000\n", "epoch: 800, Loss: 0.26983, Acc: 0.00000\n", "epoch: 1000, Loss: 0.25700, Acc: 0.00000\n", "\n", "During Time: 0.248 s\n" ] } ], "source": [ "# 进行 1000 次更新\n", "import time\n", "\n", "start = time.time()\n", "for e in range(1000):\n", " # 前向传播\n", " y_pred = logistic_regression(x_data)\n", " loss = binary_loss(y_pred, y_data) # 计算 loss\n", " # 反向传播\n", " optimizer.zero_grad() # 使用优化器将梯度归 0\n", " loss.backward()\n", " optimizer.step() # 使用优化器来更新参数\n", " # 计算正确率\n", " mask = y_pred.ge(0.5).float()\n", " acc = (mask == y_data).sum().data[0] / y_data.shape[0]\n", " if (e + 1) % 200 == 0:\n", " print('epoch: {}, Loss: {:.5f}, Acc: {:.5f}'.format(e+1, loss.data[0], acc))\n", "during = time.time() - start\n", "print()\n", "print('During Time: {:.3f} s'.format(during))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到使用优化器之后更新参数非常简单,只需要在自动求导之前使用**`optimizer.zero_grad()`** 来归 0 梯度,然后使用 **`optimizer.step()`**来更新参数就可以了,非常简便\n", "\n", "同时经过了 1000 次更新,loss 也降得比较低了" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们画出更新之后的结果" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 画出更新之后的结果\n", "w0 = w[0].data[0]\n", "w1 = w[1].data[0]\n", "b0 = b.data[0]\n", "\n", "plot_x = np.arange(0.2, 1, 0.01)\n", "plot_y = (-w0.numpy() * plot_x - b0.numpy()) / w1.numpy()\n", "\n", "plt.plot(plot_x, plot_y, 'g', label='cutting line')\n", "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n", "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n", "plt.legend(loc='best')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到更新之后模型已经能够基本将这两类点分开了" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "前面我们使用了自己写的 loss,其实 PyTorch 已经为我们写好了一些常见的 loss,比如线性回归里面的 loss 是 `nn.MSE()`,而 Logistic 回归的二分类 loss 在 PyTorch 中是 `nn.BCEWithLogitsLoss()`,关于更多的 loss,可以查看[文档](http://pytorch.org/docs/0.3.0/nn.html#loss-functions)\n", "\n", "PyTorch 为我们实现的 loss 函数有两个好处,第一是方便我们使用,不需要重复造轮子,第二就是其实现是在底层 C++ 语言上的,所以速度上和稳定性上都要比我们自己实现的要好\n", "\n", "另外,PyTorch 出于稳定性考虑,将模型的 Sigmoid 操作和最后的 loss 都合在了 `nn.BCEWithLogitsLoss()`,所以我们使用 PyTorch 自带的 loss 就不需要再加上 Sigmoid 操作了" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# 使用自带的loss\n", "criterion = nn.BCEWithLogitsLoss() # 将 sigmoid 和 loss 写在一层,有更快的速度、更好的稳定性\n", "\n", "w = nn.Parameter(torch.randn(2, 1))\n", "b = nn.Parameter(torch.zeros(1))\n", "\n", "def logistic_reg(x):\n", " return torch.mm(x, w) + b\n", "\n", "optimizer = torch.optim.SGD([w, b], 1.)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " 0.6363\n", "[torch.FloatTensor of size 1]\n", "\n" ] } ], "source": [ "y_pred = logistic_reg(x_data)\n", "loss = criterion(y_pred, y_data)\n", "print(loss.data)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 200, Loss: 0.39538, Acc: 0.88000\n", "epoch: 400, Loss: 0.32407, Acc: 0.87000\n", "epoch: 600, Loss: 0.29039, Acc: 0.87000\n", "epoch: 800, Loss: 0.27061, Acc: 0.87000\n", "epoch: 1000, Loss: 0.25753, Acc: 0.88000\n", "\n", "During Time: 0.527 s\n" ] } ], "source": [ "# 同样进行 1000 次更新\n", "\n", "start = time.time()\n", "for e in range(1000):\n", " # 前向传播\n", " y_pred = logistic_reg(x_data)\n", " loss = criterion(y_pred, y_data)\n", " # 反向传播\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " # 计算正确率\n", " mask = y_pred.ge(0.5).float()\n", " acc = (mask == y_data).sum().data[0] / y_data.shape[0]\n", " if (e + 1) % 200 == 0:\n", " print('epoch: {}, Loss: {:.5f}, Acc: {:.5f}'.format(e+1, loss.data[0], acc))\n", "\n", "during = time.time() - start\n", "print()\n", "print('During Time: {:.3f} s'.format(during))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以看到,使用了 PyTorch 自带的 loss 之后,速度有了一定的上升,虽然看上去速度的提升并不多,但是这只是一个小网络,对于大网络,使用自带的 loss 不管对于稳定性还是速度而言,都有质的飞跃,同时也避免了重复造轮子的困扰" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下一节课我们会介绍 PyTorch 中构建模型的模块 `Sequential` 和 `Module`,使用这个可以帮助我们更方便地构建模型" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }