You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

2-logistic-regression.ipynb 97 kB


  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# Logistic 回归模型"
  8. ]
  9. },
  10. {
  11. "cell_type": "markdown",
  12. "metadata": {},
  13. "source": [
  14. "上一节课我们学习了简单的线性回归模型,这一次课中,我们会学习第二个模型,Logistic 回归模型。\n",
  15. "\n",
  16. "Logistic 回归是一种广义的回归模型,其与多元线性回归有着很多相似之处,模型的形式基本相同,虽然也被称为回归,但是其更多的情况使用在分类问题上,同时又以二分类更为常用。"
  17. ]
  18. },
  19. {
  20. "cell_type": "markdown",
  21. "metadata": {},
  22. "source": [
  23. "## 模型形式\n",
  24. "Logistic 回归的模型形式和线性回归一样,都是 y = wx + b,其中 x 可以是一个多维的特征,唯一不同的地方在于 Logistic 回归会对 y 作用一个 logistic 函数,将其变为一种概率的结果。 Logistic 函数作为 Logistic 回归的核心,我们下面讲一讲 Logistic 函数,也被称为 Sigmoid 函数。"
  25. ]
  26. },
  27. {
  28. "cell_type": "markdown",
  29. "metadata": {},
  30. "source": [
  31. "### Sigmoid 函数\n",
  32. "Sigmoid 函数非常简单,其公式如下\n",
  33. "\n",
  34. "$$\n",
  35. "f(x) = \\frac{1}{1 + e^{-x}}\n",
  36. "$$\n",
  37. "\n",
  38. "Sigmoid 函数的图像如下\n",
  39. "\n",
  40. "![](https://ws2.sinaimg.cn/large/006tKfTcly1fmd3dde091g30du060mx0.gif)\n",
  41. "\n",
  42. "可以看到 Sigmoid 函数的范围是在 0 ~ 1 之间,所以任何一个值经过了 Sigmoid 函数的作用,都会变成 0 ~ 1 之间的一个值,这个值可以形象地理解为一个概率,比如对于二分类问题,这个值越小就表示属于第一类,这个值越大就表示属于第二类。"
  43. ]
  44. },
  45. {
  46. "cell_type": "markdown",
  47. "metadata": {},
  48. "source": [
  49. "另外一个 Logistic 回归的前提是确保你的数据具有非常良好的线性可分性,也就是说,你的数据集能够在一定的维度上被分为两个部分,比如\n",
  50. "\n",
  51. "![](https://ws1.sinaimg.cn/large/006tKfTcly1fmd3gwdueoj30aw0aewex.jpg)"
  52. ]
  53. },
  54. {
  55. "cell_type": "markdown",
  56. "metadata": {},
  57. "source": [
  58. "可以看到,上面红色的点和蓝色的点能够几乎被一个绿色的平面分割开来"
  59. ]
  60. },
  61. {
  62. "cell_type": "markdown",
  63. "metadata": {},
  64. "source": [
  65. "## 回归问题 vs 分类问题\n",
  66. "Logistic 回归处理的是一个分类问题,而上一个模型是回归模型,那么回归问题和分类问题的区别在哪里呢?\n",
  67. "\n",
  68. "从上面的图可以看出,分类问题希望把数据集分到某一类,比如一个 3 分类问题,那么对于任何一个数据点,我们都希望找到其到底属于哪一类,最终的结果只有三种情况,{0, 1, 2},所以这是一个离散的问题。\n",
  69. "\n",
  70. "而回归问题是一个连续的问题,比如曲线的拟合,我们可以拟合任意的函数结果,这个结果是一个连续的值。\n",
  71. "\n",
  72. "分类问题和回归问题是机器学习和深度学习的第一步,拿到任何一个问题,我们都需要先确定其到底是分类还是回归,然后再进行算法设计"
  73. ]
  74. },
  75. {
  76. "cell_type": "markdown",
  77. "metadata": {},
  78. "source": [
  79. "## 损失函数\n",
  80. "前一节对于回归问题,我们有一个 loss 去衡量误差,那么对于分类问题,我们如何去衡量这个误差,并设计 loss 函数呢?\n",
  81. "\n",
  82. "Logistic 回归使用了 Sigmoid 函数将结果变到 0 ~ 1 之间,对于任意输入一个数据,经过 Sigmoid 之后的结果我们记为 $\\hat{y}$,表示这个数据点属于第二类的概率,那么其属于第一类的概率就是 $1-\\hat{y}$。如果这个数据点属于第二类,我们希望 $\\hat{y}$ 越大越好,也就是越靠近 1 越好,如果这个数据属于第一类,那么我们希望 $1-\\hat{y}$ 越大越好,也就是 $\\hat{y}$ 越小越好,越靠近 0 越好,所以我们可以这样设计我们的 loss 函数\n",
  83. "\n",
  84. "$$\n",
  85. "loss = -(y * log(\\hat{y}) + (1 - y) * log(1 - \\hat{y}))\n",
  86. "$$\n",
  87. "\n",
  88. "其中 y 表示真实的 label,只能取 {0, 1} 这两个值,因为 $\\hat{y}$ 表示经过 Logistic 回归预测之后的结果,是一个 0 ~ 1 之间的小数。如果 y 是 0,表示该数据属于第一类,我们希望 $\\hat{y}$ 越小越好,上面的 loss 函数变为\n",
  89. "\n",
  90. "$$\n",
  91. "loss = - (log(1 - \\hat{y}))\n",
  92. "$$\n",
  93. "\n",
  94. "在训练模型的时候我们希望最小化 loss 函数,根据 log 函数的单调性,也就是最小化 $\\hat{y}$,与我们的要求是一致的。\n",
  95. "\n",
  96. "而如果 y 是 1,表示该数据属于第二类,我们希望 $\\hat{y}$ 越大越好,同时上面的 loss 函数变为\n",
  97. "\n",
  98. "$$\n",
  99. "loss = -(log(\\hat{y}))\n",
  100. "$$\n",
  101. "\n",
  102. "我们希望最小化 loss 函数也就是最大化 $\\hat{y}$,这也与我们的要求一致。\n",
  103. "\n",
  104. "所以通过上面的论述,说明了这么构建 loss 函数是合理的。"
  105. ]
  106. },
  107. {
  108. "cell_type": "markdown",
  109. "metadata": {},
  110. "source": [
  111. "下面我们通过例子来具体学习 Logistic 回归"
  112. ]
  113. },
  114. {
  115. "cell_type": "code",
  116. "execution_count": 2,
  117. "metadata": {},
  118. "outputs": [],
  119. "source": [
  120. "import torch\n",
  121. "from torch.autograd import Variable\n",
  122. "import numpy as np\n",
  123. "import matplotlib.pyplot as plt\n",
  124. "%matplotlib inline"
  125. ]
  126. },
  127. {
  128. "cell_type": "code",
  129. "execution_count": 2,
  130. "metadata": {},
  131. "outputs": [
  132. {
  133. "data": {
  134. "text/plain": [
  135. "<torch._C.Generator at 0x7f089c0fccb0>"
  136. ]
  137. },
  138. "execution_count": 2,
  139. "metadata": {},
  140. "output_type": "execute_result"
  141. }
  142. ],
  143. "source": [
  144. "# 设定随机种子\n",
  145. "torch.manual_seed(2017)"
  146. ]
  147. },
  148. {
  149. "cell_type": "markdown",
  150. "metadata": {},
  151. "source": [
  152. "我们从 data.txt 读入数据,感兴趣的同学可以打开 data.txt 文件进行查看\n",
  153. "\n",
  154. "读入数据点之后我们根据不同的 label 将数据点分为了红色和蓝色,并且画图展示出来了"
  155. ]
  156. },
  157. {
  158. "cell_type": "code",
  159. "execution_count": 3,
  160. "metadata": {},
  161. "outputs": [
  162. {
  163. "data": {
  164. "text/plain": [
  165. "<matplotlib.legend.Legend at 0x7fb05e2923c8>"
  166. ]
  167. },
  168. "execution_count": 3,
  169. "metadata": {},
  170. "output_type": "execute_result"
  171. },
  172. {
  173. "data": {
  174. "image/png": "\n",
  175. "text/plain": [
  176. "<Figure size 432x288 with 1 Axes>"
  177. ]
  178. },
  179. "metadata": {
  180. "needs_background": "light"
  181. },
  182. "output_type": "display_data"
  183. }
  184. ],
  185. "source": [
  186. "# 从 data.txt 中读入点\n",
  187. "with open('./data.txt', 'r') as f:\n",
  188. " data_list = [i.split('\\n')[0].split(',') for i in f.readlines()]\n",
  189. " data = [(float(i[0]), float(i[1]), float(i[2])) for i in data_list]\n",
  190. "\n",
  191. "# 标准化\n",
  192. "x0_max = max([i[0] for i in data])\n",
  193. "x1_max = max([i[1] for i in data])\n",
  194. "data = [(i[0]/x0_max, i[1]/x1_max, i[2]) for i in data]\n",
  195. "\n",
  196. "x0 = list(filter(lambda x: x[-1] == 0.0, data)) # 选择第一类的点\n",
  197. "x1 = list(filter(lambda x: x[-1] == 1.0, data)) # 选择第二类的点\n",
  198. "\n",
  199. "plot_x0 = [i[0] for i in x0]\n",
  200. "plot_y0 = [i[1] for i in x0]\n",
  201. "plot_x1 = [i[0] for i in x1]\n",
  202. "plot_y1 = [i[1] for i in x1]\n",
  203. "\n",
  204. "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n",
  205. "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n",
  206. "plt.legend(loc='best')"
  207. ]
  208. },
  209. {
  210. "cell_type": "markdown",
  211. "metadata": {},
  212. "source": [
  213. "接下来我们将数据转换成 NumPy 的类型,接着转换到 Tensor 为之后的训练做准备"
  214. ]
  215. },
  216. {
  217. "cell_type": "code",
  218. "execution_count": 4,
  219. "metadata": {},
  220. "outputs": [],
  221. "source": [
  222. "np_data = np.array(data, dtype='float32') # 转换成 numpy array\n",
  223. "x_data = torch.from_numpy(np_data[:, 0:2]) # 转换成 Tensor, 大小是 [100, 2]\n",
  224. "y_data = torch.from_numpy(np_data[:, -1]).unsqueeze(1) # 转换成 Tensor,大小是 [100, 1]"
  225. ]
  226. },
  227. {
  228. "cell_type": "markdown",
  229. "metadata": {},
  230. "source": [
  231. "下面我们来实现以下 Sigmoid 的函数,Sigmoid 函数的公式为\n",
  232. "\n",
  233. "$$\n",
  234. "f(x) = \\frac{1}{1 + e^{-x}}\n",
  235. "$$"
  236. ]
  237. },
  238. {
  239. "cell_type": "code",
  240. "execution_count": 5,
  241. "metadata": {},
  242. "outputs": [],
  243. "source": [
  244. "# 定义 sigmoid 函数\n",
  245. "def sigmoid(x):\n",
  246. " return 1 / (1 + np.exp(-x))"
  247. ]
  248. },
  249. {
  250. "cell_type": "markdown",
  251. "metadata": {},
  252. "source": [
  253. "画出 Sigmoid 函数,可以看到值越大,经过 Sigmoid 函数之后越靠近 1,值越小,越靠近 0"
  254. ]
  255. },
  256. {
  257. "cell_type": "code",
  258. "execution_count": 6,
  259. "metadata": {},
  260. "outputs": [
  261. {
  262. "data": {
  263. "text/plain": [
  264. "[<matplotlib.lines.Line2D at 0x7fb05e21e0b8>]"
  265. ]
  266. },
  267. "execution_count": 6,
  268. "metadata": {},
  269. "output_type": "execute_result"
  270. },
  271. {
  272. "data": {
  273. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAHB9JREFUeJzt3XmYVNWd//H3V1YXIiDIjuAIRiZjXFqNOv7UURSIgsYNonFDSYg4cVxGHR00ap4kkp+JTjSKW9xZ4q9bRHhwHxNXlggqiDauoCwqImKgafj+/jjVWjbVdHV3VZ2qW5/X89ynqu493fXt28WnL+fee465OyIikizbxC5ARERyT+EuIpJACncRkQRSuIuIJJDCXUQkgRTuIiIJpHAXEUkghbuISAIp3EVEEqh1rDfu0qWL9+vXL9bbi4iUpLlz537i7l0baxct3Pv168ecOXNivb2ISEkys/ezaaduGRGRBFK4i4gkkMJdRCSBFO4iIgmkcBcRSaBGw93M7jKzlWb2egPbzcxuMrNqM1tgZvvkvkwREWmKbI7c/wwM2cr2ocCA1DIG+FPLyxIRkZZo9Dp3d3/OzPptpckI4F4P8/W9ZGYdzayHu3+coxpFJKncYcOGzMv69VBTA7W1sGlT5mVr2zZtgs2bw3vULXXv2dR1Tf26+j9jfcceC/vtl9t9WU8ubmLqBXyY9nppat0W4W5mYwhH9/Tt2zcHby0i0dTWwqefwiefwKpV4bHu+erVsHbtlsuXX4bHr74K4b1xY+yfojDMvv26Z8+SCPesuftEYCJARUWFZuYWKWbusHQpLFoEb70FH3wQlg8/DI8ffRSOjDPZYQfo0OGbZYcdoHfvb15vtx20axeW9u2/eZ6+tG8PbdtC69bQqtWWS0Pr05dttgnBmr5A89Y15euKQC7CfRnQJ+1179Q6ESkVGzbAggUweza88gq8/jq8+SasW/dNm3btoE8f6NsXjjgiPO/eHbp2hS5dvnncaacQyhJVLsJ9GjDOzCYBBwBr1N8uUuRqauDFF+GJJ+Cpp2DevLAOYOed4fvfh9GjYY89wrL77tCtW1EdmcrWNRruZvYQcBjQxcyWAlcBbQDc/VZgBjAMqAa+As7KV7Ei0gJr18Kjj8LUqSHU160LXRf77Qe/+AXsv39Y+vRRiCdANlfLjGpkuwPn5awiEcmdzZvh8cdh4kSYMSN0v/TqBWeeCYMHw2GHwY47xq5S8iDakL8ikkdffAG33hqWd98N/eE//SmcfDIceGA40SiJpnAXSZLPPoObboIbb4TPP4dDD4Vf/xqOOy6cEJWyoXAXSYKaGrj5ZvjlL2HNmhDmV14J++4buzKJROEuUuqefBLGjYPFi+Hoo+H662HPPWNXJZGp402kVK1dG/rRBw8OJ06nT4eZMxXsAujIXaQ0vfwyjBwJ778Pl1wC11wT7ugUSdGRu0gpcQ9XwBxySHj9t7+FbhgFu9SjcBcpFRs3wrnnwtix4fb/uXPhoINiVyVFSuEuUgrWrYMRI+DOO+GKK0L/eufOsauSIqY+d5Fi99lnMHQozJkDt90GY8bErkhKgMJdpJitWQNHHRVGaXz44XD9ukgWFO4ixWrt2nDEvmABVFbCD38YuyIpIQp3kWK0cSMcf3wYW33qVAW7NJnCXaTYuIcrYp56Cu65J4S8SBPpahmRYvO734WrYq68Ek4/PXY1UqIU7iLFZNYsuPRSOOWUMAiYSDMp3EWKxbJlcNpp8L3vwV13acx1aRF9ekSKQW0tjBoF//hHOIG63XaxK5ISpxOqIsXg2mvhr3+F++8Pk1GLtJCO3EVimzcPfvUrOOMMOPXU2NVIQijcRWKqqQmTVXfrBr//fexqJEHULSMS03XXwWuvhYHAOnWKXY0kiI7cRWJZuDBMXv2Tn+gOVMk5hbtIDO5w/vnQoQPccEPsaiSB1C0jEsPUqfD003DLLdClS+xqJIF05C5SaF9+CRddBHvtpbHZJW905C5SaNdfD0uXwqRJ0KpV7GokoXTkLlJIK1aEPvaTT4aDD45djSSYwl2kkK67DtavD3ekiuSRwl2kUN55J8yBOno0DBwYuxpJuKzC3cyGmNliM6s2s8sybO9rZs+Y2d/NbIGZDct9qSIlbvz40Md+1VWxK5Ey0Gi4m1kr4GZgKDAIGGVmg+o1uxKY4u57AyOBW3JdqEhJW7wYHnwwXNves2fsaqQMZHPkvj9Q7e7vuHsNMAkYUa+NA99JPd8R+Ch3JYokwG9+A+3ahUsgRQogm0shewEfpr1eChxQr83VwONmdj6wPXBkTqoTSYL33w9D+Y4dGwYIEymAXJ1QHQX82d17A8OA+8xsi+9tZmPMbI6ZzVm1alWO3lqkyE2YAGZwySWxK5Eykk24LwP6pL3unVqXbjQwBcDdXwTaA1vcU+3uE929wt0runbt2ryKRUrJ8uVwxx1hous+fRpvL5Ij2YT7bGCAmfU3s7aEE6bT6rX5ADgCwMz2IIS7Ds1FbrwRNm6Ey7a4yEwkrxoNd3evBcYBs4BFhKti3jCza8xseKrZRcC5ZjYfeAg40909X0WLlISvvgrXtR9/POy2W+xqpMxkNbaMu88AZtRbNz7t+UJA91KLpLvvPli9Gi64IHYlUoZ0h6pIPriHLpl99tEYMhKFRoUUyYcnnoBFi+Dee8OVMiIFpiN3kXz4wx+ge/cw+qNIBAp3kVyrroaZM8NNS+3axa5GypTCXSTXbr89DBB27rmxK5EypnAXyaWaGrj7bjj2WOjRI3Y1UsYU7iK59MgjsGqV5kaV6BTuIrk0cSL07QtHHRW7EilzCneRXFmyBJ58Es45RxNfS3QKd5FcueMO2GYbOPvs2JWIKNxFcqK2NpxIPeYY6NUrdjUiCneRnHj8cVixAs46K3YlIoDCXSQ37r0XdtoJhmlueCkOCneRllqzBqqqYORIaNs2djUigMJdpOX+8hfYsAF+8pPYlYh8TeEu0lL33gsDB8L++8euRORrCneRlnjvPXjuuTBHqob2lSKicBdpifvvD4+nnhq3DpF6FO4izeUeumQOPRT69Ytdjci3KNxFmmvePHj7bR21S1FSuIs01+TJ0Lo1nHBC7EpEtqBwF2kOd5gyBQYPhs6dY1cjsgWFu0hzvPwyvP8+nHJK7EpEMlK4izTHlCnhbtQRI2JXIpKRwl2kqTZvDuF+9NHQsWPsakQyUriLNNULL8CyZeqSkaKmcBdpqsmToX17GD48diUiDVK4izTFpk1hoLBhw6BDh9jViDRI4S7SFH/9Kyxfri4ZKXoKd5GmmDoVtt0WfvjD2JWIbJXCXSRbmzfDI4/AkCGw/faxqxHZqqzC3cyGmNliM6s2s8saaHOymS00szfM7MHclilSBObODVfJHH987EpEGtW6sQZm1gq4GRgMLAVmm9k0d1+Y1mYAcDlwsLuvNrOd81WwSDRVVdCqlbpkpCRkc+S+P1Dt7u+4ew0wCah/W965wM3uvhrA3VfmtkyRIlBVFYb31VgyUgKyCfdewIdpr5em1qUbCAw0s+fN7CUzG5LpG5nZGDObY2ZzVq1a1byKRWJ46y1YuBCOOy52JSJZydUJ1dbAAOAwYBRwu5ltcV+2u0909wp3r+jatWuO3lqkAKqqwqPGkpESkU24LwP6pL3unVqXbikwzd03uvu7wFuEsBdJhqoq2Gcf6Ns3diUiWckm3GcDA8ysv5m1BUYC0+q1qSIctWNmXQjdNO/ksE6ReD7+GF56SVfJSElpNNzdvRYYB8wCFgFT3P0NM7vGzOoG15gFfGpmC4FngEvc/dN8FS1SUI8+GibnUH+7lBBz9yhvXFFR4XPmzIny3iJNMmxYOKH69ttgFrsaKXNmNtfdKxprpztURbbmiy/gqafCUbuCXUqIwl1ka2bOhJoadclIyVG4i2xNVRXsvDMceGDsSkSaROEu0pANG+Cxx8KkHK1axa5GpEkU7iINefZZWLtWXTJSkhTuIg2pqgpD+x5xROxKRJpM4S6SSd3Y7UOHhvlSRUqMwl0kk1deCXemqktGSpTCXSSTqipo3Vpjt0vJUriLZFJVBYcfDh23GNxUpCQo3EXqe/NNWLxYXTJS0hTuIvVVVobH4cO33k6kiCncReqrqoL99oPevWNXItJsCneRdMuWhStlNHa7lDiFu0i6aal5aNTfLiVO4S6SrqoKBg6E7343diUiLaJwF6nz+efw9NMau10SQeEuUmfGDKitVZeMJILCXaROVRV07w4HHBC7EpEWU7iLAKxfH2ZdGjECttE/Cyl9+hSLQOhr//JLdclIYijcRSDcldqhQxhPRiQBFO4imzaF69uHDYN27WJXI5ITCneRF1+ElSt1V6okisJdpLIS2rYNsy6JJITCXcqbewj3I4+E73wndjUiOaNwl/L22mvw7ru6SkYSR+Eu5a2yMgw1oLHbJWEU7lLeKivh4IOhW7fYlYjklMJdyte778L8+bpKRhIpq3A3syFmttjMqs3ssq20O8HM3MwqcleiSJ5UVYVH9bdLAjUa7mbWCrgZGAoMAkaZ2aAM7ToAvwBeznWRInlRWQl77gm77hq7EpGcy+bIfX+g2t3fcfcaYBIwIkO7a4HfAutzWJ9IfqxcCX/7m7pkJLGyCfdewIdpr5em1n3NzPYB+rj7YzmsTSR/Hn00XOOuLhlJqBafUDWzbYAbgIuyaDvGzOaY2ZxVq1a19K1Fmq+yEvr1g+9/P3YlInmRTbgvA/qkve6dWlenA/A94Fkzew/4ATAt00lVd5/o7hXuXtG1a9fmVy3SEmvXwpNPhi4ZTacnCZVNuM8GBphZfzNrC4wEptVtdPc17t7F3fu5ez/gJWC4u8/JS8UiLTV9OmzYAD/6UexKRPKm0XB391pgHDALWARMcfc3zOwaM9NtfVJ6pkyBnj3hoINiVyKSN62zaeTuM4AZ9daNb6DtYS0vSyRP1q4N0+n99KeaTk8STZ9uKS91XTInnRS7EpG8UrhLeVGXjJQJhbuUj7oumRNPVJeMJJ4+4VI+Hn1UXTJSNhTuUj6mTlWXjJQNhbuUB3XJSJnRp1zKg7pkpMwo3KU8TJ6sLhkpKwp3Sb5PPw1dMqNGqUtGyoY+6ZJ8U6fCxo1w6qmxKxEpGIW7JN8DD8CgQbDXXrErESkYhbsk23vvhRmXTjtNw/tKWVG4S7I9+GB4/PGP49YhUmAKd0kud7j/fjjkENhll9jViBSUwl2S69VXYdEinUiVsqRwl+S6/35o00Y3LklZUrhLMm3cGML9mGOgc+fY1YgUnMJdkmn6dFi5EkaPjl2JSBQKd0mmO+8Mww0cfXTsSkSiULhL8ixbFoYbOPNMaJ3VNMEiiaNwl+S55x7YvBnOOit2JSLRKNwlWTZvhrvugkMPhd12i12NSDQKd0mW556DJUt0IlXKnsJdkuXWW6FjRzjhhNiViESlcJfk+OgjePhhOPts2G672NWIRKVwl+SYOBE2bYKxY2NXIhKdwl2SoaYGbrsNhg7ViVQRFO6SFJWVsHw5jBsXuxKRoqBwl2T44x/hn/5Jd6SKpCjcpfTNnh1mWzrvPE2ALZKifwlS+iZMgB13hHPOiV2JSNHIKtzNbIiZLTazajO7LMP2C81soZktMLOnzEzT3khhVFeHyx/HjoUOHWJXI1I0Gg13M2sF3AwMBQYBo8xsUL1mfwcq3H1P4C/A9bkuVCSjG24Ig4P9+7/HrkSkqGRz5L4/UO3u77h7DTAJGJHewN2fcfevUi9fAnrntkyRDFatgrvvhtNPhx49YlcjUlSyCfdewIdpr5em1jVkNDAz0wYzG2Nmc8xszqpVq7KvUiSTP/wBNmyAiy6KXYlI0cnpCVUzOw2oACZk2u7uE929wt0runbtmsu3lnLzySdw001w8snw3e/Grkak6GQzk8EyoE/a696pdd9iZkcCVwCHuvuG3JQn0oDf/Q7WrYOrropdiUhRyubIfTYwwMz6m1lbYCQwLb2Bme0N3AYMd/eVuS9TJM3KlfA//wM//jHssUfsakSKUqPh7u61wDhgFrAImOLub5jZNWY2PNVsArADMNXMXjWzaQ18O5GWmzAB1q+H8eNjVyJStLKaYNLdZwAz6q0bn/b8yBzXJZLZBx+EoQZOOw0GDoxdjUjR0h2qUlr+67/C43XXxa1DpMgp3KV0vPIKPPBAuPSxT5/G24uUMYW7lAZ3uPBC6NYNLr00djUiRS+rPneR6CZNguefh9tv1xgyIlnQkbsUv9Wr4T/+Ayoq4KyzYlcjUhJ05C7F7/LLwzgyM2dCq1axqxEpCTpyl+L2wgthbtQLLoC9945djUjJULhL8Vq3Ds48E/r2hV/+MnY1IiVF3TJSvC6+OEzG8dRTsMMOsasRKSk6cpfi9NhjcOut4fLHww+PXY1IyVG4S/FZujRcFfMv/wK/+lXsakRKksJdisuGDXDSSfCPf8DkydCuXeyKREqS+tyluFx4Ibz0EkydquF8RVpAR+5SPG65JSwXXwwnnhi7GpGSpnCX4vDII3D++XDssfDrX8euRqTkKdwlvuefh1GjYN994aGHoLV6C0VaSuEucb3wAgwdCr17w/TpsP32sSsSSQSFu8Tz4oswZAh07w7PPAM77xy7IpHEULhLHNOnw5FHfhPsvXrFrkgkURTuUnh/+hOMGBEudXzuOQW7SB4o3KVw1q+HsWPh5z+HYcPgf/83HLmLSM4p3KUwliyBgw4K48VccglUVurkqUge6Zozya/Nm0OgX3optGkDjz4KxxwTuyqRxNORu+TPwoVw2GFw3nlw4IHw978r2EUKROEuubdiBfzsZ2FUx9deg7vvhlmzYJddYlcmUjbULSO589FH8Pvfh6thNmyAcePgv/8bunSJXZlI2VG4S8u4w9y5oV/9vvugthZOOSVMizdgQOzqRMqWwl2aZ8WKMCzvHXfA/Pmw7bZwzjlw0UWw666xqxMpewp3yY47vPVWuLO0sjKMCeMeBvu65ZYw8FfHjrGrFJEUhbtkVlsLixeH8V+eeQaefTb0qQPstRdcfTUcf3w4aSoiRSercDezIcCNQCvgDnf/Tb3t7YB7gX2BT4FT3P293JYqebF5cwjtt98OR+bz58O8ebBgQZjqDqBbtzBJ9eGHw+DB0L9/3JpFpFGNhruZtQJuBgYDS4HZZjbN3RemNRsNrHb33cxsJPBb4JR8FCxNsHEjfP45LF8eArz+smQJVFd/E+IA3/kO7L13uJRx771hv/1g993BLN7PISJNls2R+/5Atbu/A2Bmk4ARQHq4jwCuTj3/C/BHMzN39xzWWrrcQzdHbW0I3Lrn9V/X31ZTA199FcI3/bH+ui+/DCG+evW3l3XrMtfTuTP06BFOfA4eDLvtFq5sGTAA+vSBbXT7g0ipyybcewEfpr1eChzQUBt3rzWzNcBOwCe5KPJb7roLJkwIgRne8NtLtusK9fW1taHrIx+23TYs228fTmZ26hQCu1Onb5aOHcPgXD17hqVHD2jfPj/1iEjRKOgJVTMbA4wB6Nu3b/O+SZcu4SReXTeB2ZZLpvXZrst129atw9KmTebnW3vdpg1st10I8PTH7baDdu10hC0iDcom3JcBfdJe906ty9RmqZm1BnYknFj9FnefCEwEqKioaF6XzfDhYRERkQZlc+g3GxhgZv3NrC0wEphWr8004IzU8xOBp9XfLiIST6NH7qk+9HHALMKlkHe5+xtmdg0wx92nAXcC95lZNfAZ4Q+AiIhEklWfu7vPAGbUWzc+7fl64KTcliYiIs2lM3IiIgmkcBcRSSCFu4hIAincRUQSSOEuIpJAFutydDNbBbzfzC/vQj6GNmg51dU0qqvpirU21dU0LalrF3fv2lijaOHeEmY2x90rYtdRn+pqGtXVdMVam+pqmkLUpW4ZEZEEUriLiCRQqYb7xNgFNEB1NY3qarpirU11NU3e6yrJPncREdm6Uj1yFxGRrSjacDezk8zsDTPbbGYV9bZdbmbVZrbYzI5u4Ov7m9nLqXaTU8MV57rGyWb2amp5z8xebaDde2b2WqrdnFzXkeH9rjazZWm1DWug3ZDUPqw2s8sKUNcEM3vTzBaYWaWZdWygXUH2V2M/v5m1S/2Oq1OfpX75qiXtPfuY2TNmtjD1+f9FhjaHmdmatN/v+EzfKw+1bfX3YsFNqf21wMz2KUBNu6fth1fN7Aszu6Bem4LtLzO7y8xWmtnraes6m9kTZvZ26rFTA197RqrN22Z2RqY2TeLuRbkAewC7A88CFWnrBwHzgXZAf2AJ0CrD108BRqae3wqMzXO9/xcY38C294AuBdx3VwMXN9KmVWrf7Qq0Te3TQXmu6yigder5b4Hfxtpf2fz8wM+BW1PPRwKTC/C76wHsk3reAXgrQ12HAdML9XnK9vcCDANmAgb8AHi5wPW1ApYTrgOPsr+A/wPsA7yetu564LLU88syfe6BzsA7qcdOqeedWlJL0R65u/sid1+cYdMIYJK7b3D3d4FqwiTeXzMzA/6NMFk3wD3AcfmqNfV+JwMP5es98uDric/dvQaom/g8b9z9cXevTb18iTCrVyzZ/PwjCJ8dCJ+lI1K/67xx94/dfV7q+VpgEWGO4lIwArjXg5eAjmbWo4DvfwSwxN2be3Nki7n7c4Q5LdKlf44ayqKjgSfc/TN3Xw08AQxpSS1FG+5bkWnC7vof/p2Az9OCJFObXDoEWOHubzew3YHHzWxuah7ZQhiX+q/xXQ38NzCb/ZhPZxOO8jIpxP7K5uf/1sTvQN3E7wWR6gbaG3g5w+YDzWy+mc00s38uUEmN/V5if6ZG0vABVoz9Vaebu3+cer4c6JahTc73XUEnyK7PzJ4EumfYdIW7P1LoejLJssZRbP2o/V/dfZmZ7Qw8YWZvpv7C56Uu4E/AtYR/jNcSuozObsn75aKuuv1lZlcAtcADDXybnO+vUmNmOwAPAxe4+xf1Ns8jdD18mTqfUgUMKEBZRft7SZ1TGw5cnmFzrP21BXd3MyvIJYpRw93dj2zGl2UzYfenhP8Stk4dcWVqk5MaLUwI/iNg3618j2Wpx5VmVknoEmjRP4ps952Z3Q5Mz7Apm/2Y87rM7EzgGOAIT3U2ZvgeOd9fGeRs4vdcM7M2hGB/wN3/X/3t6WHv7jPM7BYz6+LueR1DJYvfS14+U1kaCsxz9xX1N8TaX2lWmFkPd/841U21MkObZYRzA3V6E843NlspdstMA0amrmToT/gL/Ep6g1RoPEOYrBvC5N35+p/AkcCb7r4000Yz297MOtQ9J5xUfD1T21yp1895fAPvl83E57muawjwn8Bwd/+qgTaF2l9FOfF7qk//TmCRu9/QQJvudX3/ZrY/4d9xXv/oZPl7mQacnrpq5gfAmrTuiHxr8H/PMfZXPemfo4ayaBZwlJl1SnWjHpVa13yFOIPcnIUQSkuBDcAKYFbatisIVzosBoamrZ8B9Ew935UQ+tXAVKBdnur8M/Czeut6AjPS6pifWt4gdE/ke9/dB7wGLEh9sHrUryv1ehjhaowlBaqrmtCv+GpqubV+XYXcX5l+fuAawh8fgPapz0516rO0awH20b8SutMWpO2nYcDP6j5nwLjUvplPODF9UAHqyvh7qVeXATen9udrpF3llufatieE9Y5p66LsL8IfmI+Bjan8Gk04T/MU8DbwJNA51bYCuCPta89OfdaqgbNaWovuUBURSaBS7JYREZFGKNxFRBJI4S4ikkAKdxGRBFK4i4gkkMJdRCSBFO4iIgmkcBcRSaD/DydAb7nqWwBcAAAAAElFTkSuQmCC\n",
  274. "text/plain": [
  275. "<Figure size 432x288 with 1 Axes>"
  276. ]
  277. },
  278. "metadata": {
  279. "needs_background": "light"
  280. },
  281. "output_type": "display_data"
  282. }
  283. ],
  284. "source": [
  285. "# 画出 sigmoid 的图像\n",
  286. "\n",
  287. "plot_x = np.arange(-10, 10.01, 0.01)\n",
  288. "plot_y = sigmoid(plot_x)\n",
  289. "\n",
  290. "plt.plot(plot_x, plot_y, 'r')"
  291. ]
  292. },
  293. {
  294. "cell_type": "code",
  295. "execution_count": 7,
  296. "metadata": {},
  297. "outputs": [],
  298. "source": [
  299. "x_data = Variable(x_data)\n",
  300. "y_data = Variable(y_data)"
  301. ]
  302. },
  303. {
  304. "cell_type": "markdown",
  305. "metadata": {},
  306. "source": [
  307. "在 PyTorch 当中,不需要我们自己写 Sigmoid 的函数,PyTorch 已经用底层的 C++ 语言为我们写好了一些常用的函数,不仅方便我们使用,同时速度上比我们自己实现的更快,稳定性更好\n",
  308. "\n",
  309. "通过导入 `torch.nn.functional` 来使用,下面就是使用方法"
  310. ]
  311. },
  312. {
  313. "cell_type": "code",
  314. "execution_count": 21,
  315. "metadata": {},
  316. "outputs": [],
  317. "source": [
  318. "import torch.nn.functional as F"
  319. ]
  320. },
  321. {
  322. "cell_type": "code",
  323. "execution_count": 9,
  324. "metadata": {},
  325. "outputs": [],
  326. "source": [
  327. "# 定义 logistic 回归模型\n",
  328. "w = Variable(torch.randn(2, 1), requires_grad=True) \n",
  329. "b = Variable(torch.zeros(1), requires_grad=True)\n",
  330. "\n",
  331. "def logistic_regression(x):\n",
  332. " return torch.sigmoid(torch.mm(x, w) + b)"
  333. ]
  334. },
  335. {
  336. "cell_type": "markdown",
  337. "metadata": {},
  338. "source": [
  339. "在更新之前,我们可以画出分类的效果"
  340. ]
  341. },
  342. {
  343. "cell_type": "code",
  344. "execution_count": 10,
  345. "metadata": {},
  346. "outputs": [
  347. {
  348. "data": {
  349. "text/plain": [
  350. "<matplotlib.legend.Legend at 0x7fb05e19dbe0>"
  351. ]
  352. },
  353. "execution_count": 10,
  354. "metadata": {},
  355. "output_type": "execute_result"
  356. },
  357. {
  358. "data": {
  359. "image/png": "\n",
  360. "text/plain": [
  361. "<Figure size 432x288 with 1 Axes>"
  362. ]
  363. },
  364. "metadata": {
  365. "needs_background": "light"
  366. },
  367. "output_type": "display_data"
  368. }
  369. ],
  370. "source": [
  371. "# 画出参数更新之前的结果 (FIXME: the plot is wrong)\n",
  372. "w0 = w[0].data[0]\n",
  373. "w1 = w[1].data[0]\n",
  374. "b0 = b.data[0]\n",
  375. "\n",
  376. "plot_x = np.arange(0.2, 1, 0.01)\n",
  377. "plot_y = (-w0.numpy() * plot_x - b0.numpy()) / w1.numpy()\n",
  378. "\n",
  379. "plt.plot(plot_x, plot_y, 'g', label='cutting line')\n",
  380. "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n",
  381. "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n",
  382. "plt.legend(loc='best')"
  383. ]
  384. },
  385. {
  386. "cell_type": "markdown",
  387. "metadata": {},
  388. "source": [
  389. "可以看到分类效果基本是混乱的,我们来计算一下 loss,公式如下\n",
  390. "\n",
  391. "$$\n",
  392. "loss = -(y * log(\\hat{y}) + (1 - y) * log(1 - \\hat{y}))\n",
  393. "$$"
  394. ]
  395. },
  396. {
  397. "cell_type": "code",
  398. "execution_count": 11,
  399. "metadata": {},
  400. "outputs": [],
  401. "source": [
  402. "# 计算loss\n",
  403. "def binary_loss(y_pred, y):\n",
  404. " logits = (y * y_pred.clamp(1e-12).log() + (1 - y) * (1 - y_pred).clamp(1e-12).log()).mean()\n",
  405. " return -logits"
  406. ]
  407. },
  408. {
  409. "cell_type": "markdown",
  410. "metadata": {},
  411. "source": [
  412. "注意到其中使用 `.clamp`,这是[文档](http://pytorch.org/docs/0.3.0/torch.html?highlight=clamp#torch.clamp)的内容,查看一下,并且思考一下这里是否一定要使用这个函数,如果不使用会出现什么样的结果\n",
  413. "\n",
  414. "**提示:查看一个 log 函数的图像**"
  415. ]
  416. },
  417. {
  418. "cell_type": "code",
  419. "execution_count": 12,
  420. "metadata": {},
  421. "outputs": [
  422. {
  423. "name": "stdout",
  424. "output_type": "stream",
  425. "text": [
  426. "tensor(0.8986, grad_fn=<NegBackward>)\n"
  427. ]
  428. }
  429. ],
  430. "source": [
  431. "y_pred = logistic_regression(x_data)\n",
  432. "loss = binary_loss(y_pred, y_data)\n",
  433. "print(loss)"
  434. ]
  435. },
  436. {
  437. "cell_type": "markdown",
  438. "metadata": {},
  439. "source": [
  440. "得到 loss 之后,我们还是使用梯度下降法更新参数,这里可以使用自动求导来直接得到参数的导数,感兴趣的同学可以去手动推导一下导数的公式"
  441. ]
  442. },
  443. {
  444. "cell_type": "code",
  445. "execution_count": 18,
  446. "metadata": {},
  447. "outputs": [
  448. {
  449. "name": "stdout",
  450. "output_type": "stream",
  451. "text": [
  452. "tensor(0.7402, grad_fn=<NegBackward>)\n",
  453. "tensor(0.7348, grad_fn=<NegBackward>)\n",
  454. "tensor(0.7299, grad_fn=<NegBackward>)\n",
  455. "tensor(0.7254, grad_fn=<NegBackward>)\n",
  456. "tensor(0.7212, grad_fn=<NegBackward>)\n",
  457. "tensor(0.7175, grad_fn=<NegBackward>)\n",
  458. "tensor(0.7140, grad_fn=<NegBackward>)\n",
  459. "tensor(0.7108, grad_fn=<NegBackward>)\n",
  460. "tensor(0.7079, grad_fn=<NegBackward>)\n",
  461. "tensor(0.7052, grad_fn=<NegBackward>)\n",
  462. "tensor(0.7028, grad_fn=<NegBackward>)\n",
  463. "tensor(0.7005, grad_fn=<NegBackward>)\n",
  464. "tensor(0.6984, grad_fn=<NegBackward>)\n",
  465. "tensor(0.6965, grad_fn=<NegBackward>)\n",
  466. "tensor(0.6947, grad_fn=<NegBackward>)\n",
  467. "tensor(0.6931, grad_fn=<NegBackward>)\n",
  468. "tensor(0.6916, grad_fn=<NegBackward>)\n",
  469. "tensor(0.6901, grad_fn=<NegBackward>)\n",
  470. "tensor(0.6888, grad_fn=<NegBackward>)\n",
  471. "tensor(0.6876, grad_fn=<NegBackward>)\n",
  472. "tensor(0.6865, grad_fn=<NegBackward>)\n",
  473. "tensor(0.6855, grad_fn=<NegBackward>)\n",
  474. "tensor(0.6845, grad_fn=<NegBackward>)\n",
  475. "tensor(0.6836, grad_fn=<NegBackward>)\n",
  476. "tensor(0.6827, grad_fn=<NegBackward>)\n",
  477. "tensor(0.6819, grad_fn=<NegBackward>)\n",
  478. "tensor(0.6811, grad_fn=<NegBackward>)\n",
  479. "tensor(0.6804, grad_fn=<NegBackward>)\n",
  480. "tensor(0.6797, grad_fn=<NegBackward>)\n",
  481. "tensor(0.6791, grad_fn=<NegBackward>)\n",
  482. "tensor(0.6785, grad_fn=<NegBackward>)\n",
  483. "tensor(0.6779, grad_fn=<NegBackward>)\n",
  484. "tensor(0.6773, grad_fn=<NegBackward>)\n",
  485. "tensor(0.6768, grad_fn=<NegBackward>)\n",
  486. "tensor(0.6763, grad_fn=<NegBackward>)\n",
  487. "tensor(0.6758, grad_fn=<NegBackward>)\n",
  488. "tensor(0.6753, grad_fn=<NegBackward>)\n",
  489. "tensor(0.6749, grad_fn=<NegBackward>)\n",
  490. "tensor(0.6745, grad_fn=<NegBackward>)\n",
  491. "tensor(0.6740, grad_fn=<NegBackward>)\n",
  492. "tensor(0.6736, grad_fn=<NegBackward>)\n",
  493. "tensor(0.6732, grad_fn=<NegBackward>)\n",
  494. "tensor(0.6728, grad_fn=<NegBackward>)\n",
  495. "tensor(0.6725, grad_fn=<NegBackward>)\n",
  496. "tensor(0.6721, grad_fn=<NegBackward>)\n",
  497. "tensor(0.6718, grad_fn=<NegBackward>)\n",
  498. "tensor(0.6714, grad_fn=<NegBackward>)\n",
  499. "tensor(0.6711, grad_fn=<NegBackward>)\n",
  500. "tensor(0.6707, grad_fn=<NegBackward>)\n",
  501. "tensor(0.6704, grad_fn=<NegBackward>)\n",
  502. "tensor(0.6701, grad_fn=<NegBackward>)\n",
  503. "tensor(0.6698, grad_fn=<NegBackward>)\n",
  504. "tensor(0.6694, grad_fn=<NegBackward>)\n",
  505. "tensor(0.6691, grad_fn=<NegBackward>)\n",
  506. "tensor(0.6688, grad_fn=<NegBackward>)\n",
  507. "tensor(0.6685, grad_fn=<NegBackward>)\n",
  508. "tensor(0.6682, grad_fn=<NegBackward>)\n",
  509. "tensor(0.6679, grad_fn=<NegBackward>)\n",
  510. "tensor(0.6676, grad_fn=<NegBackward>)\n",
  511. "tensor(0.6673, grad_fn=<NegBackward>)\n",
  512. "tensor(0.6671, grad_fn=<NegBackward>)\n",
  513. "tensor(0.6668, grad_fn=<NegBackward>)\n",
  514. "tensor(0.6665, grad_fn=<NegBackward>)\n",
  515. "tensor(0.6662, grad_fn=<NegBackward>)\n",
  516. "tensor(0.6659, grad_fn=<NegBackward>)\n",
  517. "tensor(0.6656, grad_fn=<NegBackward>)\n",
  518. "tensor(0.6654, grad_fn=<NegBackward>)\n",
  519. "tensor(0.6651, grad_fn=<NegBackward>)\n",
  520. "tensor(0.6648, grad_fn=<NegBackward>)\n",
  521. "tensor(0.6645, grad_fn=<NegBackward>)\n",
  522. "tensor(0.6643, grad_fn=<NegBackward>)\n",
  523. "tensor(0.6640, grad_fn=<NegBackward>)\n",
  524. "tensor(0.6637, grad_fn=<NegBackward>)\n",
  525. "tensor(0.6634, grad_fn=<NegBackward>)\n",
  526. "tensor(0.6632, grad_fn=<NegBackward>)\n",
  527. "tensor(0.6629, grad_fn=<NegBackward>)\n",
  528. "tensor(0.6626, grad_fn=<NegBackward>)\n",
  529. "tensor(0.6624, grad_fn=<NegBackward>)\n",
  530. "tensor(0.6621, grad_fn=<NegBackward>)\n",
  531. "tensor(0.6618, grad_fn=<NegBackward>)\n",
  532. "tensor(0.6616, grad_fn=<NegBackward>)\n",
  533. "tensor(0.6613, grad_fn=<NegBackward>)\n",
  534. "tensor(0.6610, grad_fn=<NegBackward>)\n",
  535. "tensor(0.6608, grad_fn=<NegBackward>)\n",
  536. "tensor(0.6605, grad_fn=<NegBackward>)\n",
  537. "tensor(0.6603, grad_fn=<NegBackward>)\n",
  538. "tensor(0.6600, grad_fn=<NegBackward>)\n",
  539. "tensor(0.6597, grad_fn=<NegBackward>)\n",
  540. "tensor(0.6595, grad_fn=<NegBackward>)\n",
  541. "tensor(0.6592, grad_fn=<NegBackward>)\n",
  542. "tensor(0.6589, grad_fn=<NegBackward>)\n",
  543. "tensor(0.6587, grad_fn=<NegBackward>)\n",
  544. "tensor(0.6584, grad_fn=<NegBackward>)\n",
  545. "tensor(0.6582, grad_fn=<NegBackward>)\n",
  546. "tensor(0.6579, grad_fn=<NegBackward>)\n",
  547. "tensor(0.6576, grad_fn=<NegBackward>)\n",
  548. "tensor(0.6574, grad_fn=<NegBackward>)\n",
  549. "tensor(0.6571, grad_fn=<NegBackward>)\n",
  550. "tensor(0.6569, grad_fn=<NegBackward>)\n",
  551. "tensor(0.6566, grad_fn=<NegBackward>)\n"
  552. ]
  553. }
  554. ],
  555. "source": [
  556. "# 自动求导并更新参数\n",
  557. "for i in range(10):\n",
  558. " w.grad.data.zero_()\n",
  559. " b.grad.data.zero_()\n",
  560. " \n",
  561. " # calc grad\n",
  562. " loss.backward()\n",
  563. " w.data = w.data - 0.1 * w.grad.data\n",
  564. " b.data = b.data - 0.1 * b.grad.data\n",
  565. "\n",
  566. " # 算出一次更新之后的loss\n",
  567. " y_pred = logistic_regression(x_data)\n",
  568. " loss = binary_loss(y_pred, y_data)\n",
  569. " print(loss)"
  570. ]
  571. },
  572. {
  573. "cell_type": "code",
  574. "execution_count": 20,
  575. "metadata": {},
  576. "outputs": [
  577. {
  578. "data": {
  579. "text/plain": [
  580. "<matplotlib.legend.Legend at 0x7fb0598c0cc0>"
  581. ]
  582. },
  583. "execution_count": 20,
  584. "metadata": {},
  585. "output_type": "execute_result"
  586. },
  587. {
  588. "data": {
  589. "image/png": "\n",
  590. "text/plain": [
  591. "<Figure size 432x288 with 1 Axes>"
  592. ]
  593. },
  594. "metadata": {
  595. "needs_background": "light"
  596. },
  597. "output_type": "display_data"
  598. }
  599. ],
  600. "source": [
  601. "# 画出参数更新之前的结果\n",
  602. "w0 = w[0].data[0]\n",
  603. "w1 = w[1].data[0]\n",
  604. "b0 = b.data[0]\n",
  605. "\n",
  606. "plot_x = np.arange(0.2, 1, 0.01)\n",
  607. "plot_y = (-w0.numpy() * plot_x - b0.numpy()) / w1.numpy()\n",
  608. "\n",
  609. "plt.plot(plot_x, plot_y, 'g', label='cutting line')\n",
  610. "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n",
  611. "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n",
  612. "plt.legend(loc='best')"
  613. ]
  614. },
  615. {
  616. "cell_type": "markdown",
  617. "metadata": {},
  618. "source": [
  619. "上面的参数更新方式其实是繁琐的重复操作,如果我们的参数很多,比如有 100 个,那么我们需要写 100 行来更新参数,为了方便,我们可以写成一个函数来更新,其实 PyTorch 已经为我们封装了一个函数来做这件事,这就是 PyTorch 中的优化器 `torch.optim`\n",
  620. "\n",
  621. "使用 `torch.optim` 需要另外一个数据类型,就是 `nn.Parameter`,这个本质上和 Variable 是一样的,只不过 `nn.Parameter` 默认是要求梯度的,而 Variable 默认是不求梯度的\n",
  622. "\n",
  623. "使用 `torch.optim.SGD` 可以使用梯度下降法来更新参数,PyTorch 中的优化器有更多的优化算法,在本章后面的课程我们会更加详细的介绍\n",
  624. "\n",
  625. "将参数 w 和 b 放到 `torch.optim.SGD` 中之后,说明一下学习率的大小,就可以使用 `optimizer.step()` 来更新参数了,比如下面我们将参数传入优化器,学习率设置为 1.0"
  626. ]
  627. },
  628. {
  629. "cell_type": "code",
  630. "execution_count": 33,
  631. "metadata": {},
  632. "outputs": [],
  633. "source": [
  634. "# 使用 torch.optim 更新参数\n",
  635. "from torch import nn\n",
  636. "w = nn.Parameter(torch.randn(2, 1))\n",
  637. "b = nn.Parameter(torch.zeros(1))\n",
  638. "\n",
  639. "def logistic_regression(x):\n",
  640. " return torch.sigmoid(torch.mm(x, w) + b)\n",
  641. "\n",
  642. "optimizer = torch.optim.SGD([w, b], lr=1.)"
  643. ]
  644. },
  645. {
  646. "cell_type": "code",
  647. "execution_count": 34,
  648. "metadata": {},
  649. "outputs": [
  650. {
  651. "name": "stderr",
  652. "output_type": "stream",
  653. "text": [
  654. "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:15: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n",
  655. " from ipykernel import kernelapp as app\n",
  656. "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/ipykernel_launcher.py:17: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number\n"
  657. ]
  658. },
  659. {
  660. "name": "stdout",
  661. "output_type": "stream",
  662. "text": [
  663. "epoch: 200, Loss: 0.39010, Acc: 0.00000\n",
  664. "epoch: 400, Loss: 0.32184, Acc: 0.00000\n",
  665. "epoch: 600, Loss: 0.28917, Acc: 0.00000\n",
  666. "epoch: 800, Loss: 0.26983, Acc: 0.00000\n",
  667. "epoch: 1000, Loss: 0.25700, Acc: 0.00000\n",
  668. "\n",
  669. "During Time: 0.248 s\n"
  670. ]
  671. }
  672. ],
  673. "source": [
  674. "# 进行 1000 次更新\n",
  675. "import time\n",
  676. "\n",
  677. "start = time.time()\n",
  678. "for e in range(1000):\n",
  679. " # 前向传播\n",
  680. " y_pred = logistic_regression(x_data)\n",
  681. " loss = binary_loss(y_pred, y_data) # 计算 loss\n",
  682. " # 反向传播\n",
  683. " optimizer.zero_grad() # 使用优化器将梯度归 0\n",
  684. " loss.backward()\n",
  685. " optimizer.step() # 使用优化器来更新参数\n",
  686. " # 计算正确率\n",
  687. " mask = y_pred.ge(0.5).float()\n",
  688. " acc = (mask == y_data).sum().data[0] / y_data.shape[0]\n",
  689. " if (e + 1) % 200 == 0:\n",
  690. " print('epoch: {}, Loss: {:.5f}, Acc: {:.5f}'.format(e+1, loss.data[0], acc))\n",
  691. "during = time.time() - start\n",
  692. "print()\n",
  693. "print('During Time: {:.3f} s'.format(during))"
  694. ]
  695. },
  696. {
  697. "cell_type": "markdown",
  698. "metadata": {},
  699. "source": [
  700. "可以看到使用优化器之后更新参数非常简单,只需要在自动求导之前使用**`optimizer.zero_grad()`** 来归 0 梯度,然后使用 **`optimizer.step()`**来更新参数就可以了,非常简便\n",
  701. "\n",
  702. "同时经过了 1000 次更新,loss 也降得比较低了"
  703. ]
  704. },
  705. {
  706. "cell_type": "markdown",
  707. "metadata": {},
  708. "source": [
  709. "下面我们画出更新之后的结果"
  710. ]
  711. },
  712. {
  713. "cell_type": "code",
  714. "execution_count": 36,
  715. "metadata": {},
  716. "outputs": [
  717. {
  718. "data": {
  719. "text/plain": [
  720. "<matplotlib.legend.Legend at 0x7f083e823400>"
  721. ]
  722. },
  723. "execution_count": 36,
  724. "metadata": {},
  725. "output_type": "execute_result"
  726. },
  727. {
  728. "data": {
  729. "image/png": "\n",
  730. "text/plain": [
  731. "<Figure size 432x288 with 1 Axes>"
  732. ]
  733. },
  734. "metadata": {
  735. "needs_background": "light"
  736. },
  737. "output_type": "display_data"
  738. }
  739. ],
  740. "source": [
  741. "# 画出更新之后的结果\n",
  742. "w0 = w[0].data[0]\n",
  743. "w1 = w[1].data[0]\n",
  744. "b0 = b.data[0]\n",
  745. "\n",
  746. "plot_x = np.arange(0.2, 1, 0.01)\n",
  747. "plot_y = (-w0.numpy() * plot_x - b0.numpy()) / w1.numpy()\n",
  748. "\n",
  749. "plt.plot(plot_x, plot_y, 'g', label='cutting line')\n",
  750. "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n",
  751. "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n",
  752. "plt.legend(loc='best')"
  753. ]
  754. },
  755. {
  756. "cell_type": "markdown",
  757. "metadata": {},
  758. "source": [
  759. "可以看到更新之后模型已经能够基本将这两类点分开了"
  760. ]
  761. },
  762. {
  763. "cell_type": "markdown",
  764. "metadata": {},
  765. "source": [
  766. "前面我们使用了自己写的 loss,其实 PyTorch 已经为我们写好了一些常见的 loss,比如线性回归里面的 loss 是 `nn.MSE()`,而 Logistic 回归的二分类 loss 在 PyTorch 中是 `nn.BCEWithLogitsLoss()`,关于更多的 loss,可以查看[文档](http://pytorch.org/docs/0.3.0/nn.html#loss-functions)\n",
  767. "\n",
  768. "PyTorch 为我们实现的 loss 函数有两个好处,第一是方便我们使用,不需要重复造轮子,第二就是其实现是在底层 C++ 语言上的,所以速度上和稳定性上都要比我们自己实现的要好\n",
  769. "\n",
  770. "另外,PyTorch 出于稳定性考虑,将模型的 Sigmoid 操作和最后的 loss 都合在了 `nn.BCEWithLogitsLoss()`,所以我们使用 PyTorch 自带的 loss 就不需要再加上 Sigmoid 操作了"
  771. ]
  772. },
  773. {
  774. "cell_type": "code",
  775. "execution_count": 17,
  776. "metadata": {
  777. "collapsed": true
  778. },
  779. "outputs": [],
  780. "source": [
  781. "# 使用自带的loss\n",
  782. "criterion = nn.BCEWithLogitsLoss() # 将 sigmoid 和 loss 写在一层,有更快的速度、更好的稳定性\n",
  783. "\n",
  784. "w = nn.Parameter(torch.randn(2, 1))\n",
  785. "b = nn.Parameter(torch.zeros(1))\n",
  786. "\n",
  787. "def logistic_reg(x):\n",
  788. " return torch.mm(x, w) + b\n",
  789. "\n",
  790. "optimizer = torch.optim.SGD([w, b], 1.)"
  791. ]
  792. },
  793. {
  794. "cell_type": "code",
  795. "execution_count": 18,
  796. "metadata": {},
  797. "outputs": [
  798. {
  799. "name": "stdout",
  800. "output_type": "stream",
  801. "text": [
  802. "\n",
  803. " 0.6363\n",
  804. "[torch.FloatTensor of size 1]\n",
  805. "\n"
  806. ]
  807. }
  808. ],
  809. "source": [
  810. "y_pred = logistic_reg(x_data)\n",
  811. "loss = criterion(y_pred, y_data)\n",
  812. "print(loss.data)"
  813. ]
  814. },
  815. {
  816. "cell_type": "code",
  817. "execution_count": 19,
  818. "metadata": {},
  819. "outputs": [
  820. {
  821. "name": "stdout",
  822. "output_type": "stream",
  823. "text": [
  824. "epoch: 200, Loss: 0.39538, Acc: 0.88000\n",
  825. "epoch: 400, Loss: 0.32407, Acc: 0.87000\n",
  826. "epoch: 600, Loss: 0.29039, Acc: 0.87000\n",
  827. "epoch: 800, Loss: 0.27061, Acc: 0.87000\n",
  828. "epoch: 1000, Loss: 0.25753, Acc: 0.88000\n",
  829. "\n",
  830. "During Time: 0.527 s\n"
  831. ]
  832. }
  833. ],
  834. "source": [
  835. "# 同样进行 1000 次更新\n",
  836. "\n",
  837. "start = time.time()\n",
  838. "for e in range(1000):\n",
  839. " # 前向传播\n",
  840. " y_pred = logistic_reg(x_data)\n",
  841. " loss = criterion(y_pred, y_data)\n",
  842. " # 反向传播\n",
  843. " optimizer.zero_grad()\n",
  844. " loss.backward()\n",
  845. " optimizer.step()\n",
  846. " # 计算正确率\n",
  847. " mask = y_pred.ge(0.5).float()\n",
  848. " acc = (mask == y_data).sum().data[0] / y_data.shape[0]\n",
  849. " if (e + 1) % 200 == 0:\n",
  850. " print('epoch: {}, Loss: {:.5f}, Acc: {:.5f}'.format(e+1, loss.data[0], acc))\n",
  851. "\n",
  852. "during = time.time() - start\n",
  853. "print()\n",
  854. "print('During Time: {:.3f} s'.format(during))"
  855. ]
  856. },
  857. {
  858. "cell_type": "markdown",
  859. "metadata": {},
  860. "source": [
  861. "可以看到,使用了 PyTorch 自带的 loss 之后,速度有了一定的上升,虽然看上去速度的提升并不多,但是这只是一个小网络,对于大网络,使用自带的 loss 不管对于稳定性还是速度而言,都有质的飞跃,同时也避免了重复造轮子的困扰"
  862. ]
  863. },
  864. {
  865. "cell_type": "markdown",
  866. "metadata": {},
  867. "source": [
  868. "下一节课我们会介绍 PyTorch 中构建模型的模块 `Sequential` 和 `Module`,使用这个可以帮助我们更方便地构建模型"
  869. ]
  870. }
  871. ],
  872. "metadata": {
  873. "kernelspec": {
  874. "display_name": "Python 3",
  875. "language": "python",
  876. "name": "python3"
  877. },
  878. "language_info": {
  879. "codemirror_mode": {
  880. "name": "ipython",
  881. "version": 3
  882. },
  883. "file_extension": ".py",
  884. "mimetype": "text/x-python",
  885. "name": "python",
  886. "nbconvert_exporter": "python",
  887. "pygments_lexer": "ipython3",
  888. "version": "3.5.2"
  889. }
  890. },
  891. "nbformat": 4,
  892. "nbformat_minor": 2
  893. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。