|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726 |
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 逻辑斯蒂回归模型"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "上一节课我们学习了简单的线性回归模型,这一节我们会学习第二个模型:逻辑斯蒂回归模型(Logistic Regression)。\n",
- "\n",
- "逻辑斯蒂回归是一种广义的回归模型,其与多元线性回归有着很多相似之处,模型的形式基本相同,虽然也被称为回归,但是其更多的情况使用在分类问题上,同时又以二分类更为常用。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 1. 模型形式\n",
- "\n",
- "逻辑斯蒂回归的模型形式和线性回归一样,都是 $y = wx + b$,其中 $x$ 可以是一个多维的特征,唯一不同的地方在于逻辑斯蒂回归会对 $y$ 作用一个 logistic 函数,将其变为一种概率的结果。 \n",
- "\n",
- "$$\n",
- "h_\\theta(x) = g(\\theta^T x) = \\frac{1}{1+e^{-\\theta^T x}}\n",
- "$$\n",
- "\n",
- "Logistic 函数作为 Logistic 回归的核心,我们下面讲一讲 Logistic 函数,也被称为 Sigmoid 函数。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 1.1 Sigmoid 函数\n",
- "Sigmoid 函数非常简单,其公式如下\n",
- "\n",
- "$$\n",
- "f(x) = \\frac{1}{1 + e^{-x}}\n",
- "$$\n",
- "\n",
- "Sigmoid 函数的图像如下"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAEICAYAAAC3Y/QeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAip0lEQVR4nO3deZQU5dn+8e8tyCLIoqCiLIqCSlwBt6hvQIiCxt0YPMpPjUg0IZpETVATQ9CTRD2Jr764oLgvwARhHA2iYAYToiiCSlgEUXFBlsgqiAww9++Pp5Bm7JlpZrq7erk+59SZrq6nq68pmrtrnqp6ytwdEREpLrvEHUBERLJPxV9EpAip+IuIFCEVfxGRIqTiLyJShFT8RUSKkIq/5AUze8DMfleH13U0s/Vm1qAOr73NzL4ws2U7+9r6qOvvKrIzTOf5S7qZ2WJgkLtPydf3NrOOwAKgk7uvSEe2at7nMkLekzL1HiLJaM9fJLmOwMpMFn6ROKn4S9aYWWMz+18z+zya/tfMGics/7WZLY2WDTIzN7ODomWPmdlt0eM2ZvaCma0xs1Vm9i8z28XMniQU7eejrp5fm9n+0XoaRq/dw8wejd5jtZmVJsnZF5gM7But5zEz62Vmn1Vptzhqi5kNM7MSM3vCzL40s7lm1jOhbQczG29m/zWzlWY2wswOBR4ATojeZ03V3zWav9LMFkW/a5mZ7ZuwzM3sKjN7P9oe95qZ1fffSgqfir9k083A8cBRwJHAscBvAcysH/AroC9wENCrhvVcB3wGtAX2Bm4C3N0HAp8AZ7p7c3e/I8lrnwR2A74D7AXcVbVB1GXUH/g8Ws9lKf5+ZwFjgFZAGTAi+t0aAC8AHwP7A/sBY9x9PnAV8Hr0Pq2qrtDMTgH+BFwItIvWMaZKsx8AxwBHRO1OSzGvFDEVf8mmi4Hh7r7C3f8L/AEYGC27EHjU3ee6+1fAsBrWs5lQCDu5+2Z3/5encPDKzNoRivpV7r46eu2r9fmFqpjm7hPdfSvhS+bI6PljgX2BG9x9g7t/7e7TUlznxcAj7j7L3TcBNxL+Utg/oc2f3X2Nu38ClBO+XEVqpOIv2bQvYc91m4+j57Yt+zRhWeLjqu4EFgEvm9mHZjY0xffvAKxy99Uptt9ZiWcFfQU0ibqbOgAfu/uWOqxzh23m7uuBlYS/Hqp73+Z1eB8pMir+kk2fA50S5jtGzwEsBdonLOtQ3Urc/Ut3v87dOxO6Wn5lZn22La7h/T8F9jCzVjsbHNhA6C4CvunKaZviaz8FOm477lBFbX+x7LDNzKwZsCewJMX3FklKxV8yZVcza5IwNQRGA781s7Zm1ga4BXgqal8CXG5mh5rZbkC157mb2Q/M7KDowOZaYCtQGS1eDnRO9jp3Xwq8CNxnZq3NbFcz+58Uf5+FhD35M8xsV8Kxisa1vGabNwlfbn82s2bR9jgxIW97M2tUzWtHE7bLUdHB8T8Cb7j74hTfWyQpFX/JlInAxoRpGHAb8BYwG/gPMCt6Dnd/EbiH0Ge9CJgerWdTknV3AaYA64HXgfvcvTxa9ifCF8waM7s+yWsHEo4ZvAesAH6Ryi/j7muBnwKjCHvdGwgHnVN57VbgTMKB7E+i1/0oWvwPYC6wzMy+SPLaKYQvwmcJXyAHAgNSeV+RmugiL8lJ0WmQc4DGdewrF5EaaM9fcoaZnRtdC9AauB14XoVfJDNqLf5m9oiZrTCzOdUsNzO7J7oIZbaZdU9/TCkSPyF0xXxA6Me/Ot44IoWr1m6f6IDYeuAJdz8syfLTgZ8DpwPHAXe7+3EZyCoiImlS656/u/8TWFVDk7MJXwzu7tOBVtHFNCIikqOSnXe8s/ZjxwtyPoueW1q1oZkNBgYDNGnSpEfHjh3T8PaZVVlZyS675P6hEeVMn3zICPmV02wXtm7dhS1bjMpKqKy0hAm2bt3+uLLScAf38PPb8+FxsTFzunRZz8KFC79w91SvMalWOop/ytz9QeBBgIMPPtgXLFiQzbevk6lTp9KrV6+4Y9RKOdMnHzJCbuR0h+XL4eOPw7R4cfi5bFl4fsUKWLp0Cxs21FxqGjeGFi3C1KwZNGkCTZvu+LPq48aNoWHD1KYGDXac32WXMJlt//nOO2/To8fRmO34fE2Pa1uWOMRe1eH2dnaZGXTsCGb2MWmQjuK/hB2vxmyPrj4UKSju8NFH8J//wPz5MG9emObPh6++2rFtq1aw336w117QvTscfvgyundvz157hedatdpe6LdNjVO9XC6j1nLyyXFnyJ50FP8yYIiZjSEc8F0bXUkpInnqyy9hxgyYPh1efz38/CLhErT27aFbN7jySujSBTp12j61aLHjuqZOXUSvXu2R3FJr8Tez0YThddtE45n/HtgVwN0fIFzJeTrhqsyvgMszFVZEMsMd3n4bJk0K0+uvw5boCotDDoEf/ACOPx6OPjrMVy3wkn9qLf7uflEtyx34WdoSiUhWuMNbb0FJSZg++SQ8f/TRcP310KsXHHsstG4da0zJkKwe8BWR+K1aBY8/DiNHwoIFsOuucOqp8Ic/QP/+sPfecSeUbFDxFykS8+fDHXfA6NGwaRN897vw8MNw7rnauy9GKv4iBW72bLjtNhg3LpwqefnlcNVVcOSRtb9WCpeKv0iBWroUfvMbePJJ2H13GDoUfvlLaFvvy4OkEKj4ixSYigq4+24YPjw8/s1v4Ne/hj32iDuZ5BIVf5ECMns2XHwxzJkTTs+86y446KC4U0kuyv2BQUSkVpWV8Je/wDHHhIuxysrg+edV+KV62vMXyXP//S9cdBG88gqccw48+KD69aV2Kv4ieWzuXDjzzHBw96GH4Iorvj1ImEgy6vYRyVNvvLEHJ5wAGzfCq6/CoEEq/JI6FX+RPPTww3DTTYdz4IFhALZjj407keQbFX+RPDNyZNjL79lzFdOmhRE2RXaWir9IHrn//nB17hlnwK23zqFZs7gTSb5S8RfJE/ffDz/9aTjA++yz0KiRxx1J8pjO9hHJA2Vl8LOfhQu3xo2DRo3iTiT5Tnv+Ijlu5sxwHn/PnjB2rAq/pIeKv0gO+/TT0M3Tpk3Y+99tt7gTSaFQt49IjtqwIRT+DRvg3/+GffaJO5EUEhV/kRx17bVhoLaJE+Gww+JOI4VG3T4iOWj06HAh1403Qr9+caeRQqTiL5JjPvgAfvKTcJvFYcPiTiOFSsVfJIdUVIQzexo0gGeeCTdXF8kE9fmL5JBbbw1j9YwbB506xZ1GCpn2/EVyxJw58Oc/wyWXwPnnx51GCp2Kv0gOqKyEwYOhZUv461/jTiPFQN0+Ijlg5Eh4/XV4/HHdhUuyQ3v+IjFbsgSGDoU+fWDgwLjTSLFQ8ReJ2S9+Ec7yeeAB3YlLskfFXyRG//pXOLPnppvgoIPiTiPFRMVfJCbucMMNsO++cN11caeRYqMDviIxGTcO3ngjDOOg0Tol27TnLxKDioowbs/hh8Oll8adRoqR9vxFYnD//WEMnxdfDEM5iGSb9vxFsmztWhg+HPr2hdNOizuNFCsVf5EsGzECVq2C22/XqZ0Sn5SKv5n1M7MFZrbIzIYmWd7RzMrN7G0zm21mp6c/qkj+W78e7ror3Ii9e/e400gxq7X4m1kD4F6gP9ANuMjMulVp9lugxN2PBgYA96U7qEghGDkSVq6Em2+OO4kUu1T2/I8FFrn7h+5eAYwBzq7SxoEW0eOWwOfpiyhSGDZuhDvvDH39xx8fdxopdubuNTcwuwDo5+6DovmBwHHuPiShTTvgZaA10Azo6+4zk6xrMDAYoG3btj1KSkrS9XtkzPr162nevHncMWqlnOmTqYwTJuzHPfd04a673uaoo9bWe335sC1BOdOtd+/eM929Z71X5O41TsAFwKiE+YHAiCptfgVcFz0+AZgH7FLTert27er5oLy8PO4IKVHO9MlExk2b3Nu3dz/pJPfKyvSsMx+2pbtyphvwltdSt1OZUjnPfwnQIWG+ffRcoiuAftGXyetm1gRoA6yo21eSSGF58kn47DMYNUpn+EhuSKXPfwbQxcwOMLNGhAO6ZVXafAL0ATCzQ4EmwH/TGVQkX7mHM3yOOgpOPTXuNCJBrXv+7r7FzIYALwENgEfcfa6ZDSf8+VEGXAc8ZGa/JBz8vSz680Sk6L3yCsydC48+qr1+yR0pDe/g7hOBiVWeuyXh8TzgxPRGEykM99wT7s41YEDcSUS20xW+Ihn0wQfwwgtw1VXQpEncaUS2U/EXyaD/+z9o2BCuvjruJCI7UvEXyZB16+CRR+DCC6Fdu7jTiOxIxV8kQx57DL78Eq69Nu4kIt+m4i+SAe5w771hGIdjjok7jci36WYuIhnwz3/CwoXwxBNxJxFJTnv+IhkwahS0bAnnnx93EpHkVPxF0mz16nBz9osv1o3ZJXep+Iuk2dNPw9dfw6BBcScRqZ6Kv0gaucNDD0GPHnD00XGnEameir9IGs2cCbNna69fcp+Kv0gajRoFTZvCRRfFnUSkZir+ImmyYQM880y4ordly7jTiNRMxV8kTSZMCFf0XnFF3ElEaqfiL5ImTz4J++8PJ50UdxKR2qn4i6TB0qUwZQpccolu2CL5QcVfJA3GjIHKynBhl0g+UPEXSYOnnoKePeGQQ+JOIpIaFX+Repo3D2bNCl0+IvlCxV+knp56Cho00D16Jb+o+IvUQ2VlGMvn+9+HvfeOO41I6lT8Reph2jT45BMYODDuJCI7R8VfpB6efhqaNYOzz447icjOUfEXqaMtW+DZZ+HMM8MXgEg+UfEXqaPycli5En70o7iTiOw8FX+ROho7Fpo3h3794k4isvNU/EXqYPNmGD8+9PU3aRJ3GpGdp+IvUgevvBLu1XvhhXEnEakbFX+ROigpgRYt4LTT4k4iUjcq/iI7qaIijN1/zjnQuHHcaUTqRsVfZCdNmQJr1qjLR/Kbir/ITiopgVatwpAOIvlKxV9kJ1RUQGlp6PJp1CjuNCJ1p+IvshNefRXWroXzzos7iUj9pFT8zayfmS0ws0VmNrSaNhea2Twzm2tmz6Q3pkhuKC2F3XaDvn3jTiJSPw1ra2BmDYB7ge8DnwEzzKzM3ecltOkC3Aic6O6rzWyvTAUWiUtlJTz3XLiit2nTuNOI1E8qe/7HAovc/UN3rwDGAFXHMLwSuNfdVwO4+4r0xhSJ38yZsGRJ6O8XyXfm7jU3MLsA6Ofug6L5gcBx7j4koU0psBA4EWgADHP3SUnWNRgYDNC2bdseJSUlafo1Mmf9+vU0b9487hi1Us70qS7jqFEHMHp0RyZM+DctWmyJIdmO8mFbgnKmW+/evWe6e896r8jda5yAC4BRCfMDgRFV2rwATAB2BQ4APgVa1bTerl27ej4oLy+PO0JKlDN9qsvYrZv7KadkN0tN8mFbuitnugFveS11O5UplW6fJUCHhPn20XOJPgPK3H2zu39E+CugS12/kERyzcKF4Ubt6vKRQpFK8Z8BdDGzA8ysETAAKKvSphToBWBmbYCuwIfpiykSr9LS8FN37JJCUWvxd/ctwBDgJWA+UOLuc81suJmdFTV7CVhpZvOAcuAGd1+ZqdAi2VZaCt27Q8eOcScRSY9aT/UEcPeJwMQqz92S8NiBX0WTSEFZuhSmT4c//CHuJCLpoyt8RWrx/PPgrv5+KSwq/iK1KC2FAw+Eww6LO4lI+qj4i9Rg3bpw165zzgGzuNOIpI+Kv0gNJk0KI3mqy0cKjYq/SA1KS6FtWzjhhLiTiKSXir9INSoq4O9/h7POggYN4k4jkl4q/iLVmDo19Pmry0cKkYq/SDVKS6FZM+jTJ+4kIumn4i+SRGVlKP4au18KlYq/SBIzZoQre9XlI4VKxV8kidLScJD3jDPiTiKSGSr+IkmUlkKvXtC6ddxJRDJDxV+kik8+2Y333lOXjxQ2FX+RKv797z0Bjd0vhU3FX6SKadPa0KMHdOhQe1uRfKXiL5Jg6VKYN6+lunyk4Kn4iyQoi25Qeu658eYQyTQVf5EEpaWw335f0a1b3ElEMkvFXySybez+k076QmP3S8FT8ReJTJwImzfDiSd+EXcUkYxT8ReJlJbCXntBt27r4o4iknEq/iLApk1hz19j90uxUPEXAcrL4csvdVWvFA8VfxE0dr8UHxV/KXqVlfDcc9C/PzRpEncakexQ8Zei9+absGyZunykuKj4S9ErLYWGDeH00+NOIpI9Kv5S1Nxh/HiN3S/FR8Vfitq8efD++3DeeXEnEckuFX8pahMmhJ8au1+KjYq/FLXx4+GEE2DffeNOIpJdKv5StBYvhrff1vDNUpxU/KVobevyUfGXYqTiL0VrwgQ4/HA46KC4k4hkn4q/FKXly2HaNJ3lI8UrpeJvZv3MbIGZLTKzoTW0O9/M3Mx6pi+iSPqVlYVz/NXlI8Wq1uJvZg2Ae4H+QDfgIjP71k3uzGx34FrgjXSHFEm3CROgc2c44oi4k4jEI5U9/2OBRe7+obtXAGOAZGdF3wrcDnydxnwiabd2LUyZEvb6dbtGKVYNU2izH/BpwvxnwHGJDcysO9DB3f9uZjdUtyIzGwwMBmjbti1Tp07d6cDZtn79euVMo1zI+core7F5czc6dZrF1KnfvmtXLmRMhXKmV77kTBt3r3ECLgBGJcwPBEYkzO8CTAX2j+anAj1rW2/Xrl09H5SXl8cdISXKmboLLnDfZx/3rVuTL8+FjKlQzvTKl5zAW15LfU1lSqXbZwnQIWG+ffTcNrsDhwFTzWwxcDxQpoO+kos2boQXXwzDOeyic92kiKXy8Z8BdDGzA8ysETAAKNu20N3Xunsbd9/f3fcHpgNnuftbGUksUg+TJ8OGDTrFU6TW4u/uW4AhwEvAfKDE3eea2XAzOyvTAUXSafx4aNkyDOEsUsxSOeCLu08EJlZ57pZq2vaqfyyR9KuoCDduOeccaNQo7jQi8VKvpxSNyZPDaZ4XXhh3EpH4qfhL0Rg7Flq1gr59404iEj8VfykKX38Nzz0XLuxSl4+Iir8UiZdfhnXr1OUjso2KvxSFkhLYYw/o0yfuJCK5QcVfCt7GjaHL57zzYNdd404jkhtU/KXgTZoE69ery0ckkYq/FLySEthzT+jdO+4kIrlDxV8K2oYN8PzzocunYUqXNIoUBxV/KWjPPRe+AC6+OO4kIrlFxV8K2lNPQYcOcPLJcScRyS0q/lKwli8P5/dffLGGbxapSv8lpGCNHQtbt8LAgXEnEck9Kv5SsJ56Co4+Grp1izuJSO5R8ZeCtGABzJgBl1wSdxKR3KTiLwXp6adDP/9FF8WdRCQ3qfhLwXEPXT59+0K7dnGnEclNKv5ScF57DT76SF0+IjVR8ZeC88gj0Lx5GLtfRJJT8ZeCsm4djBkT+vqbN487jUjuUvGXgjJmDHz1FQwaFHcSkdym4i8FZdQoOPxwOOaYuJOI5DYVfykY774bzu0fNAjM4k4jkttU/KVgjBoFjRvrLB+RVKj4S0HYuDGc23/++eFevSJSMxV/KQjjx8OaNTrQK5IqFX8pCA88AAceCN/7XtxJRPKDir/kvVmzYNo0+NnPNG6/SKr0X0Xy3t13Q7Nm8OMfx51EJH+o+EteW748XNh12WXQsmXcaUTyh4q/5LWRI6GiAn7+87iTiOQXFX/JWxUVcP/90L8/HHxw3GlE8ouKv+StkhJYtgyuvTbuJCL5R8Vf8pJ7ONB7yCFw6qlxpxHJPykVfzPrZ2YLzGyRmQ1NsvxXZjbPzGab2Stm1in9UUW2mzoV3noLrrlG4/iI1EWtxd/MGgD3Av2BbsBFZtatSrO3gZ7ufgQwDrgj3UFFEt12G+yzD1x+edxJRPJTKnv+xwKL3P1Dd68AxgBnJzZw93J3/yqanQ60T29Mke1eew3+8Q+44QZo0iTuNCL5ydy95gZmFwD93H1QND8QOM7dh1TTfgSwzN1vS7JsMDAYoG3btj1KSkrqGT/z1q9fT/M8uCVUMeUcOvRw3ntvd0aPnk7TppVpSrZdMW3LbFDO9Ordu/dMd+9Z7xW5e40TcAEwKmF+IDCimraXEPb8G9e23q5du3o+KC8vjztCSool58yZ7uD+xz+mJ08yxbIts0U50wt4y2upr6lMDVP4flgCdEiYbx89twMz6wvcDHzP3TfV4/tIpFq33QatWoVxfESk7lLp858BdDGzA8ysETAAKEtsYGZHAyOBs9x9RfpjisCcOTBhQjjDp0WLuNOI5Ldai7+7bwGGAC8B84ESd59rZsPN7Kyo2Z1Ac+BvZvaOmZVVszqROrvpplD0r7km7iQi+S+Vbh/cfSIwscpztyQ87pvmXCI7ePVVeP55+NOfYM89404jkv90ha/kvMpKuP566NBBQzmIpEtKe/4icRo7NlzN+/jj0LRp3GlECoP2/CWnbdoU+vqPOgouuSTuNCKFQ3v+ktNGjIDFi2HyZN2iUSSd9N9Jctann8KwYWG8/r46pUAkrVT8JSe5w5AhsHVr2PsXkfRSt4/kpAkToKwM7rgDOneOO41I4dGev+SctWvDXv9RR8Evfxl3GpHCpD1/yTk33gjLl8Nzz0FDfUJFMkJ7/pJTJk8ON2UfMgSOOSbuNCKFS8Vfcsby5TBwIBx6aBjGQUQyR39US06orIRLLw39/ZMnw267xZ1IpLCp+EtO+Otf4aWXQpfP4YfHnUak8KnbR2I3fXo4yHv++fCTn8SdRqQ4qPhLrBYvhrPPho4d4aGHwCzuRCLFQcVfYrN2LZxxBlRUwN//Dq1bx51IpHioz19isXkz/PCHsHBh6Os/5JC4E4kUFxV/ybrKytC3P3kyPPIInHJK3IlEio+6fSSrKith8GB49FH4/e/h8svjTiRSnLTnL1lTWQl33nkwkybB734Xir+IxEN7/pIVW7bAFVfApEnt+P3vYfhwndkjEicVf8m4NWvCWT2PPQaXXfYRw4bFHEhE1O0jmbVoEZx5Zvg5ahQceODHwAFxxxIpetrzl4x5+WU47jhYsQKmTAndPiKSG1T8Je02boRrr4XTToN27eDNN+F734s7lYgkUvGXtHrnnTAO/z33wDXXwIwZcOCBcacSkapU/CUtVq2Cn/8cevQIjydNgrvvhqZN404mIsmo+Eu9bN4MI0dC165w331w9dUwZ07o8hGR3KXiL3VSURFG4Tz4YLjqKvjOd2DWLBgxAvbYI+50IlIbFX/ZKStXwl/+AgcdFIZpaNMGyspg6lQ48si404lIqnSev9SqshKmTQt7+n/7G2zaBCefHOZPPVVX6orkIxV/SWrr1nCHrZISGDcOPv8cWrSAK68MI3IedljcCUWkPlT8BQD3cFet8vIwvv7kybB6NTRuDKefDhdeGK7UbdYs7qQikg4q/kVqxQqYPRvefRdefx1eew2WLg3L2rULt1bs1w/69w97/CJSWFT8C9imTfDxx/DBB/Dhh+HnnDmh6C9fvr3d/vtD795w4omhL/+ww9SPL1LoUir+ZtYPuBtoAIxy9z9XWd4YeALoAawEfuTui9MbVdxh/fpw79s1a8L0xRdhj3369P155hlYtizML10a+undt7++aVM49NDQjXPEEdunNm3i+o1EJC61Fn8zawDcC3wf+AyYYWZl7j4vodkVwGp3P8jMBgC3Az9KZ9BtRcx9+1R1PpU2O/uaVasafdMdUlkZDoRu2bLjlOy56p6vqAhj3yROX39d/XPr1m0v9GvXhnUmY9aJtm1Dl80++4S9906doHPnMLxC587hee3Riwiktud/LLDI3T8EMLMxwNlAYvE/GxgWPR4HjDAzc0/c79zR++/vTpMmqRXgeH034+/QpEnYK982Jc63axf21lu1ClPLltsft2oVLqhq1w7mz/8nffpo9DQRSU0qxX8/4NOE+c+A46pr4+5bzGwtsCfwRWIjMxsMDI5mN23aZHPqEjrL2lDl90i3r78O0+rV9VpNxnOmST7kzIeMoJzpli85D07HSrJ6wNfdHwQeBDCzt9y9Zzbfvy6UM73yIWc+ZATlTLd8ypmO9aQyvMMSoEPCfPvouaRtzKwh0JJw4FdERHJQKsV/BtDFzA4ws0bAAKCsSpsy4NLo8QXAP2rq7xcRkXjV2u0T9eEPAV4inOr5iLvPNbPhwFvuXgY8DDxpZouAVYQviNo8WI/c2aSc6ZUPOfMhIyhnuhVVTtMOuohI8dGQziIiRUjFX0SkCGW0+JvZD81srplVmlnPKstuNLNFZrbAzJLe9C86yPxG1G5sdMA5o6L3eSeaFpvZO9W0W2xm/4napeXUq51hZsPMbElC1tOradcv2saLzGxoDDnvNLP3zGy2mU0ws1bVtMv69qxt25hZ4+jzsCj6HO6fjVxVMnQws3Izmxf9X7o2SZteZrY24bNwS7ZzRjlq/De04J5oe842s+4xZDw4YTu9Y2brzOwXVdrEsj3N7BEzW2G2/fonM9vDzCab2fvRz9bVvPbSqM37ZnZpsjbf4u4Zm4BDCRckTAV6JjzfDXgXaAwcAHwANEjy+hJgQPT4AeDqTOZN8v5/AW6pZtlioE0281R5/2HA9bW0aRBt285Ao2ibd8tyzlOBhtHj24Hbc2F7prJtgJ8CD0SPBwBjY/h3bgd0jx7vDixMkrMX8EK2s+3svyFwOvAiYMDxwBsx520ALAM65cL2BP4H6A7MSXjuDmBo9Hhosv8/wB7Ah9HP1tHj1rW9X0b3/N19vrsvSLLobGCMu29y94+ARYRhJL5hZgacQhguAuBx4JwMxt1B9P4XAqOz9Z4Z8M3QHO5eAWwbmiNr3P1ld98SzU4nXCeSC1LZNmcTPncQPod9os9F1rj7UnefFT3+EphPuKI+H50NPOHBdKCVmbWLMU8f4AN3/zjGDN9w938SzpZMlPgZrK4GngZMdvdV7r4amAz0q+394urzTzZkRNUP9J7AmoTCkaxNJp0MLHf396tZ7sDLZjYzGrYiDkOiP58fqebPwVS2czb9mLDnl0y2t2cq22aHYUuAbcOWxCLqdjoaeCPJ4hPM7F0ze9HMvpPdZN+o7d8w1z6PA6h+5y4XtifA3u4eDS3JMmDvJG3qtF3rPbyDmU0B9kmy6GZ3f66+68+EFDNfRM17/Se5+xIz2wuYbGbvRd/cWckJ3A/cSvgPdyuhi+rH6Xz/VKWyPc3sZmAL8HQ1q8n49sxnZtYceBb4hbuvq7J4FqHrYn107KcU6JLliJBH/4bR8cOzgBuTLM6V7bkDd3czS9u5+fUu/u7etw4vS2XIiJWEPwsbRntdydrUSW2ZLQxRcR7h/gTVrWNJ9HOFmU0gdCOk9YOe6rY1s4eAF5IsSmU711sK2/My4AdAH486KZOsI+Pbs4qdGbbkM4tx2BIz25VQ+J929/FVlyd+Gbj7RDO7z8zauHtWBylL4d8wK5/HFPUHZrn78qoLcmV7RpabWTt3Xxp1ka1I0mYJ4TjFNu0Jx1lrFFe3TxkwIDqb4gDCt+qbiQ2iIlFOGC4CwvAR2fpLoi/wnrt/lmyhmTUzs923PSYc1MzqCKVV+krPreb9UxmaI6Ms3Ajo18BZ7v5VNW3i2J55MWxJdIzhYWC+u/+1mjb7bDsWYWbHEv5fZ/VLKsV/wzLg/0Vn/RwPrE3o0si2av+yz4XtmSDxM1hdDXwJONXMWkfdv6dGz9Usw0evzyX0P20ClgMvJSy7mXC2xQKgf8LzE4F9o8edCV8Ki4C/AY0zmTchw2PAVVWe2xeYmJDr3WiaS+jeyPaZAU8C/wFmRx+QdlVzRvOnE84Q+SCmnIsI/ZHvRNMDVXPGtT2TbRtgOOGLCqBJ9LlbFH0OO8ew/U4idO3NTtiGpwNXbfuMAkOi7fYu4aD6d2PImfTfsEpOI9wY6oPos9sz2zmjHM0IxbxlwnOxb0/Cl9FSYHNUN68gHGN6BXgfmALsEbXtSbir4rbX/jj6nC4CLk/l/TS8g4hIEdIVviIiRUjFX0SkCKn4i4gUIRV/EZEipOIvIlKEVPxFRIqQir+ISBH6//1zJnK5PI8iAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "%matplotlib inline\n",
- "import matplotlib.pyplot as plt\n",
- "import numpy as np\n",
- "\n",
- "plt.figure()\n",
- "plt.axis([-10,10,0,1])\n",
- "plt.grid(True)\n",
- "X=np.arange(-10,10,0.1)\n",
- "y=1/(1+np.e**(-X))\n",
- "plt.plot(X,y,'b-')\n",
- "plt.title(\"Logistic function\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "可以看到 Sigmoid 函数的范围是在 0 ~ 1 之间,所以任何一个值经过了 Sigmoid 函数的作用,都会变成 0 ~ 1 之间的一个值,这个值可以形象地理解为一个概率,比如对于二分类问题,这个值越小就表示属于第一类,这个值越大就表示属于第二类。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "另外一个 Logistic 回归的前提是确保你的数据具有非常良好的线性可分性,也就是说,你的数据集能够在一定的维度上被分为两个部分,比如\n",
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看到,上面绿色的点和蓝色的点能够几乎被一个黑色的平面分割开来"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 1.2 损失函数\n",
- "前一节对于回归问题,我们有一个 loss 去衡量误差,那么对于分类问题,我们如何去衡量这个误差,并设计 loss 函数呢?\n",
- "\n",
- "Logistic 回归使用了 Sigmoid 函数将结果变到 0 ~ 1 之间,对于任意输入一个数据,经过 Sigmoid 之后的结果我们记为 $\\hat{y}$,表示这个数据点属于第二类的概率,那么其属于第一类的概率就是 $1-\\hat{y}$。如果这个数据点属于第二类,我们希望 $\\hat{y}$ 越大越好,也就是越靠近 1 越好,如果这个数据属于第一类,那么我们希望 $1-\\hat{y}$ 越大越好,也就是 $\\hat{y}$ 越小越好,越靠近 0 越好,所以我们可以这样设计我们的 loss 函数\n",
- "\n",
- "$$\n",
- "loss = -(y * log(\\hat{y}) + (1 - y) * log(1 - \\hat{y}))\n",
- "$$\n",
- "\n",
- "其中 y 表示真实的 label,只能取 {0, 1} 这两个值,因为 $\\hat{y}$ 表示经过 Logistic 回归预测之后的结果,是一个 0 ~ 1 之间的小数。如果 y 是 0,表示该数据属于第一类,我们希望 $\\hat{y}$ 越小越好,上面的 loss 函数变为\n",
- "\n",
- "$$\n",
- "loss = - (log(1 - \\hat{y}))\n",
- "$$\n",
- "\n",
- "在训练模型的时候我们希望最小化 loss 函数,根据 log 函数的单调性,也就是最小化 $\\hat{y}$,与我们的要求是一致的。\n",
- "\n",
- "而如果 y 是 1,表示该数据属于第二类,我们希望 $\\hat{y}$ 越大越好,同时上面的 loss 函数变为\n",
- "\n",
- "$$\n",
- "loss = -(log(\\hat{y}))\n",
- "$$\n",
- "\n",
- "我们希望最小化 loss 函数也就是最大化 $\\hat{y}$,这也与我们的要求一致。\n",
- "\n",
- "所以通过上面的论述,说明了这么构建 loss 函数是合理的。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 1.3 程序示例\n",
- "\n",
- "下面我们通过例子来具体学习 Logistic 回归"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<torch._C.Generator at 0x7f36e27d3490>"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import torch\n",
- "from torch.autograd import Variable\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "%matplotlib inline\n",
- "\n",
- "# 设定随机种子\n",
- "torch.manual_seed(2021)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "我们从 `data.txt` 读入数据。读入数据点之后我们根据不同的 label 将数据点分为了红色和蓝色,并且画图展示出来了"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7f36e004f310>"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAfqElEQVR4nO3dfYxd9X3n8fd3jI01jRPAnkqRxx5PKrPBEKTiMU3+yCZqltahWbN5qmwMDWpShzQklZomgCAWMhqlrVat2oVdyYlowOOC2PyxclU2FiJB0WYD9RAeDYIYY8w4SJmMm92Q4OWh3/3j3IE7d+7Dufeeh9/vnM9LuvLce4/v+d1zzv2e3/n+Ho65OyIiEr+RsgsgIiLZUEAXEakIBXQRkYpQQBcRqQgFdBGRijirrBWvW7fON23aVNbqRUSi9Mgjj/zc3cfavVdaQN+0aROzs7NlrV5EJEpm9mKn95RyERGpCAV0EZGKUEAXEakIBXQRkYpQQBcRqYieAd3M7jCzn5nZUx3eNzP7ezM7ZmZPmNkl2RczQwcPwqZNMDKS/HvwYNklEhHJRJoa+reB7V3e/yiwufHYA/y34YuVk4MHYc8eePFFcE/+3bNHQV1EKqFnQHf3HwCnuyxyBXCXJx4CzjGzd2dVwEzddBP8+tdLX/v1r5PXRUQil0UOfT3wUtPzucZry5jZHjObNbPZ+fn5DFbdp5Mn+3tdRCQihTaKuvt+d59y96mxsbYjV/O1cWN/r2dEaft60n6XomUR0E8BG5qejzdeC8/0NIyOLn1tdDR5PSdK29eT9ruUIYuAfgj4o0Zvl/cD/8fdX87gc7O3ezfs3w8TE2CW/Lt/f/J6TpS2L0+ZNWTtdylDmm6LdwM/Av6dmc2Z2WfN7Fozu7axyH3AceAY8E3gT3MrbRZ274YTJ+Df/i35N8dgDkrbF6k5gK9bB3/8x+XVkGPY70oJVZC7l/LYunWr18HEhHsSUpY+JibKLlm1zMy4j46239ZlbPdB9/vMTLKMWfLvzEw+5Wu3vUZH81ufZAeY9Q5xVSNFc1ZC2r6W2qU42imqhjzIfi8y766UUDVVN6AHcj1ZQtq+ltIG6pw7NL1lkP1eZJCNISUk/bOkBl+8qakpz+0GF4tVneZfx+ioImmFbdqU1Gi7Cf0QGBlJauatzJImnyx12l4TE0nTkoTLzB5x96l271Wzhq7rydppl+JYuRLWro3nyqjIYRJKBVZTNQO6ridrp12K4x/+AX7+88I6NA2tyCCrVGA1VTPloutJidTBg8mF5MmTSc18elpBVpaqX8pF15MSqYKHSUjFVDOg63pSRGqomgEdVNWRIAXSm1Yq6qyyCyBSF629aRcHDoHqG5KN6tbQRQKj3rSSNwV0kYLE1JtWqaE4KaCLFKSk+6v0TXO5x0sBXaQgsfSmVWooXgroIgWJpTdtpxRQr7lypHwK6CIFiqE3bacUkJnSLqFTQBeJXNYNmNPTSfBu5a60S+gU0EUilkcD5u7d7afxhTB75MjbFNBFIrRYK7/qqnwaMCcm2r8eWo8cWUoBXTKl/sv5a66VdzJsTTqWHjmyVKqAbmbbzexZMztmZje0eX/CzB4wsyfM7EEzG8++qBI69V8uRpr7pw5bk46lR44s1TOgm9kK4Hbgo8AWYJeZbWlZ7D8Dd7n7xcA+4BtZF1T6U0ZNWf2Xi9Gr9p1VTTqGHjm91O2KMU0N/VLgmLsfd/fXgHuAK1qW2QJ8r/H399u8LwUqq6Yc09D2mHWrfasm/bY6XjGmCejrgZeans81Xmv2OPCJxt8fB9aY2drWDzKzPWY2a2az8/Pzg5RXUiirphzL0PbYdcpvz8zEW5POQx2vGLNqFP0L4ENm9ijwIeAU8GbrQu6+392n3H1qbGwso1VLq7JqympIK0a3/HbdUgzNWr97p0bjSl8xunvXB/AB4HDT8xuBG7ss/w5grtfnbt261SUxM+M+MeFulvw7MzPc501MuCcXmUsfExPDl7WXrL+LpDcz4z46unSfj47WYx+0++5m5f0O8gTMeqf42+mNtxZIboJxHJgEVpGkVy5sWWYdMNL4exrY1+tzFdATefwI6/zDrrMyT+Rl6/TdW4N6nr+DoiozQwX05P9zOfAc8DxwU+O1fcCOxt+fAn7SWOZbwNm9PlMBPZHXj1A15frpVCM1K7tk+ev03Rd/S3n/DoqsRHUL6Ja8X7ypqSmfnZ0tZd0hGRlpP8zaLOkuJp0dPJg0cJ08mTS8Tk/Xu0GwU954YiJpLK2ysr97kes3s0fcfardexopWjL1DBlMHbukdbLYGPjii8sn1apLo3TZDfKhdNlVQC9Z2QdirELsklZGD5PWaQDc3w7qdeqTXvbI1mAqZp1yMXk/lEN/m/Ld/QstX1xWQ3SdG0JDohy6cugyhLJzpq2KLE9z20Gnn6/aYIpXVJuOcuhSOaGlqorKoba2HXSiNpjihTD3jQK6RKnsnGmronKoaWZaVBtMfSmgS7RCqBEtKuqKoVuNP4QTm5TrrLILIFIFiwE07xzqxo1htR1IWFRDF8lIEVcMobUdSFgU0EUiElrbgYRFKReRyOzerQAu7amGLiJSEQroIiIVoYAuIlIRCuhSa3W+ZZtUjxpFpbYWh9EvjrxcnIIX1OgocVINXWorxCl4RYahgC61FcpNCUSyooAutRXMTQlEMpIqoJvZdjN71syOmdkNbd7faGbfN7NHzewJM7s8+6KKZEvD6KVqegZ0M1sB3A58FNgC7DKzLS2L3Qzc6+6/DewE/mvWBa0sdbMojYbRS9WkqaFfChxz9+Pu/hpwD3BFyzIOvLPx97uAn2ZXxArTnY5LF8IUvDqnS1bSBPT1wEtNz+carzW7BbjKzOaA+4AvZVK6qlM3i9rTOV2ylFWj6C7g2+4+DlwOHDCzZZ9tZnvMbNbMZufn5zNa9RDKrhqpm0Xpyj4EdE6XLKUJ6KeADU3PxxuvNfsscC+Au/8IWA2sa/0gd9/v7lPuPjU2NjZYibMSQtVI3SxKFcIhoHO6ZClNQD8CbDazSTNbRdLoeahlmZPARwDM7AKSgB5AFbyLEKpG6mZRqhAOAZ3Tw1H21VoWegZ0d38DuA44DDxD0pvlqJntM7MdjcW+AvyJmT0O3A1c497tnuQBCKFqpG4WpQrhENA5PQwhXK1lwt1LeWzdutVLNTHhnuy7pY+JiXLLJYUJ5RCYmUnWaZb8OzNT7PoHFWu52wnlWEgDmPUOcbW+I0XzrBpV4dqtBkKpHYfQdbJflanRNoRwtZaJTpE+70fpNXT3fKoYMzPuo6NLT/Ojo9FUX0KsdeVZphC/bwxiqtGmEdP3oUsNvd4BPQ8xHRktQjwXhVimqhnkpGbW/jA3y7u0+YjpOFNAb5VntSziIz3Ec1GIZaqSQQNZFfdLLFdr3QK6Je8Xb2pqymdnZ4tfcetdDSBJnGbVu2TTpiSh2GpiIkmQBmxkJPlZtjJL8rtlCLFMVTLo4Zr3z0g6M7NH3H2q3Xv1axTNu/NxKC1tAwixT3SIZaqSQRsD1eM2TPUL6Hk3Z0d8pId4LgqxTFXS6cQ4MtK7k1aMvXMqr1MuJu9HaTn0Kib/MhRiHjHEMlVFuxx66yPUxsG6Qjn0Jkr+iSxx8GCScTx5MqmVv/nm8mUiaAKqDeXQm6VJiWhgUDS0q4bXnDrp1NAc3QCbmqpfDb0X1eCj0W5XrVoFa9bA6dNJfnh6WrutHxF30qoN1dD7EcIUfJJKu1312muwsFCN4ehlUCN0vvK+olRAb1WZSR2qL80u0bm4PxF30gpeEfPfKOXSStec0ei0q1ppEJKEIKvQopRLP7K+5lSrXW7a7ap2NAhJQlDExb8CeqssrzmrNsdoYFp31dq1sHLl0mWU/5VQFDHqWSmXPCl9U7jmPtXq5SIhyaoDnVIuZVEDa+E0HH05Zf3CUESDswJ6njSzVOmqFsz6/T5Vy/rFvj9zr3B0mhMg70dlb3DRLKZZ81OIbU6Vim3+gb5PlaYuqtr+HBTD3uAC2A48CxwDbmjz/t8CjzUezwG/6PWZtQjo7vFFwQ5i/DFVKZi5D/Z9Ir7fyjJV25+DGiqgAyuA54H3AKuAx4EtXZb/EnBHr8+tTUCviBh/TFUKZu6DfZ9O+21x34V8Qm4V8v7sVm/Luk7XLaCnyaFfChxz9+Pu/hpwD3BFl+V3AXenT/pIDGJs361aE8Yg36dbX/3Y8umh7s9u7RSFt2F0ivSLD+BTwLeanl8N3NZh2QngZWBFh/f3ALPA7MaNG4c7TUmhYqyhx5gm6mbQ77NYQ+xWU49BqPuz228jj98NQ6Zc+gno1wP/pddnulIu0Qn1x9RLRZow3jLM9wk5ZZFWiPuz23bNY5t3C+g9BxaZ2QeAW9z99xvPb2zU7L/RZtlHgS+6+//udWVQi4FFFaNBO3HTOLd8dNuukP02H3Zg0RFgs5lNmtkqYCdwqM1K3gucC/xosGJK6DRoJ26aGjcf3bZr0du8Z0B39zeA64DDwDPAve5+1Mz2mdmOpkV3Avd4ryq/iJRCU+Pmo9t2LXqbay6XmCjnEQztCilLt5TLWUUXRgbUOrPPYv8nUCQpmHaFhEpzucRCt8YLhnZFGGKf1yUPCuixiHFkT0VpV5QvrwE7sZ8kFNCzUMRREOowuRrSrihfHldJVZiZUgF9WEUdBZH3OYu95tMs8l1RCXlcJVUildZpxFHej8qMFC1yTHyIw+RSiHWUaTeR7orKyONnF8tIWoYZKZqXynRbHBlJ9nsr3Wr+LRqhKFnL6nZuzWI5TnULujwpodqTGhEla3kM2KlCKk0BfVhVOApypnOe5CHrqSiqMJJWAX1YVTgK0hqwZVPnPIlF7PMVaaRoFhYnbaiyIYZHLr6tofIi+VKjqKQTS4uRSMWpUVSGp5ZNkeApoEs6atkUCZ4CuqSjlk2R4CmgSzp16s3TQZWmL5DiFHncKKBLeot9ug4cSJ5ffXVtIlsVJm6S4hV93NQ7oKvK1b+aRrZKTNwkhSv6uKlvQK9pYFqm35NaTSObOvnIIIo+buob0GsamJYY5KSW0xEa+sWSOvnIIIo+blIFdDPbbmbPmtkxM7uhwzJ/aGZPm9lRM/vHbIuZA1W5Bjup5XCExnCxpE4+MojCj5tO8+ouPoAVwPPAe4BVwOPAlpZlNgOPAuc2nv9mr88tfT70IucxD9UgE0DnMLl5LLtCc6DLILI+bugyH3qagP4B4HDT8xuBG1uW+Wvgc70+q/lRekCv4l0X+jVoJM34CI3lxgIiIegW0NOkXNYDLzU9n2u81ux84Hwz+6GZPWRm29t9kJntMbNZM5udn59PseocqV/14NeDGU9Jp/y0SDayahQ9iyTt8mFgF/BNMzundSF33+/uU+4+NTY2ltGqhxD7XJnDCuSkpvy0SDbSTJ97CtjQ9Hy88VqzOeBhd38deMHMniMJ8EcyKaXkJ4CpfzW9rkg20tTQjwCbzWzSzFYBO4FDLcv8D5LaOWa2jiQFczy7YkrV1f1iqU5C76Ias541dHd/w8yuAw6T9Hi5w92Pmtk+kuT8ocZ7v2dmTwNvAl9194U8Cy4i8RniPimSQqocurvf5+7nu/tvuft047W9jWBOo/H1z919i7u/z93vybPQtacqjgSu0yGq8Xz50i3oYqMqjgSu2yGq8Xz50i3oYqNbwUnguh2ioMN3WLoFXZWoiiOB63aIqotqvhTQY6NROBK4bodoIEMfKksBPTaq4kjgeh2i6qKaHwX02KiKI4HTIVoeNYqKiEREjaIiIjWggC4iUhEK6CIiFaGALpIxzcwgZVFAl/zUMLLFcH9UqS4F9LrKO9iGEtkKPqlo8ikpk7ot1lHr7EmQjPzIsrNwCHPOFPE9W4yMJOevVmbJQBqRYXXrtqiAXkdFBNsQIlsJJ5UQzmNSbeqHLksVMcFXCHPOlDCRmWZmkDIpoNdREcE2hMhWwklFw96lTArodVREsC07sh08CK+8svz1Ak4qmnxKyqKAXkfdgm2WvULKimyLjaELLbe1XbtW1WWptFQB3cy2m9mzZnbMzG5o8/41ZjZvZo81Hp/LvqjUsl9zbtoF21C6Gg6rXd9BgHe8Q8FcKq1nLxczWwE8B1wGzAFHgF3u/nTTMtcAU+5+XdoV993LpYQuaLVTlS4aIfSwEcnJsL1cLgWOuftxd38NuAe4IssCpqIRG/nr1PujXZAPWQg9bERKkCagrwdeano+13it1SfN7Akz+46ZbWj3QWa2x8xmzWx2fn6+v5LqXpr56xTwzOJKu4TQw6amlBUtV1aNov8EbHL3i4H7gTvbLeTu+919yt2nxsbG+luDal35m55Ogncr97iuhMruYVNTVWmCiVmaHPoHgFvc/fcbz28EcPdvdFh+BXDa3d/V7XOVQw9Uu4C++Lryz9JFVZpgQjdsDv0IsNnMJs1sFbATONSygnc3Pd0BPDNoYTtSrasYExPtX9eVUJSKTIEoK1q+ngHd3d8ArgMOkwTqe939qJntM7MdjcW+bGZHzexx4MvANbmUNs9+zUr+JZR/royiUyDKigbA3Ut5bN261YMxM+M+OuqeHPfJY3Q0eb2OZmbcJybczZJ/v/CFpc/rul06ad1egWyfiYmlh/TiY2Iin/XpZ1QMYNY7xFUFdPfij/yY6FfaXYnbp9d5xKz9YW1WXplkeN0CuqbPBQ1E6UYtXd2VtH3S9BHQrqsmTZ/bi5J/7R082HlQkVq6EiW1BKYZZ6fmkMHF2qSmgA468ttZrAJ2UveT3aKSKgNpziN16BiWR+CNuj99p1xM3o+gcujuSv616tSuoBz6UiXl0NXsk9+mD33bokZR6VunFjVQMG9VQmUg7/NIDPWbvAJvGY3J/egW0NUoKu2pRS14Bw8mOfOTJ5MMz/R0NumUWAZl59WXIfRDv1qNolkkzWJt8SiS2hWCl9c4u1gmNs2r+SLqQ79T1T3vx0AplyyuM9WvOr0Yrrslc1mmHPI8hPL8KYd86FOZHHq/SbN2eyX0Fg+RkmX1Eymi7tT8E1+7NnmEGISzVJ2A3k/VodPR1KmhL5QWD4lDyFW4IWUViIusO9Xpwrs6Ab2fI6TTsitWqIZeJWUE1hpEj1SbtcdCRfYWqdOFd3UCej8/pG7d7ir+Y6yNsgJrnaJHJym2fZGbKfSuhlmqTkB3T18j63Y0VfhyuVbKCqx1ih6dpNj2RZ5v63SOrVZAT6sGl8W1V1ZgDSF6lF0pSbntiypmnX7u9Qzo7uUf9JKvsgJr2dGj7PW7h3FSa1GXn3t9A7pUW9GBLZQ+ciEE0xBOKjXVLaDHN1JUZFGR0wm2TsG3sACvvgoHDmR/O8ReQrh55+7d8JnPwIoVyfMVK5LnIc0NUEOay0UkjZAm+AihLLFM+FJB3eZyCSqgv/7668zNzXHmzJlSypS11atXMz4+zsqVK8suigwrpLtahRBMQzip1FS3gH5Wyg/YDvwdsAL4lrv/ZYflPgl8B9jm7n1Xv+fm5lizZg2bNm3CzPr970FxdxYWFpibm2NycrLs4siwNm5sH8DKuNHHYtDOY6rFtEJI+8gyPXPoZrYCuB34KLAF2GVmW9ostwb4M+DhQQtz5swZ1q5dG30wBzAz1q5dW5mrjczFNuNlaFPw5TXVYlq6bWOQ0jSKXgocc/fj7v4acA9wRZvlbgX+ChgqglUhmC+q0nfJVIz3+KrD/dz6EdoJrl+xVShSShPQ1wMvNT2fa7z2FjO7BNjg7v/c7YPMbI+ZzZrZ7Pz8fN+FlYqIZcLtVmXXikMS8wkuxgpFSkN3WzSzEeBvgK/0Wtbd97v7lLtPjY2NDbvqyp5lK0/512qI9QQ3SIUikliTJqCfAjY0PR9vvLZoDXAR8KCZnQDeDxwys7atsJkp+Sx75513snnzZjZv3sydd95ZyDorQ/lXKVO/FYqYavSdRhwtPkh6whwHJoFVwOPAhV2WfxCY6vW57UaKPv300+mHS5U4Wm5hYcEnJyd9YWHBT58+7ZOTk3769Om2y/b1nepCowylTP3GjhBG5jZhmJGi7v4GcB1wGHgGuNfdj5rZPjPbkcdJJpUcLtuPHDnCxRdfzJkzZ/jVr37FhRdeyFNPPbVsucOHD3PZZZdx3nnnce6553LZZZfx3e9+d+D11k7M+VeJX78NuhGlCFP1Q3f3+4D7Wl7b22HZDw9frBRy6Be8bds2duzYwc0338yrr77KVVddxUUXXbRsuVOnTrFhw9tZqPHxcU6dOrVsOeli924FcClHv/34QxqD0EO8c7nk1G1q79693H///czOzvK1r31tqM8SkUD106AbURfNeAN6TpftCwsLvPLKK/zyl7/sOCho/fr1vPTS2z055+bmWL9+fdtlRSRyEaUIg5rL5ZlnnuGCCy4opTyLduzYwc6dO3nhhRd4+eWXue2225Ytc/r0abZu3cqPf/xjAC655BIeeeQRzjvvvGXLhvCdRKQ6us3lEm8NPQd33XUXK1eu5Morr+SGG27gyJEjfO9731u23HnnncfXv/51tm3bxrZt29i7d2/bYC5NIunHKxIz1dBzVsXv1LcQZgcUqQjV0KVcsQ71F4lMqm6LdfXkk09y9dVXL3nt7LPP5uGHB55Qsp4i6scrEjMF9C7e97738dhjj5VdjPhF1I9XJGZKuUj+IurHKxIzBXTJX0T9eEVippSLFEND/UVyF3UNXV2bRUTeFm1AL3uK4u3bt3POOefwsY99rJgVioj0EG1AL7tr81e/+lUOHDhQzMpERFKINqDn0bU57XzoAB/5yEdYs2bN4CsTEclYtI2ieXRtTjsfuohIiKKtoefVtVnzoYtIrKIN6Hl1bU4zH7pIENTNS1pEm3KBfLo2f/7zn+fWW2/lhRde4Prrr287H7pI6VpnsFzs5gXq719jqWroZrbdzJ41s2NmdkOb9681syfN7DEz+19mtiX7ouYv7XzoAB/84Af59Kc/zQMPPMD4+DiHDx8uuLRSa2V385Ig9ZwP3cxWAM8BlwFzwBFgl7s/3bTMO939/zb+3gH8qbtv7/a5mg9dZAgjI8kAjFZmyX0ypbKGnQ/9UuCYux9399eAe4ArmhdYDOYNvwGUc9cMkbro1J1LM1jWWpoc+nrgpabnc8DvtC5kZl8E/hxYBfxuJqUrmeZDl2BNT7e/C5RmsKy1zBpF3f124HYzuxK4GfhM6zJmtgfYA7CxQ03C3TGzrIo1lGHnQy/r9n5SA4sNnzfdlIym27gxCeZqEK21NCmXU8CGpufjjdc6uQf4T+3ecPf97j7l7lNjY2PL3l+9ejULCwuVCITuzsLCAqtXry67KFJVu3fDiRNJzvzECQVzSVVDPwJsNrNJkkC+E7iyeQEz2+zuP2k8/QPgJwxgfHycubk55ufnB/nvwVm9ejXj4+NlF0NEaqJnQHf3N8zsOuAwsAK4w92Pmtk+YNbdDwHXmdl/AF4H/pU26ZY0Vq5cyeTk5CD/VUSk9lLl0N39PuC+ltf2Nv39ZxmXS0RE+hTt0H8REVlKAV1EpCJ6jhTNbcVm80CbCXBTWQf8PMPi5CmmskJc5Y2prKDy5immssJw5Z1w9+XdBCkxoA/DzGY7DX0NTUxlhbjKG1NZQeXNU0xlhfzKq5SLiEhFKKCLiFRErAF9f9kF6ENMZYW4yhtTWUHlzVNMZYWcyhtlDl1ERJaLtYYuIiItFNBFRCoi6IAe063vepW1ablPmpmbWaldrFJs22vMbL6xbR8zs8+VUc5GWXpuWzP7QzN72syOmtk/Fl3GlrL02rZ/27RdnzOzX5RQzMWy9CrrRjP7vpk9amZPmNnlZZSzqTy9yjthZg80yvqgmZU2O56Z3WFmPzOzpzq8b2b2943v8oSZXTL0St09yAfJRGDPA+8huWnG48CWlmXe2fT3DuC7oZa1sdwa4AfAQ8BU4Nv2GuC2SI6DzcCjwLmN578Zcnlblv8SyYR3QZaVpPHuC42/twAnQt62wH8HPtP4+3eBAyWW998DlwBPdXj/cuB/Aga8H3h42HWGXEOP6dZ3PcvacCvwV8CZIgvXRtryhiBNWf8EuN3d/xXA3X9WcBmb9bttdwF3F1Ky5dKU1YF3Nv5+F/DTAsvXKk15twCLd3b/fpv3C+PuPwBOd1nkCuAuTzwEnGNm7x5mnSEH9Ha3vlvfupCZfdHMngf+GvhyQWVr1bOsjcupDe7+z0UWrINU2xb4ZONS8DtmtqHN+0VIU9bzgfPN7Idm9pCZdb1Bec7SblvMbAKY5O0AVLQ0Zb0FuMrM5khmXP1SMUVrK015Hwc+0fj748AaM1tbQNkGkfpYSSvkgJ6Ku9/u7r8FXE9y67vgmNkI8DfAV8ouSx/+Cdjk7hcD9wN3llyebs4iSbt8mKTG+00zO6fMAqW0E/iOu79ZdkG62AV8293HSVIEBxrHc6j+AviQmT0KfIjkpjwhb99MhbxjMrv1XQF6lXUNcBHwoJmdIMmXHSqxYbTntnX3BXf/f42n3wK2FlS2VmmOgzngkLu/7u4vAM+RBPgy9HPc7qS8dAukK+tngXsB3P1HwGqSiaXKkOa4/am7f8Ldfxu4qfHaLworYX/6jXG9ldVgkKJB4SzgOMkl6WIDyIUty2xu+vs/ktxBKciytiz/IOU2iqbZtu9u+vvjwEMBl3U7cGfj73Ukl7FrQy1vY7n3AidoDO4LtawkjXbXNP6+gCSHXkqZU5Z3HTDS+Hsa2FfW9m2UYROdG0X/gKWNov8y9PrK/LIpNsblJLWt54GbGq/tA3Y0/v474CjwGEkDSMcgWnZZW5YtNaCn3LbfaGzbxxvb9r0Bl9VIUlpPA08CO0Peto3ntwB/WWY5U27bLcAPG8fBY8DvBV7eT5Hc0/g5kivLs0ss693AyyS35pwjudq5Fri28b4Btze+y5NZxAQN/RcRqYiQc+giItIHBXQRkYpQQBcRqQgFdBGRilBAFxGpCAV0EZGKUEAXEamI/w9XgLB6OEkCCAAAAABJRU5ErkJggg==\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 从 data.txt 中读入点\n",
- "with open('./data.txt', 'r') as f:\n",
- " data_list = [i.split('\\n')[0].split(',') for i in f.readlines()]\n",
- " data = [(float(i[0]), float(i[1]), float(i[2])) for i in data_list]\n",
- "\n",
- "# 标准化\n",
- "x0_max = max([i[0] for i in data])\n",
- "x1_max = max([i[1] for i in data])\n",
- "data = [(i[0]/x0_max, i[1]/x1_max, i[2]) for i in data]\n",
- "\n",
- "x0 = list(filter(lambda x: x[-1] == 0.0, data)) # 选择第一类的点\n",
- "x1 = list(filter(lambda x: x[-1] == 1.0, data)) # 选择第二类的点\n",
- "\n",
- "plot_x0 = [i[0] for i in x0]\n",
- "plot_y0 = [i[1] for i in x0]\n",
- "plot_x1 = [i[0] for i in x1]\n",
- "plot_y1 = [i[1] for i in x1]\n",
- "\n",
- "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n",
- "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n",
- "plt.legend(loc='best')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "接下来我们将数据转换成 NumPy 的类型,接着转换到 Tensor 为之后的训练做准备"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [],
- "source": [
- "np_data = np.array(data, dtype='float32') # 转换成 numpy array\n",
- "x_data = torch.from_numpy(np_data[:, 0:2]) # 转换成 Tensor, 大小是 [100, 2]\n",
- "y_data = torch.from_numpy(np_data[:, 2]).unsqueeze(1)\n",
- "\n",
- "x_data = Variable(x_data)\n",
- "y_data = Variable(y_data)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "在 PyTorch 当中,不需要我们自己写 Sigmoid 的函数,PyTorch 已经用底层的 C++ 语言为我们写好了一些常用的函数,不仅方便我们使用,同时速度上比我们自己实现的更快,稳定性更好。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 定义 logistic 回归模型\n",
- "w = Variable(torch.randn(2, 1), requires_grad=True) \n",
- "b = Variable(torch.zeros(1), requires_grad=True)\n",
- "\n",
- "def logistic_regression(x):\n",
- " return torch.sigmoid(torch.mm(x, w) + b)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "在更新之前,我们可以画出分类的效果"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7f36e0017610>"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAobElEQVR4nO3deXRTZf4G8OdtKUvZKShLaYsjKAUp0KAoIuKCiIiyDSAgDEtJHB2UmRHnoJ45Opyf2yAqmlBQBIriKIwigzAsogMDSgpF9kVpoWW1gKy1pX1/fyTtdMlyk9y1fT7n9LRN03u/ubl58t73vu+NkFKCiIisK8roAoiIKDIMciIii2OQExFZHIOciMjiGORERBZXy4iVNm/eXCYlJRmxaiIiy8rMzPxZStmi8u2GBHlSUhLcbrcRqyYisiwhRI6v29m1QkRkcQxyIiKLY5ATEVkcg5yIyOIY5EREFqdKkAshPhBCnBZC7FZjeTXGkiVAUhIQFeX5vmSJ0RURkQWp1SL/EEB/lZZVMyxZAqSlATk5gJSe72lpqoc53yvCV3nbPfGEftsy1OdN6+eZ+5HJSSlV+QKQBGC3kvumpqbKGicjQ8rERCmF8HyPi5PSE+EVvxITVV1lbGzFxcfGem6nwHxtu8pfWm3LUJ83rZ9n7kfmAcAtfeWvrxvD+QoW5ADSALgBuBMSEnR50KahJBVKv4RQbbWJiZq/V1Rb/radHtvS37rj4iq2BUqDVOvnOdDyK7dPGO7aMjzIy3/VuBa50lRQORmE0Py9otryt+302JZK113aKlb6PJcP3bg4z5eSAA5UT0xM1d8Z5trxF+QctaKHo0eV3S82Fpg5U7XVJiQEvl3Nfk+9+lD1Wo+/bRfu/bRY95UrwIwZwZ9noOopmfx8z5f0np4ZMwZo3tz39vS3/KgooKio4m1FRcDUqcrq15qR5zh05yvdw/kCW+T+hXqsrJJAfZtq9nvq1YeqR19w+RZr7drKWsRqKV1/aWta6RGBku3i75SMksfkb/mBlmM0I85x6NHNBC27VgB8DOAEgCIAuQAmBrp/jQtyA88W+du51OxX1asvXsv1+HqKYmIqdj84HNq9UH2tvzTMlZwbDxQiGRnKQjzQ9vS1fDMHud7nOPR6iWsa5KF+1bggl9J0Z4XU7D/Xqy9ey/Wo8SYRyVMcbP2RBEUop2hC2Z7+3lzi4pQ/bq3ofY5Dr8YMg9zsdA76xLiLbJGXE+mbhMNRdRmhtMiUrN/XLqJkt1EaaqFuz4yMqt1PtWsb3kaRUurfIterMcMgNzO1j8uCvbozMmRGzHgZi0sVV1m7SLc+8nDet9Tu11drWH+gkSNKgyKcNyml2yOUFnmo21PBrmbIgWgkfeTh1MwWeU2XkSFldLR6e4GSV7d3r8vAKJmII1KgWCbiiMyIeyqih6F0548kkNUIBl/rr1276lA6NboulLbIwtkmSsMjWKiV74sP5c032DkDoycShVpvJDWzj7wmU9JsUKtpEB39v/83eIC50ROVAg0iKt8yj4tT9uIP1HWhZR97KE9juGPIy/9/qC1co5/ncERSsx5HJwxyMwp2zBtOp2ugVCn9f4NfYYECSI9D8WCbSMlTomQEUOnj0UooLfJQjpZ8Bb6/g8ZA69arvaDmPhNuzUpCXI0WO4PcjAIlSridrsHeHEr3MgOPeQO1iPUoK9BBi5Kwqlybv6GDDkfgOiINICVPYyhPtZJWd7Cv8oGnR3tB7V1Zi3MVavaeMsjNKJxE8dc0KN/SDvYGUf7+wVJEgyayvx0/3BOOoZbob/3hBFf54YFq1BBOmAdabyjBFOowxWDL1aO9EKhmtd4cg70pB9rGSs5PhIJBbkaB9vRQXoGhNKVC7bTV6JXoK4D8HtaiRPUSfa0/nCALt5tAr96tULoKQh2mWPnL36xQLbvKgtUczu4a6lDSQNtYyQFyKBjkZlU+QUpb4qWn2JUmVKBO2kj2arXTJsir2u/qcMRvk0jNEsPpWgh1PcHeMNTuP1a7RR4drc9MV6WU1BzqcxTqPhXo/mq/yTDIzcxfs1LpKyXYsIlwX2lqnq1S0HTOyJAyVlyueBdckhkY5ffMYbASw+ny8DVkTY33RaWDlNSkZh+5Ga9BrmSbhrq7hrrbB3r5+ltW+UFkoWCQm1mkzUqtjtOVDGWMdFmVaszAYxXHtmNUwMcTav+klvOsgtGihaZ23ZEOU4y0pnC2cbCjHK1b5P7q1mI0E4PczCJt+WrVlx2ouRPq8pU+xhD7HdQ6zaCHYAdOZmvtaiVQCzaS3VjNE8hqLCfQ8x0uBrmZqZE4Wp1VUmPsVCjLCHTWs/KQCO/jzYh7SibGXazy0FUfxxzhNlb6NGt9gtBooQ7WMuJloMZytGhIMMi1pMeAYCNFkohKzyDGxf3v8QYbNqBwe6n6Qor0qlgKyzb7rqCGUEfH6DThWHVaPJcMcq1oMSBYy87JcESSiKGM6YuK+t/jDrQNlPa3q/VCUnqUoHBRVd7zy92YGH1M9VZc0PUr+JtaAh2cqXnJIbNQe5syyLUSabPP1zAJszXJIknESAYn+1tHCEcIqryQtBwvWGnbChRr1ioN9DSqOR4/lPWXX5cau77awWm2bi4GuVZCOaOhJLRVavmpLtw9OpQWudLHreVZTF+PM9hZykjWVakZmogjmj20QJst3JEaAYO30rb0dw388uuKZGy6r3p8XRsnkuUZ3aZikGtF6RijQHuZkq/q1FEY6ePWc5ROoGsHhDOOrPxwGh/bIQOjql4nXqXwECjx+zDCOQ0SMPx9bEt/RxtqPVa1h3eabdSTlAxy7SjtP9WiZWoVvo5Egn2ycbDHreZA5FL+niNfV/NSclUsXzUrOPGbgVEyMfqYuofzGRkyUeT43cThhFbA8PexQH9HG2rt6kraBz6X7Wef8X/JiGLD+lkY5Fryt9eUb86E2gpVq5liVhkZwT/aPdRBxL6WF8oyAiWTluPRIqk5hHX7bO2Ly2H3kQcMf4VHG8FeNiE+xNCXHeCB+318OGLY65JBriUlzZlAXTCVdyIzXMRCL/5aqeWHI4a7jFCbeFofSyt5Mw937rbCdVf5VCg8VnaXUN+rAoa/n21Ztn6h/igVJQc8VZYd4Dn3+fhKLxmh9r6hEINcS5EMEK5Joe1PpK3dYE0xo2fIKq1TyxZeoG6jCPh96hR0OWqxuQOdgvC57CAnB8oen69LRoSyb6mEQa41JWFktrFM1UWwlm4orSYtnyO1h1WEum5f5yViYrRbr4JJVFpv7qDLVnoUZpIznwxyqr4CtXTNdn7ByDfzcD+5IxJmb7woPSwwyVhEBjlVX2r0s9cEql98pppQ+mZjgjclf0EuPH/Tl81mk263W/f1UjW2ZAkwYwZw9CiQkADMnAmMHm10VeaSlATk5FS9PTERyM7WuxoKgxAiU0ppq3x7lBHFEKlu9GhPGJWUeL4zxKuaOROIja14W2ys53ayNAY5UU0xejSQnu5pgQvh+Z6ezje9aqCW0QUQkY5Gj2ZwV0NskRMRWRyDnIjI4hjkREQWxyAnIrI4BjkRkcUxyImILE6VIBdC9BdCHBBCHBZCPKfGMomISJmIg1wIEQ3gXQAPAkgGMEoIkRzpcomISBk1WuS3AjgspfxJSlkIYCmAR1RYLhERKaBGkLcBcKzc77ne2yoQQqQJIdxCCPeZM2dUWC0REQE6nuyUUqZLKW1SSluLFi30Wi0RUbWnRpDnAWhb7vd4721ERKQDNYJ8G4D2Qoh2QojaAEYCWKHCcomISIGIr34opbwmhHgSwBoA0QA+kFLuibgyIiJSRJXL2EopVwFYpcayiIgoNJzZSURkcQxyIiKLY5ATEVkcg5yIyOIY5EREFscgJyKyOAY5EZHFMciJiCyOQU5EZHEMciIii2OQExFZHIOciMjiGORERBbHICcisjgGORGRxTHIiYgsjkFORGRxDHIiIotjkBMRWRyDnIjI4hjkREQWxyAnIrI4BjkRkcUxyImILM5SQf7Od+9g8CeD8e8f/40SWWJ0OUREpmCpIJeQ2HR0Ex7IeAAd3umA1ze/jp+v/Gx0WUREhrJUkP/htj8g95lcLBmyBK0atsKz655Fm1ltMGb5GGw+uhlSSqNLJCLSnTAi/Gw2m3S73REvZ/fp3XC5XVi0cxEuFl7ELdfdAofNgTFdxqBhnYYqVEpEZB5CiEwppa3K7VYO8lKXCi/h410fw+l2YsfJHWhQuwFG3zIaDpsDKS1TVFsPEZGRqnWQl5JSYtvxbXC6nVi6eykKrhXg9vjbYbfZMTx5OOrF1FN9nUREeqkRQV7euavnsHDnQjjdThzMP4hm9ZphfMp42G12tI9rr+m6iYi0UOOCvJSUEl9nfw2n24nP93+OayXXcN8N98Fhc2DQTYNQK6qWLnUQEUWqxgZ5eScunsD7O95HemY6jl04htYNW2NSt0mYnDoZ8Y3ida+HiCgUDPJyrpVcw6pDq+Byu7D68GpEiSg8fNPDsKfacf9v7keUsNSoTCKqIfwFeUSJJYQYLoTYI4QoEUJUWbhZ1YqqhUE3DcKq0atw+A+H8ec7/ozNRzej/5L+nGhERJYTadNzN4AhAL5VoRZD3ND0Bvzfff+HY88cw0dDPkLrhq050YiILCWiIJdS7pNSHlCrGCPVqVUHo24ZhW9/9y12O3YjrXsavjz4Je5ccCdSXClwbnPiwq8XjC6TiKgKdgb70Om6TnhnwDvIm5aH9IHpqBVVC0+segJtZrWBfaUdO0/uNLpEIqIyQU92CiHWAWjp408zpJRfeO+zEcCfpJR+z2AKIdIApAFAQkJCak5OTrg16650opHL7cLHuz9GwbUC9IzvCYfNwYlGRKQbTUetKAny8owetRKJc1fPYdHORXC6nTiQf4ATjYhIN5qMWqmJmtZriqk9p2Lf7/dhw+MbcG+7e/H292+jw5wOuH/x/Vi+bzmKiouMLpOIapCIWuRCiMEA3gHQAsB5AFlSygeC/Z+VW+S+nLh4Ah/s+ADp29Nx9JejnGhERJrghCAdFJcUY9WhVXC6nVh9eDWEEHi4w8Nw2BycaEREEWOQ6+zIuSNIz0zH+zvex5krZ3BD0xswJXUKftf1d2hRv4XR5RGRBTHIDfLrtV/xz/3/hNPtxLc536J2dG0MSx4Gh82BXm17QQhhdIlEZBEMchPYc3oP5mbOxcKdC3Hh1wvofF3nsk80alSnkdHlEZHJMchN5HLhZXy82/OJRttPbEf9mPqeTzTq4UDXll2NLo+ITIpBbkJSSriPu8s+0ejqtavoGd8T9lQ7ftvpt5xoREQVMMhNrnSikSvThf0/70fTuk0xvqtnolGHuA5Gl0dEJsAgtwgpJTZmb4TT7cQ/9/8T10qu4d5298Jus+ORmx5BTHSM0SUSkUEY5BZ08tJJvL/9/bKJRq0atMKk7pMwuftktG3c1ujyiEhnDHILKy4pxleHv4LT7cRXh74qm2hkt9nR7zf9ONGIqIZgkFcTR84dwbzt8/D+jvdx+vJpTjQiqkEY5NVMYXEh/rnPM9Hom5xvyiYa2VPtuDPhTk40IqqGGOTV2N4ze+FyuypMNLKn2jE2ZSwnGhFVIwzyGqB0opHL7ULmicyyiUZ2mx3dWnUzujwiihCDvIZxH3fDuc2Jj3d/jKvXruK2NrfBYXNwohGRhTHIa6hzV89h8Q+L4XQ7K0w0mpI6BTc1v8no8ogoBAzyGk5KiW9yvoHL7cKyfctwreQa7ml3Dxw2BycaEVkEg5zKnLx0Eh/s+ABzM+dyohGRhTDIqYrSiUYutwurDq2CEAIDOwyEw+bgRCMiE2KQU0DZ57PLPtHo9OXTaNekHaakTsGEbhM40YjIJBjkpEjpRCNXpgsbszeidnRtDO04FA6bgxONiAzGIKeQ7Tuzr2yi0S+//oJOLTrBbrNjbJexaFy3sdHlEdU4DHIK2+XCy1i6eymcbmfZRKPHbnkMDpuDE42IdMQgJ1X4mmhkt9kxotMITjQi0hiDnFR1vuA8Fu1cVDbRqEndJhif4vlEI040ItIGg5w0IaXEtznfwul2Yvm+5SgqKULfpL5w2Bx49OZHOdGISEUMctLcqUun8P6O95GemY6cX3LQskFLTOo2CZNTJyOhcYLR5RFZHoOcdFNcUozVh1fD6XaWTTR6qP1DZRONoqOijS6RyJIY5GSInPM5SM9Mx/wd83H68mkkNUkqm2h0Xf3rjC6PyFIY5GSoyhONYqJiPJ9oZLOjd0JvTjQiUoBBTqbBiUZE4WGQk+lcKbpSNtHIfdxdNtHIbrOje6vuRpdHZDoMcjI193E3XG4XPtr1Ea5eu4pb29wKe6odIzqPQGxMrNHlEZkCg5wsoXSikcvtwr6f95VNNJpim4Kbm99sdHlEhmKQk6VwohFRVQxysqxTl06VfaJR6USjid0mIi01jRONqEbRJMiFEK8DeBhAIYAfAfxOSnk+2P8xyCkcnGhENZ1WQd4PwAYp5TUhxKsAIKWcHuz/GOQUqZzzOZi3fR7mb5+PU5dPcaIR1Qiad60IIQYDGCalHB3svgxyUkthcSE+3/85XG4Xvs7+GjFRMRia7PlEI040oupGjyD/EsAnUsoMP39PA5AGAAkJCak5OTmqrJeo1P6f95dNNDpfcB7JLZJhT7Xj8ZTHOdGIqoWwg1wIsQ5ASx9/miGl/MJ7nxkAbACGSAXvDGyRk5auFF3BJ7s/gdPtxLbj2xAbE4vHOnsmGqW2TjW6PKKwadYiF0KMBzAFwL1SyitK/odBTnqpPNGoR+secNgcnGhElqTVyc7+AGYB6COlPKP0/xjkpLfzBeexeOdiON3OsolG41LGwW6zc6IRWYZWQX4YQB0A+d6btkop7cH+j0FORpFS4j9H/wOn24lle5eVTTSy2+x49OZHUTu6ttElEvnFCUFElZy6dAoLshZgbuZcZJ/PxvX1ry+baJTYJNHo8oiqYJAT+VFcUow1P66B0+3Evw7+C0IIDGg/APZUO/rf2J8Tjcg0GOREClSeaJTYOBFTUqdgYveJnGhEhmOQE4XA30Qje6oddyXexYlGZAgGOVGY9v+8H3Pdc/Hhzg8rTDQamzIWTeo2Mbo8qkEY5EQRKp1o5Mp04fu87xEbE4tRnUfBYXNwohHpgkFOpKLtJ7bD5XZhya4luFJ0BbbWNjhsDozsPJITjUgzDHIiDfxS8AsW/+CZaLT3zF40rtO4bKJRxxYdjS6PqhkGOZGGSicaudwufLb3MxSVFOHupLvLPtGIE41IDQxyIp2cvnwaC3YsgCvTxYlGpCoGOZHOSmQJ1hz2TjQ69C9IKTGg/QA4bA5ONKKwMMiJDHT0l6OYlzkP83fMx8lLJ5HYOBFpqWmY2G0irm9wvdHlkUUwyIlMoKi4CF8c+AJOtxMbjmxATFQMhnQcArvNjj6JfTjRiAJikBOZzIGfD2Bu5lx8mPUhzhWcQ8fmHWG3eT7RiBONyBcGOZFJXS26ik/2eD7R6Pu871GvVj3PRKMeDthaV3nNUg3GICeygMoTjVJbpZZNNKpfu77R5ZHBGOREFlI60cjldmHPmT1oXKcxHk95HA6bgxONajDTB3lRURFyc3NRUFCgez01Ud26dREfH4+YmBijS6EApJTYdHST5xON9i1DYXEh+iT2gcPmwOCOgznRqIYxfZAfOXIEDRs2RFxcHM/ca0xKifz8fFy8eBHt2rUzuhxSqHSi0dzMuThy/giuq39d2USjpCZJRpdHOvAX5FFGFONLQUEBQ1wnQgjExcXx6Mdirqt/HabfOR2H/3AYX43+Cj3je+LVza/ihrduwEMfPYSVB1eiuKTY6DLJALWMLqA8hrh+uK2tK0pEof+N/dH/xv449ssxzNs+D/O2z8PDHz+MhMYJSOuehondJ6Jlg5ZGl0o6MU2LnIhC17ZxW7zU9yUcffooPh3+KW5sdiOe//p5tH2zLUZ8NgIbszfCiO5T0heDPEzZ2dn46KOPyn7PysrCqlWryn5fsWIFXnnlFVXWNX78eHz22WcAgEmTJmHv3r2qLJeqj5joGAxLHob1j6/H/t/vx1O3PoW1P65F34V9kfxeMt7a+hbOF5w3ukzSCIM8TMGCfNCgQXjuuedUX+/8+fORnJys+nKp+rip+U2Y9cAs5E3Lw4ePfIjGdRrj6TVPo/XfW2PCFxOwLW+b0SWSykzVR17q6dVPI+tklqrL7NqyK2b3nx3wPosWLcIbb7wBIQS6dOmCxYsXY/z48Rg4cCCGDRsGAGjQoAEuXbqE5557Dvv27UPXrl0xatQovPvuu7h69So2bdqEv/zlL7h69SrcbjfmzJmD8ePHo1GjRnC73Th58iRee+01DBs2DCUlJXjyySexYcMGtG3bFjExMZgwYULZuny5++678cYbb8Bms6FBgwaYOnUqVq5ciXr16uGLL77A9ddfjzNnzsBut+Po0aMAgNmzZ6NXr16qbUuyhnox9TCu6ziM6zoOO07sKJtotCBrAVJbpcJus2NU51GcaFQNsEXutWfPHvztb3/Dhg0bsHPnTrz11lsB7//KK6+gd+/eyMrKwvTp0/HSSy9hxIgRyMrKwogRI6rc/8SJE9i0aRNWrlxZ1lJfvnw5srOzsXfvXixevBhbtmwJqebLly+jZ8+e2LlzJ+666y7MmzcPADB16lQ888wz2LZtG5YtW4ZJkyaFtFyqfrq16oa5D89F3rQ8zHlwDgquFWDyl5PRelZrPLXqKew9w+46KzNlizxYy1kLGzZswPDhw9G8eXMAQLNmzVRd/qOPPoqoqCgkJyfj1KlTAIBNmzZh+PDhiIqKQsuWLdG3b9+Qllm7dm0MHDgQAJCamoq1a9cCANatW1ehH/3ChQu4dOkSGjRooNKjIatqXLcxfn/r7/FEjyew+dhmON1OpG9Px5xtc3BX4l1w2BwY0nEIJxpZjCmD3Exq1aqFkpISAEBJSQkKCwvDWk6dOnXKflZrFEFMTEzZMMLo6Ghcu3YNgKfOrVu3om7duqqsh6ofIQTuTLgTdybcidkPzMaCrAVwuV0YtWwUJxpZELtWvO655x58+umnyM/PBwCcPXsWAJCUlITMzEwAnpEoRUVFAICGDRvi4sWLZf9f+XclevXqhWXLlqGkpASnTp3Cxo0bVXgkQL9+/fDOO++U/Z6VlaXKcql6alG/BZ7t9WzZRKPb42+vMNHoywNfcqKRyTHIvTp16oQZM2agT58+SElJwbRp0wAAkydPxjfffIOUlBRs2bIF9et7Tgx16dIF0dHRSElJwZtvvom+ffti79696Nq1Kz755BNF6xw6dCji4+ORnJyMMWPGoHv37mjcuHHEj+Xtt9+G2+1Gly5dkJycDJfLFfEyqfornWj0+cjPkT01Gy/c9QJ2nNiBQUsH4Ya3b8DMb2fi5KWTRpdJPpjmWiv79u1Dx44176pupX3X+fn5uPXWW7F582a0bKnPjLyaus1JuaLiIqw4sAJOtxPrj6xHrahaGHzzYDhsDtyddDdnCOvM37VW2EdusIEDB+L8+fMoLCzECy+8oFuIEykREx2DoclDMTR5KA7mH4TL7cKHWR/i072f4qa4m2C32TEuZRya1mtqdKk1GlvkNRi3OYXjatFV/GPPP+DKdGFr7lbUq1UPIzuPhMPm+UQjttK1Y/qrHxKRNZRONNoycQt2TNmBx1Mexz/2/AO3zr8Vtnk2zN8+H5cLLxtdZo3CICeisHVt2RWugS4c/+NxvDvgXRQWF3KikQEY5EQUsUZ1GuGJHk/gB/sP+M/v/oOBHQYifXs6Or3XCX0+7IOlu5eisDi8ORgUXERBLoR4WQjxgxAiSwjxbyFEa7UKIyLrKZ1otGTIEuQ+k4tX73sVuRdyMWrZKLR9sy3+su4vyD6fbXSZ1U6kLfLXpZRdpJRdAawE8GLkJSm0ZAmQlARERXm+L1mi26qJKLjSiUaHnjqE1aNX4/b42/Haf1/jRCMNRBTkUsoL5X6tD0CfITBLlgBpaUBODiCl53tamm5hvnDhQrRv3x7t27fHwoULdVknkVVFiSg8cOMD+Hzk58h5OqfCRKN2b7XD3779GycaRSji4YdCiJkAHgfwC4C+Usozfu6XBiANABISElJzcnIq/D2koXBJSZ7wriwxEcjOVlx7OM6ePQubzQa32w0hBFJTU5GZmYmmTa03jpbDD8koRcVF+PLgl3C6nVj30zpONFIo7OGHQoh1QojdPr4eAQAp5QwpZVsASwA86W85Usp0KaVNSmlr0aJFJI8F8F5nW/HtCmzbtg1dunRBQUEBLl++jE6dOmH37t1V7rdmzRrcf//9aNasGZo2bYr7778fq1evDnu9RDVRTHQMhnQcgrVj1+Lgkwcx9bapWH9kPe5ZdA86vtsRs7fOxrmr54wu0zKCBrmU8j4pZWcfX19UuusSAEO1KbOShITQblegR48eGDRoEJ5//nk8++yzGDNmDDp37lzlfnl5eWjbtm3Z7/Hx8cjLywt7vUQ1Xfu49nij3xvIfSYXCx9diKb1muKZNc+gzaw2mPDFBHyf9z0/dzSISEettC/36yMA9kdWjkIzZwKxsRVvi4313B6BF198EWvXroXb7cazzz4b0bKIKDT1Yurh8ZTHq0w0um3+bbDNs2Fe5jxONPIj0lErr3i7WX4A0A/AVBVqCm70aCA93dMnLoTne3q65/YI5Ofn49KlS7h48SIKCgp83qdNmzY4duxY2e+5ublo06ZNROsloorKTzR6b8B7KCouQtrKtLKJRntO7zG6RFPhtVbKGTRoEEaOHIkjR47gxIkTmDNnTpX7nD17Fqmpqdi+fTsAoHv37sjMzFT9E4X0YIZtTqSElBL/PfZfON1OfLr3UxQWF6J3Qu+yTzSqU6tO8IVUA7z6YRCLFi1CTEwMHnvsMRQXF+OOO+7Ahg0bcM8991S4X7NmzfDCCy+gR48eADzdMVYMcSIrEUKgV0Iv9Erohdn9Z2PBjgVwZbrw2PLH0CK2BSZ0m4ApqVPQrmk7o0s1BFvkNRi3OVlZiSzBup/Wwel2YsWBFZBSov+N/eGwOTCg/QBER0UbXaLq2CInomolSkSh32/6od9v+iH3Qi7mb5+PedvnYdDSQWjbqC3SUtMwsdtEtGrYyuhSNceLZvmxa9cudO3atcLXbbfdZnRZRORDfKN4/PXuvyJ7ajaW/3Y5bm5+M174+gUkzE7A8E+HY/1P66v1EEa2yP245ZZb+KHFRBYTEx2DwR0HY3DHwTiUfwhzM+diQdYCfLb3M3SI6wB7qh3juo5Ds3rV67wWW+REVC1VnmgUVy8O0/49DW1mtcH4z8fju9zvqk0rnUFORNVa6USj/078L7KmZGFcyjgs27cMPd/vidT0VMzLnIdLhZeMLjMiDHIiqjFSWqbANdCFvGl5eG/Ae7hWcg1pK9PQZlYbPLnqSew+XfX6SlZg2SDn5ciJKFyN6jSCo4cDO+07sXnCZgy6aRDmbZ+HW5y3oPeC3vho10f49dqvRpepmCWD3ODLkaN///5o0qQJBg4cqM8KiUgTQgjc0fYOLB68GHnT8vD6/a/jxMUTGL18NOLfjMf0tdPx07mfjC4zKEsG+YwZwJUrFW+7csVzux7+/Oc/Y/HixfqsjIh00Ty2Of50x59w8KmDWDNmDXon9Mbft/wdN759Ix5c8iBWHFhh2k80smSQa3A5csXXIweAe++9Fw0bNgx/ZURkWqUTjZaPWI7sp7PxYp8X8cOpH/DI0kfQ7q12ePmbl3Hi4gmjy6zAkkGuweXIFV+PnIhqjtKJRjlP52D5b5ejY4uOeHHji6abaGTJINfocuS8HjkR+VQrqhYGdxyMNWPW4NBTh/D0bU9jw5ENuG/xfbj53Zsxa8ssnL161rD6LBnkGl2OXNH1yImoZrux2Y14vd/ryJuWh0WPLkLz2Ob447//aOhEI0sGOeAJ7exsoKTE8z3SEAeAKVOm4OWXX8bo0aMxffr0yBdIRNVW3Vp1MTZlLDZP2IysKVkYnzK+wkSj9Mx03SYaWTbI1Vb+euTPPfcctm3bhg0bNvi8b+/evTF8+HCsX78e8fHxWLNmjc7VEpGZpLRMgXOgE8enHYfzISeKZTGmrJyC1n9vjbnuuZqvn9cjr8G4zYm0IaXEltwtcLldGNV5FB5s/6Aqy+X1yImIdFI60eiOtnfosj4GuR+7du3C2LFjK9xWp04dfPfddwZVRETkm6mCXEoJIYTRZQCo/tcjN8PYVyJSh2lOdtatWxf5+fkMGB1IKZGfn4+6desaXQoRqcA0LfL4+Hjk5ubizJkzRpdSI9StWxfx8fFGl0FEKjBNkMfExKBdu3ZGl0FEZDmm6VohIqLwMMiJiCyOQU5EZHGGzOwUQpwBkBPmvzcH8LOK5aiFdYWGdYWGdYXGrHUBkdWWKKVsUflGQ4I8EkIIt68pqkZjXaFhXaFhXaExa12ANrWxa4WIyOIY5EREFmfFIE83ugA/WFdoWFdoWFdozFoXoEFtlusjJyKiiqzYIicionIY5EREFmfaIBdC9BdCHBBCHBZCPOfj79OEEHuFED8IIdYLIRJNUpddCLFLCJElhNgkhEg2Q13l7jdUCCGFELoMzVKwvcYLIc54t1eWEGKSGery3ue33n1sjxDiIzPUJYR4s9y2OiiEOG+SuhKEEF8LIXZ4X5MDTFJXojcffhBCbBRC6HKlOCHEB0KI00KI3X7+LoQQb3vr/kEI0T2iFUopTfcFIBrAjwBuAFAbwE4AyZXu0xdArPdnB4BPTFJXo3I/DwKw2gx1ee/XEMC3ALYCsJmhLgDjAcwx4f7VHsAOAE29v19nhroq3f8pAB+YoS54TuA5vD8nA8g2SV2fAhjn/fkeAIt12sfuAtAdwG4/fx8A4CsAAkBPAN9Fsj6ztshvBXBYSvmTlLIQwFIAj5S/g5TyaynlFe+vWwHo8U6rpK4L5X6tD0CPs8lB6/J6GcCrAAp0qCmUuvSmpK7JAN6VUp4DACnlaZPUVd4oAB+bpC4JoJH358YAjpukrmQApZ+i/rWPv2tCSvktgLMB7vIIgEXSYyuAJkKIVuGuz6xB3gbAsXK/53pv82ciPO9uWlNUlxDi90KIHwG8BuAPZqjLe+jWVkr5Lx3qUVyX11Dv4eVnQoi2JqmrA4AOQojNQoitQoj+JqkLgKfLAEA7/C+kjK7rrwDGCCFyAayC52jBDHXtBDDE+/NgAA2FEHE61BZMqBkXkFmDXDEhxBgANgCvG11LKSnlu1LK3wCYDuB5o+sRQkQBmAXgj0bX4sOXAJKklF0ArAWw0OB6StWCp3vlbnhavvOEEE2MLKiSkQA+k1IWG12I1ygAH0op4+HpNljs3e+M9icAfYQQOwD0AZAHwCzbTDVm2NC+5AEo3zKL995WgRDiPgAzAAySUv5qlrrKWQrgUS0L8gpWV0MAnQFsFEJkw9Mnt0KHE55Bt5eUMr/cczcfQKrGNSmqC54W0gopZZGU8giAg/AEu9F1lRoJfbpVAGV1TQTwDwCQUm4BUBeei0MZWpeU8riUcoiUshs8WQEp5XmN61Ii1CwJTI+O/zBOFNQC8BM8h46lJzE6VbpPN3hOdLQ3WV3ty/38MAC3GeqqdP+N0Odkp5Lt1arcz4MBbDVJXf0BLPT+3Byew+A4o+vy3u9mANnwTugzyfb6CsB4788d4ekj17Q+hXU1BxDl/XkmgJf02Gbe9SXB/8nOh1DxZOf3Ea1LrwcVxkYYAE8r6EcAM7y3vQRP6xsA1gE4BSDL+7XCJHW9BWCPt6avAwWqnnVVuq8uQa5we/2fd3vt9G6vm01Sl4CnO2ovgF0ARpqhLu/vfwXwih71hLC9kgFs9j6PWQD6maSuYQAOee8zH0Adner6GMAJAEXwHN1NBGAHYC+3f73rrXtXpK9HTtEnIrI4s/aRExGRQgxyIiKLY5ATEVkcg5yIyOIY5EREFscgJyKyOAY5EZHF/T/+gyEj9TYh0wAAAABJRU5ErkJggg==\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出参数更新之前的结果 \n",
- "w0 = w[0].data[0]\n",
- "w1 = w[1].data[0]\n",
- "b0 = b.data[0]\n",
- "\n",
- "plot_x = np.arange(0.2, 1, 0.01)\n",
- "plot_y = (-w0.numpy() * plot_x - b0.numpy()) / w1.numpy()\n",
- "\n",
- "plt.plot(plot_x, plot_y, 'g', label='cutting line')\n",
- "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n",
- "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n",
- "plt.legend(loc='best')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看到分类效果基本是混乱的,我们来计算一下 loss,公式如下\n",
- "\n",
- "$$\n",
- "loss = -\\{ y * log(\\hat{y}) + (1 - y) * log(1 - \\hat{y}) \\}\n",
- "$$"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 计算loss, 使用clamp的目的是防止数据过小而对结果产生较大影响。\n",
- "def binary_loss(y_pred, y):\n",
- " logits = (y * y_pred.clamp(1e-12).log() + \\\n",
- " (1 - y) * (1 - y_pred).clamp(1e-12).log()).mean()\n",
- " return -logits"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "注意到其中使用 `.clamp`,这是[文档](http://pytorch.org/docs/0.3.0/torch.html?highlight=clamp#torch.clamp)的内容,查看一下,并且思考一下这里是否一定要使用这个函数,如果不使用会出现什么样的结果。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(0.7655, grad_fn=<NegBackward>)\n"
- ]
- }
- ],
- "source": [
- "y_pred = logistic_regression(x_data)\n",
- "loss = binary_loss(y_pred, y_data)\n",
- "loss.backward()\n",
- "print(loss)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "得到 loss 之后,我们还是使用梯度下降法更新参数,这里可以使用自动求导来直接得到参数的导数,感兴趣的同学可以去手动推导一下导数的公式"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 37,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n",
- "During Time: 0.306 s\n"
- ]
- }
- ],
- "source": [
- "start = time.time()\n",
- "\n",
- "# 自动求导并更新参数\n",
- "for i in range(1000):\n",
- " # 算出一次更新之后的loss\n",
- " y_pred = logistic_regression(x_data)\n",
- " loss = binary_loss(y_pred, y_data)\n",
- " \n",
- " # calc grad & update w,b\n",
- " loss.backward()\n",
- " w.data = w.data - 0.1 * w.grad.data\n",
- " b.data = b.data - 0.1 * b.grad.data\n",
- "\n",
- " # clear w,b grad\n",
- " w.grad.data.zero_()\n",
- " b.grad.data.zero_()\n",
- " \n",
- "during = time.time() - start\n",
- "print()\n",
- "print('During Time: {:.3f} s'.format(during))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7f36cedaf310>"
- ]
- },
- "execution_count": 26,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAzAElEQVR4nO3dd3gU1frA8e9JSOi9WAgB9GJJ6IRyRRRRuIgYLCBwEUWQav0pKgp6I4gCegWUDnKlioANkaLSkRoklAQrRboQlF6S7Pn9sRsNIZvsJtN29/08zz7J7k7mvDuZfefMOWfOKK01QgghAl+Y3QEIIYQwhiR0IYQIEpLQhRAiSEhCF0KIICEJXQghgkQhuwquUKGCrlatml3FCyFEQNqyZctxrXXFnN6zLaFXq1aNxMREu4oXQoiApJTa5+09aXIRQoggIQldCCGChCR0IYQIEpLQhRAiSEhCF0KIICEJXQghgoQkdCGECBIBl9B/OP4Dg5YP4mL6RbtDEUIIRwm4hP7lj18ydM1Q6k2sx8YDG+0ORwghHCPgEvoLTV9gcZfFnLl0hlum3kL/r/tzLu2c3WEJIYTtAi6hA7T+R2t29ttJz/o9+e/6/1J3Ql3W7Ftjd1hCCGGrgEzoAKUKl2JC2wkse2QZ6a50bv/wdp5e/DRnLp2xOzQhhLBFwCb0TC2qt2B73+081egpxmwaQ63xtVi2e5ndYQkhhOUCPqEDlIgswei7R7P6sdVEhkdy14y76P1lb05eOGl3aEIIYZmgSOiZbo2+laTeSbx4y4tM2TqFmuNrsvjnxXaHJYQQlgiqhA5QNKIow1sOZ32P9ZQuXJo2s9vw6OePcuL8CbtDE0IIUwVdQs/UqHIjtvTawqBmg5i1fRax42L5/IfP7Q5LCCFME7QJHaBwocIMaTGEzT03c3WJq7n/4/vpNL8Tx84eszs0IYQwXFAn9Ez1rqnHpsc3MeSOIXy661NixsXw8c6P0VrbHZoQQhgmJBI6QER4BINuG8T3vb+nepnqdPqkEw/OfZAjZ47YHZoQQhgiz4SulJqqlPpdKbXTy/tKKfWeUuoXpdR2pVR948M0Ts1KNVnXYx3D7xrOop8XETM2hhnbZkhtXQgR8HypoX8ItM7l/buBGp5HL2B8wcMyV6GwQrzY9EW29dlGTMUYHvn8Edp+1JYDpw5YH8ysWVCtGoSFuX/OmmV9DEKIoJBnQtdarwZyG/PXDpiu3TYAZZRS1xgVoJlurHAjq7qtYnTr0azcu5LYcbFM3jLZutr6rFnQqxfs2wdau3/26iVJXQiRL0a0oVcG9md5fsDzWkAIDwvn6cZPs73PdupfU59eC3vRckZL9v651/zCBw6Ec9lmijx3zv26EMgJnPCPpZ2iSqleSqlEpVTisWPOGjp4fbnrWfbIMsbfM56NBzdSc1xNxm4ai0u7zCv0t9/8ez2fJCkEJjmBE/4yIqEfBKpkeR7lee0KWutJWus4rXVcxYoVDSjaWGEqjD5xfUjul0zT6KY8ufhJmn/YnJ9TfzanwOho/17PB0kKxrLy4CgncMJfRiT0BcAjntEuTYCTWuvDBqzXNtGlo1nSZQkfxH/A9qPbqTOhDu+uf5cMV4axBQ0dCsWKXf5asWLu1w0iSaFgsibwChWge3frDo7+nsCZfbCRM70AoLXO9QF8BBwG0nC3j/cA+gB9PO8rYCzwK7ADiMtrnVprGjRooAPBwVMH9b2z79UkoBtPbqxTfk8xtoCZM7WuWlVrpdw/Z840dPVKae1OP5c/lDK0mKA0c6bWxYrlvP2yPqpWNaf8qlV9Ly+nWIsVM253Mnv9wndAovaWr729YfYjUBK61lq7XC49e/tsXX54eR05JFK/ufpNnZaR5v+KTE7eOfEnKYjLedt2Vh0c/UmiZv+fZT9yDknoBjly+ojuMLeDJgFdf2J9ve3INt//2KYqjtSs8s/b2Y2VSc3XOoDZZ2JypuccktANNj95vq70diVdaHAh/Z8V/9EX0y/m/Uc2VnFsODEICr7U0J1ycJQaeuiQhG6C42eP6y6fdNEkoGuNq6UTDybm/gdSxQk4OZ3dRERoXb688w6O0oYeOiShm2jBDwv0tf+9Voe/Hq5f/vZlfT7tfM4LShUnIAXS2Y3ZsQbStghmuSV05X7fenFxcToxMdGWso3254U/eX7p80xNmspNFW5iavxU/lnln5cvlDkgPOsYwmLFYNIk6NLF2oCFEAFLKbVFax2X03shM32umcoUKcMH7T5g6cNLOZd2jqZTm/Lc0uc4l5YleXfp4k7eVauCUu6fksxFPsh4cOGNJHQDtbq+FTv77qRPXB9GbhhJ7fG1WbV31d8LdOkCe/eCy+X+Kclc+MmJV/7KAcY5JKEbrGThkoy7ZxwrHl2BRtN8WnOe+OoJTl88bXdoIgg47cpfJx5gQpm0oZvo7KWzDFo+iNEbRxNdOprJ906m5fUt7Q5LBLCwMHfizE4p94mf1apVcyfx7KpWdZ+ECuNJG7pNikcWZ2TrkaztvpYihYrQamYrei7oyckLJ+0OTQQoC+Zz84tFE4YKH0lCt8AtVW5ha++tvNT0JaYmTSV2XCxf/fSV3WGJAGTkfG5GtH2XK+ff68JcktAtUjSiKMPuGsaGHhsoW7QsbT9qS9fPunLifG43gxLickYNlpK27+Akbeg2uJh+kTfXvMmba9+kfNHyjLtnHA/c/IDdYYkQMGuWuwM1p3Zv8L/t22lt+qFA2tAdpnChwrx+x+ts7rmZa0tey4NzH6Tj/I78fvZ3u0MLCDJMLn+y1sq98bft22lt+qFOErqN6l5dl42Pb2Roi6F8/sPnxI6LZc7OOdh11lRQViRaaSrIv5yGPGbnbyK24B4thgr6yoC3OQHMfgTLXC5G2Xl0p240uZEmAd3uo3b64KmDdofkF6smb5IpcfIvr+mA8/v/CpQ5XoJlgjFkLpfAkOHKYNSGUQxaMYgihYow8l8jebTOoyil7A4tT1aNR5Y22/zz9j8C9/9p6NDgu3g5s8/gt9/c+05GDneRDLQx89KGHiDCw8J5/pbn2dZnG7Uq1eKxLx6jzew2/HbSvEG9Rp2CWjUeWdps889b80jfvu7fu3YNrmaI7M1zOSVzCLIx896q7mY/pMkldxmuDP3+xvd18aHFdck3S+oJmydol8tlaBlGnoJa1RQSLKfNdsnePNK3b/BuT19vIWjkPmpF8xMyH3rg2n1it24xrYUmAd1iWgv964lfDVu3kUnYykQr834bJ5j7JHy5hWAg3gREEnqAc7lcelLiJF3yzZK62NBievSG0TrDlVHg9Rp9E6VgSIShdAYwc6b3RBcMN9LydrAKDzdnH7Xq4CgJPUj89udv+u6Zd2sS0E0/aKp/PP5jgdYXzLWz/LJ7m1h1UMzpwBVs+4DVB2er7jIpCT2IuFwu/eHWD3WZYWV0kTeK6BFrR+j0jPR8rSuUaqO+svPWr1bcFzTzYBEe7j2ZB9M+YOVZo9TQRb4dOnVIt/uonSYB3WhyI73z6M58rScYmkmMZGcN3cyy86qRZ32E+j6QX9KGLgrE5XLpOTvm6AojKujIIZH6jVVv6Evpl+wOK6DZedZi5tmBHSM+QpHdo1xkHHoAU0rRsWZHUvqlcP9N9zNoxSAaT2lM0pEku0MLWHbe+tXMMfa+jLV28iX7gcLuu0xKQg8CFYtXZE77OXz60KccOn2IhpMb8tqK17iYftHu0AKSXV9KM+dF8XZQCA+Xe5YHE0noQeT+m+8n5YkU/l3r3wxZPYQGkxqw+eBmu8MSPjLz7MDbwWLaNLlneTCRhB5kyhUtx7T7pvHVv7/izwt/0uSDJrz0zUucTztvd2jCB2adHdjZlCSsI5NzBbGTF07S/+v+TNk6hRvK38DU+Kk0jW5qd1hCiAKQyblCVOkipZkcP5lvun7DxfSLNPtfM55d8ixnL521OzQhhAkkoYeAu667i539dtKvYT9GbxxN7Qm1WbFnhd1hBaWgv4GCcDRJ6CGiRGQJxrQZw8pHV6JQtJjegr4L+3L64mm7QwsacjclYTdJ6CHm9mq3s73vdp5r8hwTt0yk5viaLP1lqd1hBYWcbvF27pz7dSGsIAk9BBWLKMZ///Vf1vVYR7GIYrSe1ZruX3Tnj/N/2B1aQLPqJh9CeONTQldKtVZK/aiU+kUpNSCH96OVUiuUUluVUtuVUm2MD1UYrUlUE7aVeokT40ox5b7/cebaCmx553m7wwpYcjclYbc8E7pSKhwYC9wNxACdlVIx2RYbBMzVWtcDOgHjjA5UmGDWLCL7PEHZ308RBlT508VNr7zL2Gdv4fi543ZHF3DMvNIzN9IRKzL5UkNvBPyitd6ttb4EzAHaZVtGA6U8v5cGDhkXYgix+puZQ6Nv8TRo++F6YsfFMj9lvrnlB6GiRf/+vXx58y/ekY5YkZUvCb0ysD/L8wOe17JKAB5WSh0AFgFP5bQipVQvpVSiUirx2LFj+Qg3iNnxzfTSuBt9SlGlVBU6zOtA+7ntOXrmqHkxBInMf19q6t+vnbfg4lzpiBVZGdUp2hn4UGsdBbQBZiilrli31nqS1jpOax1XsWJFg4oOEnZ8M7007qroaDY8voG37nyLhT8tJGZcDDO3z8Suq4oDgV2JVTpizRGozVi+JPSDQJUsz6M8r2XVA5gLoLVeDxQBKhgRYMiw45uZS6NvobBCDLh1AFt7b+WG8jfQ9bOuxM+J5+Cp7P96AfYlVumINV4gN2P5ktA3AzWUUtWVUpG4Oz0XZFvmN+BOAKXUzbgTurSp+MOOb6YPMzbdXPFm1j62lndbvcuy3cuIHRfL1K1TpbaejV2J1a6OWH8EWm03oJuxvN35IusDdzPKT8CvwEDPa4OBeM/vMcB3wDYgCWiV1zrljkXZGH2rHBNunfJz6s/6tv/dpklAt5rRSu/9Y2+B12nnLfCMLNvOOx05+TaCgXjfWjvvK+sL5BZ0Dpf5jYS/795bkG+mid+iDFeGHrtprC4+tLgu8WYJPX7zeJ3hynBamLaU7eTEahR/P6Od92jNL6fHLAndyczILBbskXv+2KNbTm+pSUDf8eEd+tcTvzoxTEeWHajys6s6vbabE6efVUhCdzIzMotF3yKXy6WnbJmiS71VShcbWkyPWj9Kp2ekOy1Mx5UdqPKzqwbqgdPJZ1u5JXSZy8VuZgyPsKiHTilFj/o9SO6XTPNqzXl26bPc9uFt/HD8B5/+3s4RGjI6xH/edsl9+7x3eAZCp21O7L7Zc35JQrebGZnF4m9RVKkoFnZeyPT7prPr2C7qTqjL8LXDSXelOylMx5QdqHLbJbWX4X1y6zuLeau6m/2QJhePvBrs8nvuZ9M54+HTh/UDHz+gSUDHTYrTO47ucGKYeZbt5FNuu+S0qwZic0qgQ9rQHc5b9nB674wXLpdLz905V1ccUVFHDI7Qg1cO1pfSL9kdls9y2uyZbe6hntyz7qreErr0Q5hLEnqgCtQeJY/fz/yuO8/vrElA1xlfR39/6Hu7Q/KJt80eQMdUSwT47mkLI878ckvo0obuZAE+UUfF4hWZ/eBsPu/4OUfPHqXh5IYMWj6Ii+kX7Q4tV3lt3oC5atBk0g/hHyumFJCE7mRGdZjafO11u5vakdIvhYdrP8zQNUOpP6k+mw5usjQGf/iyeQPkmGoq6fD0jyVTCnirupv9kCYXHxjRhu6wdvhFPy3SUe9G6bDXw3T/pf31uUvnbIkjN750/kmzgvCXUdc+IE0uAcqIKpDDZhq6u8bd7Oy7k8frPc4769+hzoQ6rP1trS2xeJN1s4N702cV7M0KgTaZVqCw5NoHb5ne7IfU0C3i4Esiv/31W119VHWtEpR+atFT+vTF03aHlKPsHVl9+wbekEZfO+McdkLns0AYZmrUtkVGuYQwhw9FOH3xtH560dNaJShdfVR1vWz3Mq21c7+ggZjw/InZ4btLjgLpf2L2KBdJ6MEuQPb2NfvW6Brv1XBP9vXcZF20mMuRIQdiwvMn5tzGlzvpwJqVU/4nVl2oJgk91Dm1upvNuUvndP+l/TWl9zriC5oTB7dgeeVPzIE4Bt8J/5Pc6k1G16lyS+jK/b714uLidGJioi1lC2cLC9Nora54XSn3ZEl2qlbNPX44u6pV3ZM4OZE/MWeOlc7ej57X39nJCf+T3GIAY+NTSm3RWsfl9J6MchGOEx19ZTJ3v25xIDkIxItp/Ik5+wifnDhtDL4T/ie5XQNo5fWBktCF4+T0BSXiLNe0e59jZ+29VW0gXkzjb8yZU8d6S+pOOLBm5YT/SW5DEi2dqtlbW4zZD2lDDyA2tMFnLTI62qUfHPiJjhgcoSuMqKDn7JijXS6X6TFYwcndGwHSn+4ITmlDl4Qucuegb/WOozt03KQ4TQL6/jn368OnD1seg5EctGm9cvIBx2lklItwPqeMCfNIy0jTw9YM04WHFNZlh5XV05KmBWxt3WGbNmgYlTydejDLLaHLKBeRu7Awd57JzuYhJz8e/5HuC7qzbv862tRow8S2E4kqFWVbPPnh0E0b0HIapVOsmP9t6katxwwyyiVUmDEJh0NvvnljhRtZ3W01o/41ihV7VlCj538of81pwsJ0wMw/4tBNG9CMmrrIYVMg+c5b1d3shzS5GMysBtkAaOh9d+IRHRZ53skh5igANm3AMeoiIydcrOQN0oYeAsxskHVqY6KHt48eHe38tnWHb9qAY9TXwMn9G7kldGlDDxYh3CDr7aODi59Td/OPcv+wOiRhE2lDF8Eh2Bpk/egP8PYRVZkD1B5fm5HrR5LhyjAlTOEsRl1k5ISLlfLFW9Xd7Ic0uRgsmBpk/fws3hZ/f8oJfe/sezUJ6CZTmuiU31Ms/iBCGA+5Y1EICNgqRQ78HGLg7aM/2aMsX3T6glkPzOKn1J+oO7Eub615i3RXugUfQgjrSRu6cB4T+gOOnjnKk4ufZH7KfBpc04Cp7aZS+6raBQxUCOtJG7oILCb0B1xV4irmdZjHvA7z2H9qPw0mNSBhZQKXMi7le50FIfftDC1W/b8loQvnyWm6RaWgTZsCr7p9THuS+yXzUOxDvL7qdeImxbHl0JYCr9cfmSMo9u1zn4js2+d+Lkk9OFn5/5aE7iRSbXPr0gUefdSdxDNpDdOmGbJNKhSrwKwHZvFFpy84fu44jac05pVlr3Ah/UKB1+2LgL0KUeSLlf9vSehOESrVNl8PWosWXdmObvC3IP7GeFKeSOGROo/w1tq3qDexHuv3rzds/d5YecMDYT/H3eBCKdVaKfWjUuoXpdQAL8s8pJRKUUolK6VmGxtmCAiFaps/By0DvwW5HUPKFCnD1HZTWdJlCWcvnaXp1KY8v/R5zqXlcg+2Agq2SwZE7hx1gwsgHPgVuA6IBLYBMdmWqQFsBcp6nlfKa70yDj0bJ08eYRR/rqc26Nprf4a0n7xwUvf5so8mAX396Ov1yj0r8/EhjY1JBD5H3eAC+CewNMvzl4GXsy0zAng8r3VlfUhCz8bJk0cYxZ+DlkHfgvxs1uW7l+vqo6prEtD9FvbTpy6c8qtMX8gcLqHFMTe4ANoDU7I87wqMybbM556k/h2wAWjtZV29gEQgMTo6Ov+fKBiFQrXN3+xqwLcgvyc+Zy6e0c8sfkarBKWrjqyqv/7la7/LFsIMuSV0ozpFC3maXZoDnYHJSqky2RfSWk/SWsdpreMqVqxoUNFBIpiu9PTG39uzZ96t2OVy/8zHtshv+2XxyOKMaj2KNY+toXChwrSa2YqeC3py8sJJv2MQwiq+JPSDQJUsz6M8r2V1AFigtU7TWu8BfsKd4IU/DEhgjmbDQcvfY0h2TaObktQ7iRdveZGpSVOJHRfLVz99ZXyg4i8yercAvFXd9d/NJIWA3UB1/u4Ujc22TGtgmuf3CsB+oHxu65U2dGEVo9ovNx3YpGPHxmoS0F0/7apTz6UaGabQodHyWFAUpMlFa50OPAksBXYBc7XWyUqpwUqpeM9iS4FUpVQKsAJ4QWudathRRzhDgFadjDrxaVi5IVt6beG1217jo50fETM2hs92fWZkqCElp90pFEbvmkkm5xK+cfKM/zZIOpJE9y+6s/XIVh6KfYgxd4+hYnHpF/KVt90pezLPFAL3afFZbpNzSUIXvqlWzX0hUHZVq7qrvSEoLSONEd+NYPDqwZQqXIr3736fjrEdUVmnLBA58rY7hYdDRg73Ignh3ewKMtuiKDi5Xv0KEeERDLxtIN/3+p7ryl5H508688DcBzh8+rDdoTmet90mI6NgndihThK68I1cr+5VbKVYvuv+HW+3fJslvywhZlwM05KmYdfZbyDwtttkDnwK5tG7ZpKELnxT0PF/Qa5QWCH639KfbX22UbNSTbp90Y17Zt/D/pP77Q7NkXLbnYJ99K6ZJKEL34TChU8GuKH8Dazqtor3Wr/Hqn2riB0Xy6Qtk6S2no3sTuaQTlEhTLL7j930/LIny/csp0X1Fky5dwrVy1b36W8zh/D99pu7eSKz5iqEdIoKYYPryl7Ht12/ZWLbiWw+uJma42vy/sb3cencx9+FytT4wniS0IU1zLgoyY4LnfwsUylFrwa9SO6XzG1Vb+PpJU/T/MPm/Jz6s9e/kYtrRL55u4TU7Idc+h9CzLie245rxAtYpsvl0h9u/VCXGVZGF3mjiH7nu3d0ekb6FcuFwtT4Iv/I5dJ/aUMX5jPjoiQ7LnQyqMxDpw/R96u+LPhxAY0rN2Zqu6nEVIwxuhgRpKQNXdjLjIuS7LjQyaAyry15LZ93/JzZD8zmlxO/UG9iPd5c8yZpGWmAjBAV+ScJXZjPjIuS7LjQqVw5w8pUStG5VmdSnkjhvpvuY+DygTSe0phtR7bJkD6Rb5LQjRKgMxFawluVs02b/G8zq6uxs2bBqVNXvh4ZWaAyKxWvxMftP+aThz7h0OlDxE2O47UVr9Gh0yW5uEb4z1vjutmPoOoUlUmc85Z9UvK+fQu+zay8Mae32+eVL29YEcfPHtddP+2qSUDXHFdTbzqwybB1i+CBdIqaTHqx/Odtm5UvD8ePWx5OnsLC3Ck8OxPmdf3qp6/ovbA3h88cpv8/+5PQPIGiEUUNLUMELukUNZvMROg/b9smNdWZzVUWttnfc8M9JPdLpnvd7oxYN4J6E+uxbv86w8uxi7ROmkcSuhFkJkL/5bZtnHgFjcVt9qWLlGZy/GS+fvhrLqRf4Napt/Lskmc5e+msKeVZRa6CNZckdCPIODP/5bZtnHhmY9PQk5bXt2RH3x30a9iP0RtHU3tCbVbuXWlqmWbWoOUqWJN5a1w3+xFUnaJaG99BZ2WHn13Kl8+5ozE8PLg/dz6t3LNSXz/6ek0Cuu/CvvrUhVOGl2F2/75cBVtw5NIpKgndiUJl1ExOnzP7Ixg/d058PICfvXRWP7fkOa0SlI4eGa2X/rLU0KK8DeapWtX/j5QTs9cfCiShB5pQ2uuzZpfw8ND53Fnl4wC+7rd1+qYxN2kS0N0/767/OP+HIUWZXYMOlbqKmSShB5pQPC+dOdN7LT2YP7fW+T6An087rwd8M0CHvx6ur/3vtfrLH78scFFW1CUCrTXRafFKQg80oVRD1zrvppdg/dyZCngA33xws641rpYmAd3lky76+Nnj+S4qWGrQRiVhJ24PSeiBxol7kZm8HcCC/XNnMuAAfjH9ov7Piv/oQoML6UpvV9Lzk+fnuyh/k6HTarBGfn2cWLeShB4Icro03knfEjN5qzZCcH/uTAZmoG1Htun6E+trEtDt57bXR88cNasoU9ZnBCOTsBNbPyWhWyW/VRUnfius5MRqkNUMrOamZaTpt9a8pSOHROryw8vrWdtnaZfLZUZRjvzXGZmEnfj5JKFbwdeknNO3yYl7jZVC/YBmkpTfU3STKU00Ceh7Z9+rD5w8YHgZRiRPo5tsvH2dwsP9X7cTd01J6FbwtXEyp73DW3NDsI/uyMppDbFWMflzp2ek63fXvauLvlFUl36rtJ76/dTLausF5VNdJJfPaNXdCQuybqftmpLQreBLVSW3qkMo19CdyIpvsYXVv5+O/6SbTW2mSUC3mtFK7/tznyHrzfMj5LGAWSenM2cG79dKEroVfNkzc+v8c9p5XSizKtFa3NSW4crQYzaO0cWHFtcl3iyhx28erzNcGQVeb67Hvjw+o5mdjk7s0DSCJHQr+JIEctu5nXZeF8qsSrR2ZJyZM3VaVGWdAXpPafSQXjfrX0/8al55eXxGMzd1sHZNSUK3Sl5J2Yk9LOJKVlz/ntvYe7MyTg7739kIdLcOEXrU+lE6PSPd+DLzyKpmfiWC9esmCd1JpCbufGZW7fK6KtbMjOPlcx2pUESTgL7lg1v0D8d+MLbMmTO1joy8vMzIyCs6Rs36SgTj100SuhD+MLNql1fN3MyM4+XMw6WUnp40XZcdVlYXHlJYD187XKdlpBlT5syZWkdEXF5mRERwZFabFDihA62BH4FfgAG5LPcgoIG4vNYpCV04mllVOzt76vI48zh06pC+b859mgR0w0kN9Y6jO0wvU/gvt4Se5x2LlFLhwFjgbiAG6KyUislhuZLAM8DG/N1qQwgDFfS2O126uG/w7XK5fxp1ZyI7b1eYx521ril5DZ8+9ClzHpzDnj/3UH9ifd5Y/QZpGWn5L1Put2spX25B1wj4RWu9W2t9CZgDtMthuSHAcOCCgfEJ4T8n37jSztsV+nAbPaUUHWt2JKVfCg/c/ACvrniVRlMasfXw1vyVGSj32w2WO1d7q7pnPoD2wJQsz7sCY7ItUx/4xPP7Srw0uQC9gEQgMTo62qITFBFynH6aH0A9dZ/t+kxf/c7VOvz1cD1w2UB9Ie2CfysIhKEm/sZo8/+PgrSh55XQcdfyVwLVdB4JPetD2tCFaYL1ihKbpJ5L1Y9+9qgmAR0zNkZvPLDRvxU4/QDmTwXAAQeo3BK6cr/vnVLqn0CC1vpfnucve2r2b3melwZ+Bc54/uRq4AQQr7VO9LbeuLg4nZjo9W0h8q9aNXczS3ZVq7rbw0W+LPp5Eb0X9ubQ6UM8/8/neb356xSNKGp3WAUXFuZOzdkp5e5DycoB+5ZSaovWOi6n93xpQ98M1FBKVVdKRQKdgAWZb2qtT2qtK2itq2mtqwEbyCOZC2EqO9upg1ibGm3Y2XcnPer14O11b1N3Yl2+++07u8MqOH/a+R3eyZtnQtdapwNPAkuBXcBcrXWyUmqwUire7ACF8JsPnX8if0oXKc2keyfxTddvuJRxiWb/a8Yzi5/h7KWzdoeWf/5UAJzeyeutLcbsh7ShCxHYTl88rZ9a9JQmAV19VHW9bPcyu0PKP1/b+R3ehu5Lk4sQgSFYhp4FiBKRJXjv7vdY3W014WHh3Dn9Tvos7MOpi6fsDs1/vl534PCzvzw7Rc0inaLCUJljz8+d+/u1YsUc9WULZufSzvHaitcYuWEklUtWZtK9k2j9j9Z2hxWUcusUdVRCT0tL48CBA1y4INcmWaFIkSJERUURERFhdygF54DRBwI2HNhA9y+6s+v4LrrV7ca7rd6lbNGydocVVAImoe/Zs4eSJUtSvnx5lFK2xBUqtNakpqZy+vRpqlevbnc4BefP0DNhqovpFxm8ajDDvxtOpeKVmNB2AvE3yvgJoxR02KJlLly4IMncIkopypcvHzxnQ04ffRBCChcqzNA7h7Kp5yYqFq9Iuznt6PJpF46fO253aEHPUQkdkGRuoaDa1jL23HHqX1OfzT03M7j5YOYlzyNmbAzzkudhV6tAKHBcQhciXxw++iBURYZH8urtr7Kl1xaqlqnKQ/Mfov289hw5c8Tu0IKSJPQC2Lt3L7Nnz/7reVJSEosWLfrr+YIFCxg2bJghZXXr1o358+cD8Pjjj5OSkmLIeoOKWVPeigKrdVUt1vdYz7A7h/HVT18ROy6WmdtnSm3dYJLQCyCvhB4fH8+AAQMML3fKlCnExFwxJb0QjlYorBAv3foSSX2SuLH8jXT9rCvxc+I5eOqg3aEFjUJ2B+DNs0ueJelIkqHrrHt1XUa1HpXrMtOnT+edd95BKUXt2rWZMWMG3bp1o23btrRv3x6AEiVKcObMGQYMGMCuXbuoW7cunTt3ZuzYsZw/f561a9fy8ssvc/78eRITExkzZgzdunWjVKlSJCYmcuTIEUaMGEH79u1xuVw8+eSTLF++nCpVqhAREUH37t3/KisnzZs355133iEuLo4SJUrwzDPPsHDhQooWLcoXX3zBVVddxbFjx+jTpw+/eeaYGDVqFE2bNjVsWwqRXzdVuIk1j63h/U3v88qyV4gZF8O7rd6le73uwdWvYwOpoWeRnJzMG2+8wfLly9m2bRujR4/Odflhw4bRrFkzkpKSeOmllxg8eDAdO3YkKSmJjh07XrH84cOHWbt2LQsXLvyr5v7pp5+yd+9eUlJSmDFjBuvXr/cr5rNnz9KkSRO2bdvGbbfdxuTJkwF45pln+L//+z82b97MJ598wuOPP+7XeoUwU3hYOM82eZYdfXdQ7+p6PP7l4/xr5r/Y92cO1xIInzm2hp5XTdoMy5cvp0OHDlSoUAGAcuXKGbr+++67j7CwMGJiYjh69CgAa9eupUOHDoSFhXH11Vdzxx13+LXOyMhI2rZtC0CDBg345ptvAPj2228va2c/deoUZ86coUSJEgZ9GhGwZs2CgQPdMwRGR7tHAtnU33B9uetZ/uhyJiZO5MVvX6Tm+JoMv2s4feL6EKakvukvxyZ0JylUqBAuz8UpLpeLS5cu5Ws9hQsX/ut3ozqDIiIi/jpNDQ8PJz09HXDHuWHDBooUKWJIOSJIZJ8iIfP2fGBbUg9TYfRt2Jc2NdrQ88uePLHoCeYmz2VK/BT+Ue4ftsQUqOQQmEWLFi2YN28eqampAJw4cQKAatWqsWXLFsA9ciUtzX3T3JIlS3L69Om//j77c180bdqUTz75BJfLxdGjR1m5cqUBnwRatWrF+++//9fzpKQkQ9YrAtzAgZfPdwPu5wMH2hNPFlXLVGXpw0v5IP4Dko4kUXt8bUauH0mGK8Pu0AKGJPQsYmNjGThwILfffjt16tThueeeA6Bnz56sWrWKOnXqsH79eooXLw5A7dq1CQ8Pp06dOowcOZI77riDlJQU6taty8cff+xTmQ8++CBRUVHExMTw8MMPU79+fUqXLl3gz/Lee++RmJhI7dq1iYmJYcKECQVepwgCDr9Bg1KK7vW6k9wvmRbVW/Dc18/R7H/N+OH4D3aHFhAcNZfLrl27uPnmm22Jx06Zbdupqak0atSI7777jquvvtqSskN1m4esAJrETGvN7B2zeXrJ05y9dJaE5gn0v6U/hcJCu6U4YOZyCVVt27albt26NGvWjFdffdWyZC5CUABNkaCUokvtLqT0S6HtDW15ednLNJnShO1Ht9sdmmOF9qHOIYxqNxciT5kdnw4Z5eKLq0pcxfyH5jMveR5PLHqCuElxDGw2kJebvUxkeKTd4TmK1NCFCDUBOkVCh9gOpDyRQofYDiSsSqDh5IZ8f/h7u8NyFEnoQoiAUaFYBWY9MIsvOn3BsbPHaDS5Ea8se4UL6UEyDXQBSUIXQgSc+BvjSe6XzCN1HuGttW9Rf2J9NhzYYHdYtpOELoQISGWLlmVqu6ks6bKEM5fO0HRqU/p/3Z9zaefy/uMgFdgJXe7yLkTI+9c//sXOfjvpVb8X/13/X+pMqMPqfavtDssWgZvQMy9h3rfPfS/JzEuYLUrq06ZNo0aNGtSoUYNp06ZZUqYQImelCpdifNvxLHtkGRmuDG7/8HaeXPQkZy6dsTs0SwVuQrfxEuYTJ07w+uuvs3HjRjZt2sTrr7/OH3/8YXq5Qojctajegh19d/B0o6cZt3kctcbX4tvd39odlmUCN6GbcAnz5s2bqV27NhcuXODs2bPExsayc+fOK5ZbunQpLVu2pFy5cpQtW5aWLVuyZMmSfJcrhDBO8cjijL57NGseW0NkeCQtZ7Sk15e9OHnhpN2hmS5wE7oJd3lv2LAh8fHxDBo0iBdffJGHH36YmjVrXrHcwYMHqVKlyl/Po6KiOHhQ7roihJM0jW5KUu8kXrjlBT7Y+gE1x9dk0c+L8v7DABa4Cd2kS5hfe+01vvnmGxITE3nxxRcLtC4hhL2KRhRlRMsRrO+xnlKFS3HP7Ht49PNHOXH+hN2hmSJwE7pJd3lPTU3lzJkznD59mgsXcr5YoXLlyuzfv/+v5wcOHKBy5coFKlcIYZ5GlRvxfa/vGdRsELO2zyJmbAyf7frM7rAMJ7MtZhMfH0+nTp3Ys2cPhw8fZsyYMVcsc+LECRo0aMD337svO65fvz5btmwx/A5HVnDCNhfCSlsPb6X7gu4kHUniodiHGHP3GCoWr2h3WD6T2RZ9NH36dCIiIvj3v//NgAED2Lx5M8uXL79iuXLlyvHqq6/SsGFDGjZsyGuvvRaQyVyIUFTvmnpsenwTb9zxBp/t+oyYcTHM2TnHsLuI2Ulq6CFOtrkIZcm/J/PYF4+x+dBm7rvpPsa1Gcc1Ja+xO6xcSQ1dCCFyEFsplnU91jHirhEs/nkxseNimb5tesDW1iWh52LHjh3UrVv3skfjxo3tDksIYaBCYYV4oekLbOuzjZiKMTz6+aPcM/se9p/cn/cfO4xPCV0p1Vop9aNS6hel1IAc3n9OKZWilNqulFqmlKpqfKjWq1WrFklJSZc9Nm7caHdYQggT3FjhRlY/tprRrUezat8qYsfFMnnL5ICqreeZ0JVS4cBY4G4gBuislIrJtthWIE5rXRuYD4wwOlAhhDBbmArj6cZPs6PvDuKujaPXwl60nNGSPX/ssTs0n/hSQ28E/KK13q21vgTMAdplXUBrvUJrnTmxygYgytgwhRDCOteVvY5vH/mWCfdMYNPBTdQaX4sxm8bg0i67Q8uVLwm9MpC1MemA5zVvegCLc3pDKdVLKZWolEo8duyY71EKIYTFwlQYveN6s7PfTm6NvpWnFj/F7R/ezs+pP9sdmleGdooqpR4G4oC3c3pfaz1Jax2ntY6rWLHgA/llOnQhhNmiS0ezuMti/tfuf+z8fSe1J9TmnXXvkOHKsDu0K/iS0A8CVbI8j/K8dhml1F3AQCBea33RmPC8s3k6dFq3bk2ZMmVo27atNQUKIWyjlKJb3W4k90um1fWteOGbF7hl6i2kHEuxO7TL+JLQNwM1lFLVlVKRQCdgQdYFlFL1gIm4k/nvxod5JRunQwfghRdeYMaMGdYUJoRwhGtLXsvnHT9n9gOz+fXEr9SbWI8317xJWkaa3aEBPiR0rXU68CSwFNgFzNVaJyulBiul4j2LvQ2UAOYppZKUUgu8rM4wJkyH7vN86AB33nknJUuWzH9hQoiApJSic63OpDyRQrsb2zFw+UAaT2nMtiPb7A6NQr4spLVeBCzK9tprWX6/y+C48hQd7W5myen1/Mo6H/r58+e9zocuhBCVildiboe5fJLyCf0W9SNuchyv3PoKA28bSGR4pC0xBeyVoiZNhy7zoQsh/PJgzIOk9Euhc83ODF49mAaTGpB4KDHvPzRBwCZ0k6ZD92k+dCGEyKp8sfJMv386X3b+khPnT9B4SmMGfDuAC+nW5pCATejgTt5794LL5f5Z0GQO0Lt3b4YMGUKXLl146aWXCr5CIUTIaHtDW5L7JfNY3ccY/t1w6k6oy7r96ywrP6ATutF8nQ8doFmzZnTo0IFly5YRFRXF0qVLLY5WCOFEZYqUYUr8FL5++GvOp5/n1qm3krAywZKyZT70ECfbXAjznL54mgHfDqBJVBO61ulqyDpzmw/dp1EuQggh/FeycEnG3jPWsvIkoedix44ddO16+VG1cOHCMoWuEMKRHJfQtdYopewOA/h7PvRgFUjzPAsh8uaoTtEiRYqQmpoqicYCWmtSU1MpUqSI3aEIIQziqBp6VFQUBw4cQKbWtUaRIkWIipKp64UIFo5K6BEREVSvXt3uMIQQIiA5qslFCCFE/klCF0KIICEJXQghgoRtV4oqpY4BOUyA65MKwHEDwzGKxOUfict/To1N4vJPQeKqqrXO8R6etiX0glBKJXq79NVOEpd/JC7/OTU2ics/ZsUlTS5CCBEkJKELIUSQCNSEPsnuALyQuPwjcfnPqbFJXP4xJa6AbEMXQghxpUCtoQshhMhGEroQQgQJRyd0pVRrpdSPSqlflFIDcnj/OaVUilJqu1JqmVKqqkPi6qOU2qGUSlJKrVVKxTghrizLPaiU0kopS4Zz+bC9uimljnm2V5JS6nEnxOVZ5iHPPpaslJrthLiUUiOzbKuflFJ/OiSuaKXUCqXUVs93so1D4qrqyQ/blVIrlVKWzEinlJqqlPpdKbXTy/tKKfWeJ+7tSqn6BS5Ua+3IBxAO/ApcB0QC24CYbMvcARTz/N4X+NghcZXK8ns8sMQJcXmWKwmsBjYAcU6IC+gGjHHg/lUD2AqU9Tyv5IS4si3/FDDVCXHh7ujr6/k9BtjrkLjmAY96fm8BzLBoH7sNqA/s9PJ+G2AxoIAmwMaClunkGnoj4Bet9W6t9SVgDtAu6wJa6xVa63OepxsAK468vsR1KsvT4oAVPc95xuUxBBgOXLAgJn/ispovcfUExmqt/wDQWv/ukLiy6gx85JC4NFDK83tp4JBD4ooBMu/2viKH902htV4NnMhlkXbAdO22ASijlLqmIGU6OaFXBvZneX7A85o3PXAf7czmU1xKqSeUUr8CI4CnnRCX55Suitb6Kwvi8Tkujwc9p53zlVJVHBLXDcANSqnvlFIblFKtHRIX4G5KAKrzd7KyO64E4GGl1AFgEe6zByfEtQ14wPP7/UBJpVR5C2LLi785Lk9OTug+U0o9DMQBb9sdSyat9Vit9fXAS8Agu+NRSoUB7wLP2x1LDr4EqmmtawPfANNsjidTIdzNLs1x14QnK6XK2BlQNp2A+VrrDLsD8egMfKi1jsLdnDDDs9/ZrT9wu1JqK3A7cBBwyjYzlBM2tjcHgaw1tSjPa5dRSt0FDATitdYXnRJXFnOA+8wMyCOvuEoCNYGVSqm9uNvsFljQMZrn9tJap2b5300BGpgck09x4a4xLdBap2mt9wA/4U7wdseVqRPWNLeAb3H1AOYCaK3XA0VwT0Jla1xa60Na6we01vVw5wq01n+aHJcv/M0lebOicyCfHQqFgN24TykzOztisy1TD3eHSA2HxVUjy+/3AolOiCvb8iuxplPUl+11TZbf7wc2OCSu1sA0z+8VcJ8el7c7Ls9yNwF78Vwc6JDttRjo5vn9Ztxt6KbG52NcFYAwz+9DgcFWbDNPedXw3il6D5d3im4qcHlWfbB8bow2uGtFvwIDPa8Nxl0bB/gWOAokeR4LHBLXaCDZE9OK3BKrlXFlW9aShO7j9nrLs722ebbXTQ6JS+FupkoBdgCdnBCX53kCMMyKePzYXjHAd57/YxLQyiFxtQd+9iwzBShsUVwfAYeBNNxnez2APkCfLPvXWE/cO4z4Psql/0IIESSc3IYuhBDCD5LQhRAiSEhCF0KIICEJXQghgoQkdCGECBKS0IUQIkhIQhdCiCDx/xOJ1wLZ9VXRAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出参数更新之前的结果\n",
- "w0 = w[0].data[0]\n",
- "w1 = w[1].data[0]\n",
- "b0 = b.data[0]\n",
- "\n",
- "plot_x = np.arange(0.2, 1, 0.01)\n",
- "plot_y = (-w0.numpy() * plot_x - b0.numpy()) / w1.numpy()\n",
- "\n",
- "plt.plot(plot_x, plot_y, 'g', label='cutting line')\n",
- "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n",
- "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n",
- "plt.legend(loc='best')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 1.4 torch.optim\n",
- "上面的参数更新方式其实是繁琐的重复操作,如果我们的参数很多,比如有 100 个,那么我们需要写 100 行来更新参数,为了方便,我们可以写成一个函数来更新,其实 PyTorch 已经为我们封装了一个函数来做这件事,这就是 PyTorch 中的优化器 `torch.optim`\n",
- "\n",
- "使用 `torch.optim` 需要另外一个数据类型,就是 `nn.Parameter`,这个本质上和 Variable 是一样的,只不过 `nn.Parameter` 默认是要求梯度的,而 Variable 默认是不求梯度的\n",
- "\n",
- "使用 `torch.optim.SGD` 可以使用梯度下降法来更新参数,PyTorch 中的优化器有更多的优化算法,在本章后面的课程我们会更加详细的介绍\n",
- "\n",
- "将参数 w 和 b 放到 `torch.optim.SGD` 中之后,说明一下学习率的大小,就可以使用 `optimizer.step()` 来更新参数了,比如下面我们将参数传入优化器,学习率设置为 1.0"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 使用 torch.optim 更新参数\n",
- "from torch import nn\n",
- "\n",
- "w = nn.Parameter(torch.randn(2, 1))\n",
- "b = nn.Parameter(torch.zeros(1))\n",
- "\n",
- "def logistic_regression(x):\n",
- " return torch.sigmoid(torch.mm(x, w) + b)\n",
- "\n",
- "optimizer = torch.optim.SGD([w, b], lr=1.)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 38,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "epoch: 200, Loss: 0.24529, Acc: 0.89000\n",
- "epoch: 400, Loss: 0.23901, Acc: 0.89000\n",
- "epoch: 600, Loss: 0.23409, Acc: 0.89000\n",
- "epoch: 800, Loss: 0.23013, Acc: 0.89000\n",
- "epoch: 1000, Loss: 0.22689, Acc: 0.89000\n",
- "\n",
- "During Time: 0.352 s\n"
- ]
- }
- ],
- "source": [
- "# 进行 1000 次更新\n",
- "import time\n",
- "\n",
- "start = time.time()\n",
- "for e in range(1000):\n",
- " # 前向传播\n",
- " y_pred = logistic_regression(x_data)\n",
- " loss = binary_loss(y_pred, y_data) # 计算 loss\n",
- " \n",
- " # 反向传播\n",
- " optimizer.zero_grad() # 使用优化器将梯度归 0\n",
- " loss.backward()\n",
- " optimizer.step() # 使用优化器来更新参数\n",
- " \n",
- " # 计算正确率\n",
- " mask = y_pred.ge(0.5).float()\n",
- " acc = (mask == y_data).sum().item() / y_data.shape[0]\n",
- " if (e + 1) % 200 == 0:\n",
- " print('epoch: {}, Loss: {:.5f}, Acc: {:.5f}'.format(e+1, loss.item(), acc))\n",
- "during = time.time() - start\n",
- "print()\n",
- "print('During Time: {:.3f} s'.format(during))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看到使用优化器之后更新参数非常简单,只需要在自动求导之前使用**`optimizer.zero_grad()`** 来归 0 梯度,然后使用 **`optimizer.step()`**来更新参数就可以了,非常简便\n",
- "\n",
- "同时经过了 1000 次更新,loss 也降得比较低了"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "下面我们画出更新之后的结果"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.legend.Legend at 0x7f36cec7e550>"
- ]
- },
- "execution_count": 33,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAzFElEQVR4nO3dd3wU1fr48c9JSOgIBBSvoYhfvJogNdgQURRF4WIDlQsKIl1Frz9REK83UhRQ77XQlN4UARsqitIEpEiA0IINBaUICIpUU/b5/bGJhpBNdpOZnZnN83699pXs7uycZ2dnnz1z5sw5RkRQSinlfVFOB6CUUsoamtCVUipCaEJXSqkIoQldKaUihCZ0pZSKEKWcKrhatWpSp04dp4pXSilPWr9+/S8iUj2/5xxL6HXq1CElJcWp4pVSypOMMbsCPadNLkopFSE0oSulVITQhK6UUhFCE7pSSkUITehKKRUhNKErpVSE0ISulFIRwnMJ/atfvuKpJU/xR+YfToeilFKu4rmE/sHXHzB8xXAav9aYtbvXOh2OUkq5hucS+oDmA/i488ccSz/GlZOv5LFPH+NExgmnw1JKKcd5LqEDtPm/Nmztt5WeTXry4uoXaTi+ISt2rXA6LKWUcpQnEzpApdKVGN9uPIvvXUyWL4urp17NQwse4lj6MadDU0opR3g2oedodX4rNvfdTP9L+zNm3RguGXcJi75f5HRYSikVdp5P6AAVYivw8k0vs/y+5cRGx9J6Rmt6fdCLI6eOOB2aUkqFTUQk9BxX1bqK1N6pDLhyAJM2TqL+uPp8/O3HToellFJhEVEJHaBsTFlGtR7F6vtXU6l0JW5+42a6vteVwycPOx2aUkrZKuISeo5Lz7uUDb02MLjFYGZtnkXi2ETe++o9p8NSSinbRGxCByhdqjTDWg1jXc91nFP+HG576zbunnc3B48fdDo0pZSyXEQn9ByNz23Mup7rGHbtMN7Z/g4JYxN4a+tbiIjToSmllGVKREIHiImOYfDVg9nYeyPnVz6fu9++mzvm3MHPx352OjSllLJEoQndGDPZGHPAGLM1wPPGGPOKMeY7Y8xmY0wT68O0TuLZiay6fxUjrx/Jgm8XkDAmgRmbZjhXW581C+rUgago/99Zs5yJQynlecHU0KcCbQp4/iagXvatFzCu+GHZq1RUKR5v/jib+mzi4uoXc+9799LuzXbs/n13eAOZNQt69YJdu0DE/7dXL03qSqkiKTShi8hyoKA+f7cA08VvDVDZGHOuVQHa6e/V/s7ybst5uc3LLNu5jMSxiUxYPyF8tfXBg+FEnoHFTpzwP66UUiGyog39POCnXPd3Zz92BmNML2NMijEm5eBBd/Q0iY6Kpv9l/dncZzNNzm1Crw970XpGa3b+ttP+wn/8MbTHVYmjLXIqFGE9KSoir4tIkogkVa9ePZxFF+qCqhew+N7FjGs7jrV71lJ/bH3GfDkGn/jsK7RWrdAeLyJNCt6kLXIqVFYk9D1AzVz347Mf85woE0WfpD5s67eN5rWa8+DHD3LN1Gv49tC39hQ4fDiUK3f6Y+XK+R+3iCYFa4Xzx1Fb5FTIRKTQG1AH2BrgubbAx4ABLge+DGadTZs2FTfz+XwyacMkOeu5s6TssLLy4qoXJTMr0/qCZs4UqV1bxBj/35kzLV197doi/lR++q12bUuLiVi5P564OJHY2NO3Y7lyln9kfzIm/8/OGHvKK4zNu6oKEpAigXJ1oCf+XADeBPYBGfjbx+8H+gB9sp83wBhgB7AFSCpsneKBhJ5jz+975B9v/ENIRi6bcJlsO7DN6ZBC4rak4CUzZ/oTdn7bLxw/jqH+GNuZcPPbFnb+mKnAipXQ7bp5JaGL+Gvrb2x+Q+JGxkns0Fh5dvmzkpGVEfqKHKjiaA296AJtu3D9OIaSRO1OuLofuYcmdIvsP7ZfOs7pKCQjTV9rKpt+3hT8ix2q4mjNqugCHd2EM6kFWwewO+HqkZ57aEK32Lxt8+Ts58+WmCExkrw0Wf7I/KPwFzlYxdG2z6IJpobulh9HuxOu1tDdo6CEXmLGcrHSHQl3kNYvjTsT7yT582SSXk9i/d71Bb/IwT7nnTvDzp3g8/n/du5se5ERIb9OSDExEBcHxkDt2vD66+7Ynnb3gA1DhyxlAU3oRRRXLo6Zt89k/t3zOXTyEJdNvIwnFz/JqcxT+b8gTH3OlXU6d/Yn7Nq1/0rgU6bAL7+478fR7oSb37Zwy4+ZyiVQ1d3um5ebXPL69eSv0v297kIyctHoi2TVj6vOXEgbs5XNtGmtZECbXOxVuUxlJt0yiYVdFnIi4wTNJzfn0YWPciIj11UhWsVRFgl0cZM2rSnjT/jhl5SUJCkpKY6UbaejfxzliUVPMC5lHBdUuYBJ7SfRsk5Lp8NSESLnyt/cV5CWK6d1g5LEGLNeRJLye05r6BarWLoiY9uOZWnXpQjCNdOu4YGPHuBY+jGnQ1MRwI3DAehYQe6hCd0m19S5hs19NvPIZY8wLmUc9cfW57MdnzkdlvI4tw3QqWMFuYsmdBuVjy3P/9r8j5XdV1KmVBlumHkDPef35MipI06HpjzKbZ2l3HjEUJJpQg+DK2teSWqfVJ5o/gSTUyeTODaRj775yOmwlAe5rT+4244YSjpN6GFSplQZRlw/gjX3r6FK2Sq0e7Md97x7D4dPFjQZlFKns7KzlBVt31Wrhva4spf2cnFAelY6w5cP59mVzxJXNo6xbcdy+8W3Ox2WKkGs6i1TrRocOnTm43Fx/guwlPW0l4vLxEbH8sy1z7Cu5zr+VvFv3DHnDu6adxcHjh9wOjQV4XJq5V26WNP2fTjAAWagx5W9NKE7qFGNRqztsZbhrYbz3lfvkTg2kdlbZ+PUUZNXaDe5osndIyWQUNu+3XaStqTThO6wmOgYnmzxJBt6baBulbp0ersTt711G3uP7nU6NFfSbnJFl1+PlLxCTcRuO0lbmIivDAQaE8DuWySN5WKVzKxMeeGLF6TMsDJSeURlmbJxivh8PqfDClo4xhLRYVyLrrDx3Ys6tJBXxpCJlOGU0PHQveXrX76WFpNbCMlIm5ltZNdvu5wOqVDh+rLoRAtFV9D47m5OxFaJlMpAQQldm1xc6MK4C1nWbRmv3vQqK3atoP7Y+ryW8pqr29bDdYGJttkWXaDmkb59/f/fc0/kNUPkbmIJdO4govrMB8r0dt+0hh6c7w9/L62mtRKSkVbTWsmOwzssXb9Vh8vhqjlHymGzU/J+3n37Ru72dHqSb7ugTS7e5vP55PWU16XisxWl3PBy8vKalyXLl1Xs9VqZHMN5OOuVNlsviJRmiPw4MYVgOPZNTegR4sfffpSbZt4kJCPNJzWXr3/5uljrs/LLHEk155L0gxHJ5yQKOglsx2cbru+AJvQI4vP5ZOrGqVJ5RGUpM6yMjFo5SjKzMou0Lqu/zJGQCJ3+YQrnNpw5UyQ6uuTV0O16b+EqTxN6BNr7+1655c1bhGTk0gmXytb9W0NeRyQfbheVk9sknD8mBbUve/XIKq9w/ziH62hHE3qE8vl8MnvLbKk2qprEDo2VYZ8Pk/TM9KBf73Rt1I2cbIKw+8ckd+0/UM08OjqyPv9wHvFoDV1Z4sCxA3LX3LuEZKTx+Maycd/GoF8bCc0kVnKyhm7nj0mwPT4ioe3cKW5oQ9d+6BGgevnqzO4wm3fufIe9R/fSbEIznl76NH9k/lHoa3Vi4dM5eSm7nX3sg7ns36qySipXzAMfKNPbfdMauj0OnTgk9757r5CMJI5JlC93f+l0SJ7j1FGLnTW8wi771+Y270Br6CVH1bJVmXbrND7650f8duo3Lp90OU989gQnM046HZpnOHXUYmcNL1DNOzrawdqkspxOcBHBjpw6wmOfPsbEjRO5MO5CJrefTPNazZ0OSznAqgktlPN0gosS6qwyZzGh/QQ+u+cz/sj8gxZTWvDIJ49wPP2406GpMHNF+66yndbQS4hj6ccYuGggY9aNoW6Vukz8x0SuPf9ap8NSSoVIa+iKCrEVGH3zaJZ1XYbB0Gp6K/p+2Jejfxx1OjSllEU0oZcwLeu0ZHPfzTx6+aO8tv416o+rz8LvFjodVsSI+BlxlKtpQi+BysWU48UbX+SL7l9QLqYcbWa1ofv73fn15K9Oh+ZpOj2ecpom9BLsippXsLH3RgZdNYjpm6aTODaRD77+wOmwPCtck3woFYgm9BKuTKkyPHvds6ztsZZq5arRfnZ7Or/TmUMnDjkdmucEmvkmombEUa6mCV0B0PRvTUnplUJyy2TmbJtDwtgE5qXNczosT9Hp8ZTTgkroxpg2xpivjTHfGWMG5vN8LWPMUmPMRmPMZmPMzdaHqmyR6yxe7AUX8p89/8f6XuupWakmHed2pOPcjuw/tt/pKD3ByXFglIIgEroxJhoYA9wEJACdjDEJeRZ7CpgjIo2Bu4GxVgeqbBDgLF6DRVtY02MNz133HPO/nk/C2ARmbZ6FU9cseIVTF+9ozxqVI5ga+qXAdyLyvYikA7OBW/IsI0Cl7P/PAvZaF2IJEu5vZgFn8UpFlWLgVQNJ7Z3KhXEX0uXdLtwy+xb2/L7H3pg8bNYs/yb98Ud/M8vw4eFJ5tqzRv0p0KhdOTegAzAx1/17gNF5ljkX2ALsBn4FmgZYVy8gBUipVauW/cOSeYkTs00EOQB3ZlamvLjqRSk7rKyc9dxZMmnDJPH5fPbF5UFOTRais06VPIRhtMVOwFQRiQduBmYYY85Yt4i8LiJJIpJUvXp1i4qOEE70eQvyLF50VDSPXvEom/psomGNhtw//37azGrDrt922RebxzjVZVF71tjDq81YwST0PUDNXPfjsx/L7X5gDoCIrAbKANWsCLDEcOKbGeJZvHpx9VjadSljbh7DFz9+Qf1x9Rm3bhw+8dkXo0c4lVi1Z431vNyMFUxCXwfUM8acb4yJxX/Sc36eZX4ErgMwxlyMP6EftDLQiOfEN7MIZ/GiTBT9mvVja7+tXBF/Bf0W9OO66dex4/AO++L0AKcSqxd61nittuvpC8QCtcXkvuFvRvkG2AEMzn5sCNA++/8E4AtgE5AK3FDYOnXGojysboQNw7Q7Pp9PJq6fKJWeqyTlhpeTl1a/JJlZmW4LMyxlOznhtpvnhfXiROROThQeDHSSaI+w6psZ5m/RT0d+kptn3SwkI1dOulK2H9zuxjBtL9vNidUpXjxp6/aYNaG7ndWZwIE90ufzyfTU6VJlRBUpPbS0jFgxQjKyMtwWpivK9rJQd1W313bz4/ajCk3obmbH3uPgt2jf0X1y+1u3C8lI0utJsmX/FjeG6clE47Si7Kpe/eF089GWJnQ3s2OPd/hb5PP5ZM7WOVJ9VHWJGRIjQ5YNkfTMdFeF6dVE46RA2yw6OnDic3tt14s0obuZHVVFl3yLDhw7IJ3mdRKSkYbjGsqGvRtcE6ZLNpGnBNpVC9uGbq7tepEmdDezq6room/Re9vfkxov1JDoZ6Jl8OLBcirjlCvCdNEm8oRAu6oe5YSXJnQ3K6yqGCFZ5/CJw9L13a5CMpIwJkHW7l7rdEgFyr3Z4+L8N49/BMWW366q5yHCTxO62wVK2hHYLrDgmwUS/994iXomSgZ8OkBOpJ9wOqQzFJa4PP4RFEvuXTU6WmvoobKifqYJ3asi9MzdkVNHpPcHvYVkpN4r9WTFrhVOh3SaYJoWPP4RWCIC6xu2smp7FZTQdcYiN4vQkZcqla7E+HbjWXTPIjJ8GVw95Wr6f9yf4+nHnQ4NCG7zevwjsIRT4797VTiGFNCE7mZWDRDi0sE0rqt7HVv6buGhSx9i9JejuWTcJSz5YYnTYQW1eXXwK7/OnWHnTvD5/H81mQcWjvqZJnQ3s2LkJZcPHVchtgIv3/Qyy+9bTqmoUlw3/Tp6f9Cb3//43bGY8tvsublt8CvlDWEZwC1QW4zdN21DD1Jxz6J4qB3+RPoJeWzhYxL1TJTE/zdePv72Y8diKcm9XCKkY5XrhKMNXRN6pPPgNe5rflojCWMShGSk23vd5PCJw06HdBovJrxgY/bqiU6vfCbay0UVj4dq6LlNnZ4ulc4+LJAlUZV/lH+NWud0SCLizYQXSsxe3F28+JkUhyb0ksyDe3u+/cBjjsnl/V+Rg8cPOhqbFxNeKDEXdHm/W2u+XvxMiqOghK4nRSOdB/uW5de9i4zyrJnWnoQxCczZNsdfG3GAF3uShhJzQSfoXHY+/U9u+UwK6kwWto5mgTK93TetoatAAjf7+yTp9SQhGbn9rdtl39F9YY/Ni7XBUGIO5vJ+t71XN3wmBR0IW32QjDa5KC8p6AuakZUhI1eOlNJDS0uVEVVkeup08fl8YYvNgy1YIcecc+IuUEJ32/l0N3wmBe2zVv/gaEJXnhLMF/Srg1/JlZOuFJKRtrPayk9HfgprfF7oUZFbUWJ2Q803WE5/JgV1JrO6o5kmdOU5wXxBM7My5aXVL0nZYWWl0nOVZML6CWGtrUc6N9R8vUJr6JrQvcPp6k8hvjv0nVwz9RohGbl++vXyw68/OB1S0Fy+aV0fn1toG7omdG/wSDUty5cl49aNkwrPVpDyw8vL6LWjJcuX5XRYBfLIplVBKujHz8ofRk3oqui81JAqIjt/3Sk3zLhBSEaunnK1fHvoW6dDCshjm9YzrEqebj06KSihG//z4ZeUlCQpKSmOlK1CEBXlzzN5GeMfYs+FRISpqVP518J/kZ6VzvBWw+l/WX+io6KdDu00Hty0rpczFl3u6xjKlQv90gur1mMHY8x6EUnK7zm9sCiS2HH1QliGiCu+3G/9/PMNsWn3kfZAGtfXvZ5HP32Uq6ZcxfaD250O8zQe2bSeYtWY4+EYu9wWgarudt+0ycVidjXIeqCht6AQfT6fzNo8S6qOrCqxQ2Pl2eXPSkZWhtMhi4gnNq3nWNVF0M1j2qFt6CWAnQ2ybm1MzBbMW//56M/SYU4HIRlp+lpT2fTzJqfCPY3LN63nWPU1cPP5jYISurahR4oS3CAbyluflzaPfh/147dTvzG4xWAGtRhEbHRseAJVttM2dBUZIq1BNoTzAaG89Q4JHUh7II2OiR1J/jyZpNeTWL93vSUhK+dZNRadB8e08wtUdbf7pk0uFoukBtkQ30tR3/r7X70v575wrkQ/Ey2DFg2SkxknbXgzSlkLbUMvISKlQbYIDZhFfeu/nvxV7nvvPiEZuWj0RbLqx1UWvAGl7FNQQtc2dOU+DpwPWPjdQnp+0JPdv+/mX5f/i6GthlIupoCZopVyiLahK29x4HzAjf93I1v7baVPUh/+u+a/NBjXgM93fm5beUrZQRO6cp/hw/1dCnIrV87/uI0qla7E2LZjWXLvEgThmmnX8OCCBzmWfszyssI2g41yhXB93prQlfvkdDGIi/vrsbJlw1b8tedfy+Y+m3n4socZu24s9cfWZ9H3iyxbf06XuF27/C1Lbp3aTVkjnJ+3JnQ30Wrb6U6e/Ov/Q4fCmvXKx5bnpTYvseK+FZQuVZrWM1rTc35Pjpw6Uux1e/ayclUkYf28A50ttfumvVzyiKRuhwUJtjuKiy7VO5F+Qh7/9HGJeiZKznvxPPnw6w+LtT43X1aurBfOGYuCqqEbY9oYY742xnxnjBkYYJk7jTFpxphtxpg3LP3VKQlKQrUtlGNPC6dyL+6BT9mYsoxsPZI196+hcpnKtHuzHfe+ey+HTx4OORaIvGvAVMHC+nkHyvQ5NyAa2AHUBWKBTUBCnmXqARuBKtn3zy5svVpDz6MkVNtCqXVbVEO3+sDnVMYpeXrJ01JqSCk55/lz5J20d0JeR0k5GFN+rpqxCLgCWJjr/iBgUJ5lRgE9CltX7psm9Dxc1MRgm1B+tCz6Fti1WTfu2yiNxzcWkpE7594p+4/tD+n1kXINmAqOa2YsAjoAE3PdvwcYnWeZ97KT+hfAGqBNgHX1AlKAlFq1ahX9HUWiklBtCzW7WvAtsPPAJz0zXYZ9Pkxih8ZKtVHV5M0tb+ok1cp2BSV0q3q5lMpudrkG6ARMMMZUzruQiLwuIkkiklS9enWLio4Qnh0NKASh9i/v3Bl27vRfHbpzZ5G2hZ3tlzHRMQy+ejAbem2gbpW6dHq7E7e9dRv7ju4r/sqVKoJgEvoeoGau+/HZj+W2G5gvIhki8gPwDf4Er0JhQQJzNQd+tMJxjVLi2Yl80f0Lnm/9PAt3LCRhbALTUqflHJWqEGnv3WIIVHWXv5pJSgHfA+fz10nRxDzLtAGmZf9fDfgJiCtovdqGrsIlnO3VX//ytVw1+SohGblp5k3y428/2ldYBCoJLY/FRXGaXEQkE3gQWAhsB+aIyDZjzBBjTPvsxRYCh4wxacBSYICIHLLsV0e5g0erTuE88Lkw7kI+7/Y5r7R5hc93fU7i2EReX/+61tbzkd/uVBJ679pJR1tUwXHzFC4u9f2v39Pzg54s+WEJrc5vxYR/TKBulbpOh+UKgXanvMk8RwmYeCtoBY22qAldBadOHf+FQHnVru2v9qp8iQgTNkzgsU8fI0uyGHHdCB649AGiTMkedSPQ7hQdDVlZZz6uu9lfdPhcVXwWXrlZkhhj6NW0F9v6bePq2lfT/5P+tJzakm8OfeN0aI4KtNtkZTky0GbE0ISugqPXqxdLzbNqsuCfC5h6y1S2HthKw/ENeWHVC2T58qmOlgCBdpucjk+R3HvXTprQVXAcGqM8khhj6NqoK2n90rjxghsZ8NkArpx8JWkH05wOLewK2p0ivfeunTShq+CUhAufwuTciufy7l3v8uYdb7Lj8A4av9aY4cuHk5GV4XRoYaO7kz30pKhSDjpw/AAPffwQc7bNoXGNxky5ZQoNazR0OizlYnpSVCmXOrv82bzV4S3evvNt9h7dS9KEJP6z9D9Mm5HpxS7/ymGlnA5AKQW3X3w7LWu35JGFjzBk9LeYDzOQdP/XM2fYeNAmCVUwraGr8LDjKlMnrly1scy4cnHMuG0G1ddMRNJPn0NVr5ZUwdAaurJf3ssCrahy2rFOl5T5y75y+T6uXf5VYfSkqLKfHVeZOnHlapjKDFRMxbMPs++n0pSPLW9ZWcp79KSocpYdV5k6ceVqmMrMr492qdJ/cPSqB2kwvgHLdi6ztDwVOTShK/vZcZWpE1euVq0aljLz66M9dVJplr3YG4Ph2mnX0u+jfhz946il5Srv04Su7BfossCbby76CcZwX7k6axb8/vuZj8fG2lJmfldLtqzTks19N/Po5Y8yPmU89cfV59Mdn1petvKwQAOl232LuAkudNbfguXdPn37Fn8mg3Bu80DzocbF2VdmAVb9uEouGn2RkIx0f6+7/HryV0fiUOFHARNc6ElRK+hY4aHz2nC8UVH+FJ6XgwN1n8o8xTPLnuH5Vc9zToVzGN92PP/4+z8ciUWFj54UtZtOsxK6QCcS80vybuDC0SbLlCrDc9c/x5oea4grG0f72e3p8k4XDp3QycJKKk3oVtCxwkMXKBEa487r3F082mTS35JI6ZVCcstk3tr2FgljE3g77W2nwwrIozMZeoImdCu4sPbmesOH+5N3XiLuPLJx+fCAsdGx/Oea/7C+13riK8XTYW4HOs7tyIHjB5wO7TQ5rZO7dvk/6pxrszSpWyRQ47rdt4g6KapTlRdNficZwX+SUxVZRlaGPLv8WYkdGitxI+Nk1uZZ4vP5gn69neeaA51brl3bujIiHQWcFNUauhXsqL2VhOPS2rXzfzwqKrLft81KRZViUItBpPZOpV5cPTq/05lb37qVvUf3Fvpau2vQ2jpps0CZ3u5bRNXQrVZSavz5vc+8t0h83/mxqVqcmZUpL656UcoMKyNnPXeWTN4wWWbO9AUsyu4atNbQi48Cauia0N2oJO31uRNZdHTJed+5heEH/JtfvpGrp1wt3N5JomJPBizKGHtbwUpKXcVOmtC9xu5vlRvNnBm4lh7J71skbD/gWb4sqVLjSIFFhSMUr12D57Z4C0ro2obuRiWt10xOw20gkfq+c4SpYTnKRPHb/koFFhWO3pnhmATaqlNQnuuVEyjT233TGnoBStpxaaBqYaS/7xxhbGILVFSVGkckMytTREKvkbqtBmvl18eNrZ9ok4sH5DfWiZu+JXYK1MQEkf2+c4TxBzy/oqJiTwq3d5IrJ10pXx38yq2hB83KJOzG1k9N6OFS1KqKG78V4eTGalC4hbGae2ZRPpmeOl2qjKgipYeWlpErR0pGVkZQ63LjR2dlEnbj+9OEHg7FScpu3GvCqaT/oLnE3t/3yq2zbxWSkWavN5Mt+7cU+horkqfVv2WBvk7R0aGv2427pib0cAg2Kee397rxuC7c3NYQW0L5fD6ZvWW2VBtVTWKGxMjQz4dKemZ6wOWLWxexI2EWdHlDUdbttl1TE3o4BJOUA+29cXElu4buRuH6FrstW2Q7cOyA3DX3LiEZaTS+kWzYuyHf5YJKyAW8R7sOTmfOjNzLGjShh0Mwe2ZBkyS47biuJAvXcbYbj+fzeHf7u1LjhRoS/Uy0DF48WE5lnDpjmQJ/kwp5j3YenEbqga8m9HAI5stZ0B7m0ppaiRSucxoeOXdy6MQh6fpuVyEZSRiTIGt3rw3+xYW8Rzs3gUc2b8g0oYdLYUk5UvewSBOuqp0TVchiVBw++uYjif9vvEQ9EyUDPh0gJ9JPFP6iQt6jnQcpHjgAKhJN6G4RqXtYpLH7hzcnqeZXhp0/8Bbsf7+d/E16zu8pJCMXvnqhrNy1suAXBLEt7Tw4jcQDX03obhKJe1ikCXe1MVw/8Bb+UH224zOp81IdMclG+i/oL8f+OJb/gjNnisTGnl5ebKzu98VQUELXsVzCLRwDWajisXN2ovzmn81h9yxIFo4Zc33d69nSdwsPXvogr3z5CpeMu4QlPyzJf2GRgu8ryxgJYuMaY9oALwPRwEQRGRFguTuAeUAzEUkpaJ1JSUmSklLgIkpFnqio/BOaMf4feTvVqZP/JNy1a/srF0W0YtcKus/vzneHv6N3096Maj2KSqUr2VpmSWaMWS8iSfk9V2gN3RgTDYwBbgISgE7GmIR8lqsIPAysLV64SlnArTM+OTmSpk1DKbao3YJNfTbx/674f0zYMIH6Y+vzyXef+J/UKYrCqlQQy1wKfCci3wMYY2YDtwBpeZYbCowEBhQ1mIyMDHbv3s2pU6eKugoVgjJlyhAfH09MTIzToVgrZ8zTnKaNnDFPwfkmruHDT48NrB+fNpCc9z54sD+h1qrlL9eCbVIuphwv3PACHRI60P397tw06ya6NerGxPjziP5p95kvcNuQyLNm2bJdwi5Q43rODeiAv5kl5/49wOg8yzQB3s7+fxmQFGBdvYAUIKVWrVpnNPZ///33cvDgwZAmtFVF4/P55ODBg/L99987HYr13N49NMJPjJ/KOCVPLnpSop+Jlr6dzpKMMqXd3bMr1JPgDn9+FKeXS2EJHX+zzTKgjhSS0HPf8uvlkpaWpsk8jHw+n6SlpTkdhvUi9RJBj1m/d700GNdAOt2OHKhWTnxu/QELpQLggq7HBSX0YHq57AFq5rofn/1YjopAfWCZMWYncDkw3xiTb6N9YYwxRXmZKoKI3dYlbcYnl2pybhPW9VzHxf2HcF7/DM4ZWY25Hz2P/POfTod2ulDa+fPrpXTihP9xFwgmoa8D6hljzjfGxAJ3A/NznhSRIyJSTUTqiEgdYA3QXgrp5aKUbcIxj5oKSmx0LP9u+W/W91pP7cq1uXPenXSY24Gfj/3sdGh/CaUC4PKTvIUmdBHJBB4EFgLbgTkiss0YM8QY097uAN1s586dvPHGG3/eT01NZcGCBX/enz9/PiNG5NvDM2TdunVj3rx5APTo0YO0tLznpNWf7OxHrorkknMuYfX9qxlx3Qg++uYjEscmMnPzzJxmW2eFUgFw+9FfoLYYu2+B2tC9ZOnSpdK2bds/70+ZMkUeeOABW8rq2rWrzJ071/L1em2bK+/bfnC7XDHxCiEZafdGO9l9ZLfTIQV/otPlbejBdFt0xCOfPELqz6mWrrNRjUa81OalApeZPn06L7zwAsYYGjRowIwZM+jWrRvt2rWjQ4cOAFSoUIFjx44xcOBAtm/fTqNGjejUqRNjxozh5MmTrFy5kkGDBnHy5ElSUlIYPXo03bp1o1KlSqSkpPDzzz8zatQoOnTogM/n48EHH2TJkiXUrFmTmJgYunfv/mdZ+bnmmmt44YUXSEpKokKFCjz88MN8+OGHlC1blvfff59zzjmHgwcP0qdPH37MPhR86aWXaN68uWXbUqmiuqjaRay4bwWvfvkqTy5+koSxCfz3hv/SvXF3587rdO4c3BGcjV0/raCX/ueybds2hg0bxpIlS9i0aRMvv/xygcuPGDGCFi1akJqayhNPPMGQIUO46667SE1N5a677jpj+X379rFy5Uo+/PBDBg4cCMA777zDzp07SUtLY8aMGaxevTqkmI8fP87ll1/Opk2buPrqq5kwYQIADz/8MP/6179Yt24db7/9Nj169AhpvZ7k1ouJ1Bmio6J55PJH2NJ3C41rNKbHBz24ceaN7Potn6tK3cbFw3e4toZeWE3aDkuWLKFjx45Uq1YNgKpVq1q6/ltvvZWoqCgSEhLYv38/ACtXrqRjx45ERUVRo0YNrr322pDWGRsbS7t27QBo2rQpn332GQCLFi06rZ39999/59ixY1SoUMGid+Mybr6YSAV0QdULWNJ1Ca+lvMbjix6n/rj6jLx+JH2S+hBltL4ZKt1iQShVqhS+7HE2fD4f6enpRVpP6dKl//xfLDoZFBMT8+dhanR0NJmZmYA/zjVr1pCamkpqaip79uyJ3GQOru9OpgKLMlH0bdaXrX23ckX8FTyw4AFaTWvFd4e/czo0z9GEnkurVq2YO3cuhw4dAuDw4cMA1KlTh/Xr1wP+nisZGRkAVKxYkaNHj/75+rz3g9G8eXPefvttfD4f+/fvZ9myZRa8E7jhhht49dVX/7yfmppqyXpdy+XdyVThaleuzcIuC5nUfhKpP6fSYFwD/rf6f2T5spwOzTM0oeeSmJjI4MGDadmyJQ0bNuTRRx8FoGfPnnz++ec0bNiQ1atXU758eQAaNGhAdHQ0DRs25H//+x/XXnstaWlpNGrUiLfeeiuoMu+44w7i4+NJSEigS5cuNGnShLPOOqvY7+WVV14hJSWFBg0akJCQwPjx44u9Tldze3cyFRRjDN0bd2dbv220Or8Vj376KC2mtOCrX75yOjRvCNT9xe5bJHRbtMrRo0dFROSXX36RunXryr59+8JWdsRscxd0J1PW8vl8MnPTTKk6sqqUHlpanlvxnGRkZTgdluPQCS7crV27djRq1IgWLVrw73//mxo1ajgdkvfoxUQRxxhD5wadSeuXRrsL2zFo8SAun3g5m/dvdjo01wpqggs75DfBxfbt27n44osdiaek0m2uvGLutrk8sOABfjv1G4NbDGZQi0HERsc6HVbYFWuCC6WUcoOOiR1JeyCNjokdSf48mWYTmrFh3wanw3IVTehKKc+oVq4as26fxft3v8/B4we5dMKlPLn4SU5l6qQ4oAldKeVB7f/enm39tnFvw3t5buVzNHmtCWt2r3E6LMdpQleqpImQIRKqlK3C5Fsm80nnTziWfozmk5vz2KePcSLjROEvjlDeTugRsmMqFTY5QyTs2uXv3JkzRIKHvzs3/t+NbO23lV5NevHi6hdpOL4hy3ctdzosR3g3oTu8Y06bNo169epRr149pk2bFpYylSq2CB0ioVLpSoxrN47F9y4my5dFy6kteXDBgxxLP+Z0aGHl3W6Lder4k3hetWv7R0Cz0eHDh0lKSiIlJQVjDE2bNmX9+vVUqVLF1nLtoN0WS5ioKH8FKC9j/KMHRoDj6cd5cvGTvPrlq9SuXJsJ/5jA9XWvdzosy0Rmt0Ubxu5Yt24dDRo04NSpUxw/fpzExES2bt16xnILFy6kdevWVK1alSpVqtC6dWs++eSTIperVNiUgCESyseW5+WbXmbFfSuIjY6l9YzW9PqgF0dOHXE6NNt5N6HbsGM2a9aM9u3b89RTT/H444/TpUsX6tevf8Zye/bsoWbNv+bNjo+PZ8+ePWcsp5TrlKD5VpvXak5q71QGXDmASRsnUX9cfRZ8u6DwF3qYdxO6TTvm008/zWeffUZKSgqPP/54sdallOuUsCESysaUZVTrUay+fzWVSlei7Rtt6fpeVw6fPOx0aLbwbkK3acc8dOgQx44d4+jRo5w6lf/FCueddx4//fTTn/d3797NeeedV6xylQobF8+4Y5dLz7uUDb028FSLp5i1eRYJYxJ4d/u7TodlOe+eFLVJ+/btufvuu/nhhx/Yt28fo0ePPmOZw4cP07RpUzZs8F923KRJE9avX2/5DEfh4IZtrlQ4bdy3ke7zu5P6cyp3Jt7J6JtGU718dafDClpknhS1wfTp04mJieGf//wnAwcOZN26dSxZsuSM5apWrcq///1vmjVrRrNmzXj66ac9mcyVKokan9uYL3t8ybBrh/Hu9ndJGJvA7K2zLZtFzElaQy/hdJurkmzbgW3c9/59rNu7jlsvupWxN4/l3IrnOh1WgbSGrpRS+Ug8O5FV969i1PWj+Pjbj0kcm8j0TdM9W1vXhF6ALVu20KhRo9Nul112mdNhKaUsVCqqFAOaD2BTn00kVE+g63tdaftGW3468lPhL3aZUk4H4GaXXHJJ5E+urJQC4O/V/s7y+5Yz+svRDFo8iMSxibx4w4v0aNIDY4zT4QVFa+hKKZUtykTR/7L+bOm7haS/JdHrw160ntGaH379wenQgqIJXSml8qhbpS6L7l3E+Lbj+XLPl1wy7hJGfzkan7h7vBtN6EoplY8oE0XvpN5s7beVq2pdxUMfP0TLqS359tC3TocWkKcTug6HrpSyW62zavFx54+ZcssUth7YSoPxDXhh1Qtk+bKcDu0Mnk3oTo/T36ZNGypXrky7du3CU6BSyjHGGLo16sa2ftu44YIbGPDZAK6cfCVpB9OcDu00nk3oTo/TP2DAAGbMmBGewpRSrvC3in/jvbve443b32DH4R00fq0xz654loysDKdDAzyc0G0YDj3o8dABrrvuOipWrFj0wpRSnmSModMlnUh7II1b/n4Lg5cM5rKJl7Hp501Oh+bdhG7HOP3BjoeulFJnlz+bOR3nMK/jPPYc3UPShCT+s/Q/pGelOxaTZxO6XeP063joSqlQ3JFwB2n90uhUvxNDlg+h6etNSdmbUvgLbeDZhG7XOP3BjIeulFK5xZWLY/pt0/mg0wccPnmYyyZexsBFAzmVGd4c4tmEDvaM09+7d2+GDh1K586deeKJJ4q/QqVUidHuwnZs67eN+xrdx8gvRtJofCNW/bQqbOV7OqFbLdjx0AFatGhBx44dWbx4MfHx8SxcuDDM0Sql3KhymcpMbD+RT7t8ysnMk1w1+SqSlyWHpeygxkM3xrQBXgaigYkiMiLP848CPYBM4CDQXUR2FbROHQ/dHXSbK2Wfo38cZeCigVwefzn3NLzHknUWNB56oaMtGmOigTFAa2A3sM4YM19Ecveo3wgkicgJY0xfYBRwV/FDV0op76pYuiJj2o4JW3nBDJ97KfCdiHwPYIyZDdwC/JnQRWRpruXXAF2sDNIpW7Zs4Z57Tv9VLV26NGvXrnUoIqWUCiyYhH4ekHuk991AQbM83A98nN8TxpheQC+AWgE6jIuIa8YejvTx0L06K4tSKn+WnhQ1xnQBkoDn83teRF4XkSQRSape/cxZtsuUKcOhQ4c00YSBiHDo0CHKlCnjdChKKYsEU0PfA9TMdT8++7HTGGOuBwYDLUXkj6IEEx8fz+7duzl48GBRXq5CVKZMGeLj450OQyllkWAS+jqgnjHmfPyJ/G7gn7kXMMY0Bl4D2ojIgaIGExMTw/nnn1/UlyulVIlWaJOLiGQCDwILge3AHBHZZowZYoxpn73Y80AFYK4xJtUYM9+2iJVSSuUrqEmiRWQBsCDPY0/n+v96i+NSSikVIr1SVCmlIkRQV4raUrAxB4ECryYtQDXgFwvDsYrGFRqNK3RujU3jCk1x4qotImd2E8TBhF4cxpiUQJe+OknjCo3GFTq3xqZxhcauuLTJRSmlIoQmdKWUihBeTeivOx1AABpXaDSu0Lk1No0rNLbE5ck2dKWUUmfyag1dKaVUHprQlVIqQrg6oRtj2hhjvjbGfGeMGZjP848aY9KMMZuNMYuNMbVdElcfY8yW7GEQVhpjEtwQV67l7jDGiDEmLN25gthe3YwxB7O3V6oxpocb4spe5s7sfWybMeYNN8RljPlfrm31jTHmN5fEVcsYs9QYszH7O3mzS+KqnZ0fNhtjlhljwjIinTFmsjHmgDFma4DnjTHmley4NxtjmhS7UBFx5Q3/dHc7gLpALLAJSMizzLVAuez/+wJvuSSuSrn+bw984oa4sperCCzHPxFJkhviAroBo124f9XDPxtXlez7Z7shrjzLPwRMdkNc+E/09c3+PwHY6ZK45gJds/9vBcwI0z52NdAE2Brg+Zvxzx1hgMuBtcUt08019D9nShKRdCBnpqQ/ichSETmRfXcN/qF93RDX77nulgfCcea50LiyDQVGAqfCEFMocYVbMHH1BMaIyK8AUoyRRC2OK7dOwJsuiUuAStn/nwXsdUlcCUDObO9L83neFiKyHDhcwCK3ANPFbw1Q2RhzbnHKdHNCz2+mpPMKWD7gTEkWCyouY8wDxpgd+OdX7e+GuLIP6WqKyEdhiCfouLLdkX3YOc8YUzOf552I60LgQmPMF8aYNdmTpbshLsDflACcz1/Jyum4koEuxpjd+Afze8glcW0Cbs/+/zagojEmLgyxFSbUHFcoNyf0oBU2U5ITRGSMiFwAPAE85XQ8xpgo4L/A/3M6lnx8ANQRkQbAZ8A0h+PJUQp/s8s1+GvCE4wxlZ0MKI+7gXkikuV0INk6AVNFJB5/c8KM7P3OaY8BLY0xG4GW+Od1cMs2s5QbNnYgoc6U1F6KOFOSHXHlMhu41c6AshUWV0WgPrDMGLMTf5vd/DCcGC10e4nIoVyf3USgqc0xBRUX/hrTfBHJEJEfgG/wJ3in48pxN+FpboHg4rofmAMgIquBMvgHoXI0LhHZKyK3i0hj/LkCEfnN5riCEWouKVw4Tg4U8YRCKeB7/IeUOSc7EvMs0xj/CZF6LourXq7//wGkuCGuPMsvIzwnRYPZXufm+v82YI1L4moDTMv+vxr+w+M4p+PKXu4iYCfZFwe6ZHt9DHTL/v9i/G3otsYXZFzVgKjs/4cDQ8KxzbLLq0Pgk6JtOf2k6JfFLi9cb6yIG+Nm/LWiHcDg7MeG4K+NAywC9gOp2bf5LonrZWBbdkxLC0qs4Ywrz7JhSehBbq/nsrfXpuztdZFL4jL4m6nSgC3A3W6IK/t+MjAiHPGEsL0SgC+yP8dU4AaXxNUB+DZ7mYlA6TDF9SawD8jAf7R3P9AH6JNr/xqTHfcWK76Peum/UkpFCDe3oSullAqBJnSllIoQmtCVUipCaEJXSqkIoQldKaUihCZ0pZSKEJrQlVIqQvx/vHDJPHFjoEsAAAAASUVORK5CYII=\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# 画出更新之后的结果\n",
- "w0 = w[0].data[0]\n",
- "w1 = w[1].data[0]\n",
- "b0 = b.data[0]\n",
- "\n",
- "plot_x = np.arange(0.2, 1, 0.01)\n",
- "plot_y = (-w0.numpy() * plot_x - b0.numpy()) / w1.numpy()\n",
- "\n",
- "plt.plot(plot_x, plot_y, 'g', label='cutting line')\n",
- "plt.plot(plot_x0, plot_y0, 'ro', label='x_0')\n",
- "plt.plot(plot_x1, plot_y1, 'bo', label='x_1')\n",
- "plt.legend(loc='best')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看到更新之后模型已经能够基本将这两类点分开了"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 1. 5 PyTorch的Loss函数\n",
- "前面我们使用了自己写的 loss,其实 PyTorch 已经为我们写好了一些常见的 loss,比如线性回归里面的 loss 是 `nn.MSE()`,而 Logistic 回归的二分类 loss 在 PyTorch 中是 `nn.BCEWithLogitsLoss()`,关于更多的 loss,可以查看[文档](http://pytorch.org/docs/0.3.0/nn.html#loss-functions)\n",
- "\n",
- "PyTorch 为我们实现的 loss 函数有两个好处,第一是方便我们使用,不需要重复造轮子,第二就是其实现是在底层 C++ 语言上的,所以速度上和稳定性上都要比我们自己实现的要好\n",
- "\n",
- "另外,PyTorch 出于稳定性考虑,将模型的 Sigmoid 操作和最后的 loss 都合在了 `nn.BCEWithLogitsLoss()`,所以我们使用 PyTorch 自带的 loss 就不需要再加上 Sigmoid 操作了"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 35,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 使用自带的loss\n",
- "criterion = nn.BCEWithLogitsLoss() # 将 sigmoid 和 loss 写在一层,有更快的速度、更好的稳定性\n",
- "\n",
- "w = nn.Parameter(torch.randn(2, 1))\n",
- "b = nn.Parameter(torch.zeros(1))\n",
- "\n",
- "def logistic_reg(x):\n",
- " return torch.mm(x, w) + b\n",
- "\n",
- "optimizer = torch.optim.SGD([w, b], 1.)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 118,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(0.6314)\n"
- ]
- }
- ],
- "source": [
- "y_pred = logistic_reg(x_data)\n",
- "loss = criterion(y_pred, y_data)\n",
- "print(loss.data)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 39,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "epoch: 200, Loss: 0.22419, Acc: 0.89000\n",
- "epoch: 400, Loss: 0.22191, Acc: 0.89000\n",
- "epoch: 600, Loss: 0.21997, Acc: 0.89000\n",
- "epoch: 800, Loss: 0.21830, Acc: 0.88000\n",
- "epoch: 1000, Loss: 0.21685, Acc: 0.88000\n",
- "\n",
- "During Time: 0.215 s\n"
- ]
- }
- ],
- "source": [
- "# 同样进行 1000 次更新\n",
- "\n",
- "start = time.time()\n",
- "for e in range(1000):\n",
- " # 前向传播\n",
- " y_pred = logistic_reg(x_data)\n",
- " loss = criterion(y_pred, y_data)\n",
- " \n",
- " # 反向传播\n",
- " optimizer.zero_grad()\n",
- " loss.backward()\n",
- " optimizer.step()\n",
- " \n",
- " # 计算正确率 0.5以上的判断为正确\n",
- " mask = y_pred.ge(0.5).float() \n",
- " acc = (mask == y_data).sum().item() / y_data.shape[0]\n",
- " if (e + 1) % 200 == 0:\n",
- " print('epoch: {}, Loss: {:.5f}, Acc: {:.5f}'.format(e+1, loss.item(), acc))\n",
- "\n",
- "during = time.time() - start\n",
- "print()\n",
- "print('During Time: {:.3f} s'.format(during))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "可以看到,使用了 PyTorch 自带的 loss 之后,速度有了一定的上升,虽然看上去速度的提升并不多,但是这只是一个小网络,对于大网络,使用自带的 loss 不管对于稳定性还是速度而言,都有质的飞跃,同时也避免了重复造轮子的困扰"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|