{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 感知机\n", "\n", "感知机(Perceptron)是二分类的线性分类模型,输入为实例的特征向量,输出为实例的类别(取`+1`或`-1`)。感知机对应于输入空间中将实例划分为两类的分离超平面,感知机旨在求出该超平面。为求得超平面导入了基于误分类的损失函数,利用梯度下降法对损失函数进行最优化。感知机的学习算法具有简单而易于实现的优点,感知机预测是用学习得到的感知机模型对新的实例进行预测的,因此属于判别模型。\n", "\n", "![perceptron](images/perceptron.png)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. 生物学、心理学解释\n", "\n", "心理学家唐纳德·赫布(Donald Olding Hebb)于1949年提出`赫布理论`,该理论能够解释学习的过程中脑中的神经元所发生的变化。赫布理论描述了突触可塑性的基本原理,即突触前神经元向突触后神经元的持续重复的刺激,可以导致突触传递效能的增加:\n", "> 当细胞A的轴突足以接近以激发细胞B,并反复持续地对细胞B放电,一些生长过程或代谢变化将发生在某一个或这两个细胞内,以致A作为对B放电的细胞中的一个,导致突触传递效能的增加。\n", "\n", "![neuron_cell](images/neuron_cell_cn.png)\n", "\n", "\n", "心理学家弗兰克·罗森布拉特(Frank Rosenblatt)于1957年提出了`感知机`,它作为简化的数学模型解释大脑神经元如何工作:它取一组二进制输入值(附近的神经元),将每个输入值乘以一个连续值权重(每个附近神经元的突触强度),并设立一个阈值,如果这些加权输入值的和超过这个阈值,就输出1,否则输出0,这样的工作原理就相当于神经元是否放电。\n", "\n", "\n", "感知机并没有完全遵循赫布理论,**但通过调输入值的权重,可以有一个非常简单直观的学习方案:给定一个有输入输出实例的训练集,感知机应该「学习」一个函数:对每个例子,若感知机的输出值比实例低太多,则增加它的权重,否则若设比实例高太多,则减少它的权重。**\n", "\n", "\n", "模仿的是生物神经系统内的神经元,它能够接受来自多个源的信号输入,然后将信号转化为便于传播的信号在进行输出(在生物体内表现为电信号)。\n", "\n", "![neuron](images/neuron.png)\n", "\n", "* dendrites - 树突\n", "* nucleus - 细胞核\n", "* axon - 轴突\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. 感知机模型\n", "\n", "假设输入空间(特征向量)为$\\mathbf{X} \\subseteq \\mathbb{R}^n$,输出空间为$\\mathbf{Y} \\in \\{-1, +1\\}$。输入$x \\in \\mathbf{X}$ 表示实例的特征向量,对应于输入空间的点;输出$y \\in \\mathbf{Y}$表示示例的类别。由输入空间到输出空间的函数为\n", "\n", "$$\n", "f(x) = sign(w x + b)\n", "$$\n", "\n", "称为感知机。其中,参数$w$叫做权值向量,$b$称为偏置。$w·x$表示$w$和$x$的内积。$sign$为符号函数,即\n", "![sign_function](images/sign.png)\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 几何解释 \n", "感知机模型是`线性分类模型`,感知机模型的假设空间是定义在特征空间中的所有线性分类模型,即函数集合$\\{ f | f(x)=w·x+b\\}$。线性方程 $w·x+b=0$对应于特征空间$\\mathbb{R}^n$中的一个超平面$S$,其中$w$是超平面的法向量,$b$是超平面的截距。这个超平面把特征空间划分为两部分,位于两侧的点分别为正负两类。超平面$S$称为分离超平面,如下图:\n", "![perceptron_geometry_def](images/perceptron_geometry_def.png)\n", "\n", "### 2.2 生物学类比\n", "![perceptron_2](images/perceptron_2.PNG)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. 感知机学习策略\n", "\n", "假设训练数据集是 **线性可分**,感知机学习的目标是求得一个能够将训练数据的正负实例点完全分开的分离超平面,即最终求得参数$w, b$。这需要一个学习策略,即定义(经验)损失函数并将损失函数最小化。\n", "\n", "损失函数的一个自然的选择是误分类的点的总数,但是这样得到的损失函数不是参数$w,b$的连续可导函数,不宜优化。损失函数的另一个选择是误分类点到分类面的距离之和。\n", "\n", "首先,对于任意一点$x$到超平面的距离([参考资料](https://www.cnblogs.com/graphics/archive/2010/07/10/1774809.html))为\n", "$$\n", "\\frac{1}{||w||} | w \\cdot x + b |\n", "$$\n", "\n", "其次,对于误分类点$(x_i,y_i)$来说 $-y_i(w \\cdot x_i + b) > 0$\n", "\n", "这样,假设超平面$S$的总的误分类点集合为$\\mathbf{M}$,那么所有误分类点到$S$的距离之和为\n", "$$\n", "-\\frac{1}{||w||} \\sum_{x_i \\in \\mathbf{M}} y_i (w \\cdot x_i + b)\n", "$$\n", "不考虑$1/||w||$,就得到了感知机学习的损失函数\n", "$$\n", "L = - \\sum_{x_i \\in \\mathbf{M}} y_i (w \\cdot x_i + b)\n", "$$\n", "\n", "### 3.1 经验风险函数\n", "\n", "给定数据集$\\mathbf{T} = \\{(x_1,y_1), (x_2, y_2), ... (x_N, y_N)\\}$,其中\n", "* $x_i \\in \\mathbb{R}^n$\n", "* $y_i \\in \\{-1, +1\\},i=1,2...N$\n", "\n", "感知机$sign(w·x+b)$学习的损失函数定义为\n", "$$\n", "L(w, b) = - \\sum_{x_i \\in \\mathbf{M}} y_i (w \\cdot x_i + b)\n", "$$\n", "其中$\\mathbf{M}$为误分类点的集合,这个损失函数就是感知机学习的[经验风险函数](https://blog.csdn.net/zhzhx1204/article/details/70163099)。\n", "\n", "显然,损失函数$L(w,b)$是非负的。\n", "* 如果没有误分类点,那么$L(w,b)$为0\n", "* 误分类点数越少,$L(w,b)$值越小\n", "\n", "该损失函数在误分类时是参数$w,b$的线性函数,在正确分类时该损失函数是0。因此,给定训练数据集$\\mathbf{T}$,损失函数$L(w,b)$是$w,b$的连续可导函数。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 感知机学习算法\n", "\n", "\n", "最优化问题:给定数据集$\\mathbf{T} = \\{(x_1,y_1), (x_2, y_2), ... (x_N, y_N)\\}$,其中\n", "* $x_i \\in \\mathbb{R}^n$\n", "* $y_i \\in \\{-1, +1\\},i=1,2...N$\n", "\n", "求参数$w,b$,使其成为损失函数的解,其中$\\mathbf{M}$为误分类数据的集合:\n", "\n", "$$\n", "min_{w,b} L(w, b) = - \\sum_{x_i \\in \\mathbf{M}} y_i (w \\cdot x_i + b)\n", "$$\n", "\n", "感知机学习是误分类驱动的,具体采用随机梯度下降法:\n", "* 首先,任意选定$w_0$、$b_0$,\n", "* 然后用梯度下降法不断极小化目标函数\n", " - 极小化的过程不是一次性的把$\\mathbf{M}$中的所有误分类点梯度下降\n", " - 而是一次随机选取一个误分类点使其梯度下降。\n", "\n", "假设误分类集合$\\mathbf{M}$是固定的,那么损失函数$L(w,b)$的梯度为\n", "$$\n", "\\triangledown_w L(w, b) = - \\sum_{x_i \\in \\mathbf{M}} y_i x_i \\\\\n", "\\triangledown_b L(w, b) = - \\sum_{x_i \\in \\mathbf{M}} y_i \\\\\n", "$$\n", "\n", "随机选取一个误分类点$(x_i,y_i)$,对$w,b$进行更新:\n", "\\begin{eqnarray}\n", "w & = & w + \\eta y_i x_i \\\\\n", "b & = & b + \\eta y_i\n", "\\end{eqnarray}\n", "\n", "式中$\\eta$(0 ≤ $ \\eta $ ≤ 1)是学习速率(步长):\n", "* 步长越大,梯度下降的速度越快,更能接近极小点。如果步长过大,有可能导致跨过极小点,导致函数发散;\n", "* 如果步长过小,有可能会耗很长时间才能达到极小点。\n", "\n", "> 直观解释:当一个实例点被误分类时,调整$w,b$,使分离超平面向该误分类点的一侧移动,以减少该误分类点与超平面的距离,直至超越该点被正确分类。\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.1 算法\n", "\n", "\n", "**输入:**\n", "* $\\mathbf{T}=\\{(x_1,y_1),(x_2,y_2), ..., (x_N,y_N)\\}$, 其中$x_i \\in \\mathbf{X}=\\mathbb{R}^n$,\n", "* $y_i \\in \\mathbf{Y} = {-1, +1},i=1,2...N$,\n", "* 学习速率为η\n", "\n", "**输出:**\n", "* $w$, $b$; \n", "* 感知机模型$f(x)=sign(w·x+b)$\n", "\n", "**处理过程:**\n", "1. 初始化$w_0$,$b_0$\n", "2. 在训练数据集中选取$(x_i, y_i)$\n", "3. 如果$y_i(w * x_i+b)≤0$\n", " \n", " $w = w + η y_i x_i$\n", " \n", " $b = b + η y_i$\n", "\n", "4. 如果所有的样本都正确分类,或者迭代次数超过设定值,则终止\n", "5. 否则,跳转至(2)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. 示例程序\n", "\n", "生成数据:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAEICAYAAABGaK+TAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAU/UlEQVR4nO3dfZBddZ3n8fe3u9N5hhjSRk2iYQYGcGBmoHpEZMQawC0GRHAXa2AWRQc3PlA8qQs4DjuU7Di4MC7O4uqkEAVl4giDAz6uLCoWK8J0EBAIEJEAgYR0nkPS6af73T+6cUInodP33r6nT9/3qyrV956+95zPSXU++fXvnHNPZCaSpPJpKTqAJKk6FrgklZQFLkklZYFLUklZ4JJUUha4JJWUBS5JJWWBa9KJiFUR0RMR2yJic0T8PCI+EhGj/rxHxOKIyIhoa0RWqRYWuCarUzNzNvAm4CrgUuArxUaS6ssC16SWmVsy8w7gz4FzIuLwiDglIn4ZEVsj4rmIuGKXt/xs+OvmiHgpIo6JiN+NiB9HxIaIWB8RN0fEnEbvizSSBa6mkJn3A6uBtwPbgfcDc4BTgI9GxOnDLz1u+OuczJyVmfcCAfwd8AbgMGARcEWjskt7Y4GrmbwAzM3Mn2bmrzKzkpkPA8uAd+ztTZn568y8MzN7M7Mb+PyrvV5qFA/UqJksADZGxNEMzYsfDrQDU4Fb9vamiJgPfIGh0ftshgY+m8Y9rTQKR+BqChHxxwwV+D3APwF3AIsyc3/gywxNkwDs6eM5Pzu8/IjM3A84e5fXS4WxwDWpRcR+EfEu4JvANzLzVwyNojdm5s6IeAvwF7u8pRuoAL+zy7LZwEvAlohYAPzXxqSXXl34eeCabCJiFTAfGGCojB8DvgF8OTMHI+IM4O+BucDdwCqGDlqePfz+zwAfBaYAJwHbgJuAQ4BfA18HLs7MhY3bK2l3FrgklZRTKJJUUha4JJWUBS5JJWWBS1JJNfRCnnnz5uXixYsbuUlJKr3ly5evz8yOkcsbWuCLFy+mq6urkZuUpNKLiGf2tNwpFEkqKQtckkrKApekkrLAJamkLHBJKikLXJJKygKX6igH11PZdD6VDX9B9i0vOo4mOQtcqqPccjH0/l/o7yI3nUtWdhQdSZOYBS7V08BqYHDocQ5Abi00jiY3C1yqp1nnMXSLzekw9e3QMr/oRJrEvKmxVEctM84g24+G3AZthxLhrTM1fixwqc6ibVHREdQknEKRpJKywCWppCxwSSopC1ySSsoCl6SSssAlqaQscEkqKQtckkrKApekkrLAJamkLHBJKikLXJJKygKXpJKywCWppEYt8Ii4ISLWRcQjuyy7OiIej4iHI+LbETFnXFNKknazLyPwrwEnjVh2J3B4Zv4B8CTwqTrnkiSNYtQCz8yfARtHLPtRZg4MP/0FsHAcskmSXkU95sD/EvhBHdYjSRqDmgo8Ij4NDAA3v8prlkREV0R0dXd317I5SdIuqi7wiPgA8C7gP2dm7u11mbk0Mzszs7Ojo6PazUmSRqjqpsYRcRJwCfCOzNxR30iSpH2xL6cRLgPuBQ6JiNURcS5wHTAbuDMiHoyIL49zTknSCKOOwDPzrD0s/so4ZJEkjYFXYkpSSVngklRSFrgklZQFLkklZYFLUklZ4JJUUha4JJWUBS5JJWWBS1JJWeCSVFIWuCSVlAUuSSVlgUtSSVngklRSFrhUApUdt1JZfwaVrX9LZl/RcTRBVHVHHkmNk30PwbYrIXtg4EmyZS4x66NFx9IE4AhcmugqLwAx/GQnDK4qMIwmEkfg0kTX/nZoOQAqASQx431FJ9IEYYFLE1y0zIJ534P+J6BtEdEyt+hImiAscKkEIqZB+x8WHUMTjHPgklRSFrgklZQFLkklZYFLUkmNWuARcUNErIuIR3ZZNjci7oyIlcNfXzO+MSVJI+3LCPxrwEkjll0G3JWZBwN3DT+XJDXQqAWemT8DNo5YfBpw4/DjG4HT6xtLkjSaaufA52fmmuHHa4H5e3thRCyJiK6I6Oru7q5yc5KkkWo+iJmZCeSrfH9pZnZmZmdHR0etm5MkDau2wF+MiNcDDH9dV79IkqR9UW2B3wGcM/z4HOD2+sSRJO2rfTmNcBlwL3BIRKyOiHOBq4B3RsRK4MTh55KkBhr1w6wy86y9fOuEOmeRJI2BV2JKUklZ4JJUUha4JJWUBS5JJWWBS1JJWeCSVFIWuCSVlAUuSSVlgUtSSVngklRSFrgkjZPMPiqbPkJl7RFUNp5DZk9d12+BS9J46fkO9N0L9ELfctixrK6rt8AladwMQL58v5skc6Cua7fAJWm8TD8Npvw+ENB2MDFjbx/uWp1RP05WklSdiGnEAcvIHCSite7rdwQuSeNsPMobLHBJKi0LXJJKygKXVJWsbKey5TNUNi4h+5YXHacpeRBTUlVy6xWw8wdAH7npfuj4CdHymoJTNRdH4NI+yMpLZN9DZGVr0VEmjoEngL5/fz74YmFRmpUFLo0iB18ku08kN32Q7D6eHHi26EgTw4xzgWkQM6H1QGg7qOhETccpFGk0O38I+RJDo80Wsud2Yvb5RacqXMuM08j234fBddDeSYR10mg1jcAj4uKIeDQiHomIZRExrV7BpAmjdSHw8nm8U4m2hUWmmVCi7SBi6tuIaC86SlOqusAjYgFwAdCZmYcz9BN+Zr2CSRPG1ONh1vnQ9gcw81yYdlrRiSSg9imUNmB6RPQDM4AXao8kTSwRQcz6EMz6UNFRpFeoegSemc8D1wDPAmuALZn5o5Gvi4glEdEVEV3d3d3VJ5UkvUItUyivAU4DDgTeAMyMiLNHvi4zl2ZmZ2Z2dnR0VJ9UkvQKtRzEPBF4OjO7M7MfuA14W31iSZJGU0uBPwu8NSJmREQAJwAr6hNLkjSaWubA7wNuBR4AfjW8rqV1yiVJGkVNZ6Fk5t8Af1OnLJKkMfBSekkqKQtckkrKApekkrLAJamkLHBJKikLXJJKygKXpJKywCWppCxwSSopC1ySSsoCl6SSssAlqaQscFUlc4DsvYfse6joKFLTqvWemGpCmUlu/EsYeBgyyVkfpmXWx4qOJTUdR+Aau8oG6H8AcgfQAztuLjqR1JQscI1dy34Q04EApkDb7xWdSGpKTqFozCLa4YBl5EvXQexPzP540ZGkpmSBqyrRdhAx59qiY0hNzSkUSSopC1ySSsoCV9Oq9PyQyvrTqWy+mKxsKzqONGbOgasp5cBq2HIJsBMGVpIxldj/qqJjSWPiCFzNqbIe4uUf/34YeK7QOFI1airwiJgTEbdGxOMRsSIijqlXMGlcTTkc2g4BZkBMJ2adV3SiMcv+J4emgLpPIvvuLzqOClDrFMoXgB9m5hkR0Q7MqEMmadxFtMHcf4KBldDSQbQeUHSkMcvN58HgM0OPNy2B1y4norXgVGqkqgs8IvYHjgM+AJCZfUBffWJJ4y+iFaYcWnSM6lU2//vj7AUGAQu8mdQyhXIg0A18NSJ+GRHXR8TMkS+KiCUR0RURXd3d3TVsTtIrzL4UaAemwMwlQ1fIqqlEZlb3xohO4BfAsZl5X0R8AdiamZfv7T2dnZ3Z1dVVXVJJu8nKZsgBonVe0VE0jiJieWZ2jlxeywh8NbA6M+8bfn4rcFQN65M0RtEyx/JuYlUXeGauBZ6LiEOGF50APFaXVJJqUun5PpV1x1PZcBY5uKboOBontZ6Fcj5w8/AZKL8BPlh7JEm1yMENsOVSoBcqL5BbLiPm3lh0LI2Dmgo8Mx8EdpuXkVSg3L7LkwpUNhYWRePLKzGlyaZ1EUw7iaEzVKYRsy8rOpHGiZ+FIk0yEUHMuZoc/CTETKJlVtGRNE4scGmSitb5RUfQOHMKRZJKygKXpJKywCWppCxwSSopC1ySSsoCl6SS8jRCSTXJ7IEd/wwMwvQ/97zzBrLAJdUkN30E+pYPPen5HjHvtmIDNRELXFJt+h7gtzfjGniUzEFv7dYgzoFLqs3UY4BpQ3+mHGV5N5AjcEk1iTnXQc/tDM2Bn150nKZigUuqSUQ7zHhv0TGaklMoklRSFrgklZQFLkklZYFLUklZ4JJUUha4JJWUBS5JJWWBS1JJ1VzgEdEaEb+MiO/WI5Akad/UYwR+IbCiDuuRJI1BTQUeEQuBU4Dr6xNHkrSvah2BXwtcAlRqjyJJGouqCzwi3gWsy8zlo7xuSUR0RURXd3d3tZuTJI1Qywj8WODdEbEK+CZwfER8Y+SLMnNpZnZmZmdHR0cNm1OtMpMNazbR29NbdBRJdVB1gWfmpzJzYWYuBs4EfpyZZ9ctmeqqUqlw+buv4n2/8zHe+7r/wuP3ryw6kqQaeR54k3iy6ykeuvsx+nsH6NnWww2fXlZ0JEk1qssNHTLzp8BP67EujY9Zc2ZSGRw61tza1sL+HfsVnEhSrRyBN4mFv/cGPnzN++lYNI8jjnszH7v2g0VHklSjyMyGbayzszO7uroatj1JmgwiYnlmdo5c7ghckkrKApekkrLAJamkLHBJKikLvGR6e3pZ8/SLDA4OFh1FUsHqch64GuPZx5/nomM/Td/OfhYduoBr77mSqdOnFh1LUkEcgZfIt66+nZc276C3p4/nV67h/h88WHQkSQWywEtk7uvm0NY+9EtTZjLHqymlpjbhp1Ayk+V3PszmdVs49vQ/Zvqs6b/9XqVSoaWlef4POuuv/iMv/HotT/zbU5y85ESOePthRUeSVKAJfyXm1z9zC9+6+nYAXvvGeSx96O95dsVqLnnnlWzdsI3/dPEpLPkf7x+PuJI0IZT2Ssw7b7qbndt72bm9lxef6WbtqnVcd8ENbF63hcpghTu++H947onnd3tf384+7rzpbu7+1s/rdsbG+uc38N/P+p9c/u6reOax5+qyTkmq1oSfQjn8Tw5lwwsb6e8dYOr0qcxbMJeW1hYigswkgZbW3f8fuuSdn+GpB1cB8IvvLufSm86vOcvlp32O3zz0DFmpsOK+ldyy9noioub1SlI1JnyBX/SPH+aNhy5gw5pNnH7+nzF1+lQu+OKH+KuTP8vGtZs587LTWXDQ61/xnt6eXlbc+ySVytD00M/v+Le6ZFnz1Iu//UjWbRtfor+3n/Zp7XVZtySN1YQv8PapUzjzsve8YtmiQxbw9ae+uPf3TGvn9b/7OtY+vY6W1uCwow+uS5b3XHgyt1xzBxHB2057i+UtqVAT/iBmtTZ3b+Hb//B9ps6YynsuOJnpM6fVZb1PPbSK3p4+Djv6YKdPJDXE3g5iTtoCl6TJorRnoYyn/r5+Hv35E6x7trvoKJI0ZhN+Dny89PX2c/5bPzV0YLJS4b/d8kne8mdHFh1LkvZZ047An+x6ijW/eZGel3bSu6OPb37u20VHkqQxadoC71h4AJWBoVMCp0ydwhsPW1BwIkkam6Yt8Plv6uDyWz7BEccdxn/4wDv48DXnFB1JksakaefAAY4++SiOPvmoomNIUlWqHoFHxKKI+ElEPBYRj0bEhfUMJkl6dbWMwAeAT2TmAxExG1geEXdm5mN1yiZJehVVj8Azc01mPjD8eBuwAvBIoCQ1SF0OYkbEYuBI4L49fG9JRHRFRFd3txfMSFK91FzgETEL+BfgoszcOvL7mbk0Mzszs7Ojo6PWzUmShtVU4BExhaHyvjkzb6tPJEnSvqjlLJQAvgKsyMzP1y+SJGlf1DICPxZ4H3B8RDw4/OfkOuWSJI2i6tMIM/MewA/ElqSCNO2l9JJUdhZ4FTZ3b2HHtp6iY0hqchb4GP3vi7/KWYs+wnvnn8v/+9f7i44jqYlZ4GOwZf1WvvOlHzHQN0Dfzn6+9PGv1WW9v37waf761L/jcx+4ji3rdzuVXpL2qKk/jXCs2qdNoaVl6LhtBMyeO6vmdfb19vPJP72C7Vt20DallfXPrefqu66oeb2SJj9H4GMwfdZ0/vqfP87rFr+Wg448kE8vu7jmdW7fvJ3enj4ABvoHee6JF2pep6Tm0LQj8J7tO5nS3kbblLH9FRxzaifHnLrbzaGrNue1+3Pk8UfwyD0rqFQqnPHxU+u2bkmTW1MW+NJLvs5t136PKVPbuPKOy/ijPz28sCwRwZXfuZQV9z7JzP1ncOARbyosi6RyaboplA1rNvGv/+v7DA4MsnN7L/9w3vVFR6K1tZXD/+Qwy1vSmDRdgU9pb+PlC0gjYPqsacUGkqQqNV2B73fAbC76xyXM6diPRYcu5NKbzi86kiRVJTKzYRvr7OzMrq6uhm1PkiaDiFiembudPdF0I3BJmiwscEkqKQtckkrKApekkrLAJamkLHBJKikLXJJKqqHngUdEN/BMwzZYH/OA9UWHKEgz7zu4/828/xNt39+UmR0jFza0wMsoIrr2dAJ9M2jmfQf3v5n3vyz77hSKJJWUBS5JJWWBj25p0QEK1Mz7Du5/M+9/KfbdOXBJKilH4JJUUha4JJWUBb4HEbEoIn4SEY9FxKMRcWHRmYoQEa0R8cuI+G7RWRotIuZExK0R8XhErIiIY4rO1CgRcfHwz/0jEbEsIib1basi4oaIWBcRj+yybG5E3BkRK4e/vqbIjHtjge/ZAPCJzHwz8FbgvIh4c8GZinAhsKLoEAX5AvDDzDwU+EOa5O8hIhYAFwCdmXk40AqcWWyqcfc14KQRyy4D7srMg4G7hp9POBb4HmTmmsx8YPjxNob+8S4oNlVjRcRC4BSg+Ls+N1hE7A8cB3wFIDP7MnNzoaEaqw2YHhFtwAzghYLzjKvM/BmwccTi04Abhx/fCJzeyEz7ygIfRUQsBo4E7is4SqNdC1wCVArOUYQDgW7gq8NTSNdHxMyiQzVCZj4PXAM8C6wBtmTmj4pNVYj5mblm+PFaYH6RYfbGAn8VETEL+BfgoszcWnSeRomIdwHrMnN50VkK0gYcBXwpM48EtjNBf4Wut+G53tMY+k/sDcDMiDi72FTFyqFzrSfk+dYW+F5ExBSGyvvmzLyt6DwNdizw7ohYBXwTOD4ivlFspIZaDazOzJd/67qVoUJvBicCT2dmd2b2A7cBbys4UxFejIjXAwx/XVdwnj2ywPcgIoKh+c8Vmfn5ovM0WmZ+KjMXZuZihg5g/Tgzm2YUlplrgeci4pDhRScAjxUYqZGeBd4aETOG/x2cQJMcwB3hDuCc4cfnALcXmGWvLPA9OxZ4H0MjzweH/5xcdCg11PnAzRHxMPBHwGeLjdMYw7913Ao8APyKoY4oxWXl1YqIZcC9wCERsToizgWuAt4ZESsZ+q3kqiIz7o2X0ktSSTkCl6SSssAlqaQscEkqKQtckkrKApekkrLAJamkLHBJKqn/D+qhvYKqDzFRAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# data generation\n", "np.random.seed(314)\n", "\n", "data_size1 = 10\n", "x1 = np.random.randn(data_size1, 2) + np.array([2,2])\n", "y1 = [-1 for _ in range(data_size1)]\n", "\n", "data_size2 = 10\n", "x2 = np.random.randn(data_size2, 2)*2 + np.array([8,8])\n", "y2 = [1 for _ in range(data_size2)]\n", "\n", "# all sample data\n", "x = np.concatenate((x1, x2), axis=0)\n", "y = np.concatenate((y1, y2), axis=0)\n", "\n", "shuffled_index = np.random.permutation(data_size1 + data_size2)\n", "x = x[shuffled_index]\n", "y = y[shuffled_index]\n", "\n", "\n", "train_data = np.concatenate((x, y[:, np.newaxis]), axis=1)\n", "\n", "# plot data\n", "plt.scatter(train_data[:,0], train_data[:,1], c=train_data[:,2], marker='.')\n", "plt.title(\"Data\")\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "学习模型:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "lines_to_end_of_cell_marker": 2 }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "update weight/bias: 3.662519607024163 3.038628576040485 0.5\n", "update weight/bias: 3.1283482021588505 1.595298136548129 0.0\n", "update weight/bias: 1.7022056555578562 0.9488336160189257 -0.5\n", "update weight/bias: 0.6064534385664728 0.25821553124912766 -1.0\n", "update weight/bias: -0.0072588404447664345 -1.1732780618880432 -1.5\n", "update weight/bias: 3.56979581700504 4.157268901515593 -1.0\n", "update weight/bias: 2.7104258301108155 3.6232316160976543 -1.5\n", "update weight/bias: 1.6146736131194321 2.9326135313278563 -2.0\n", "update weight/bias: 1.0805022082541196 1.4892830918355 -2.5\n", "update weight/bias: -0.3456403383468747 0.8428185713062968 -3.0\n", "update weight/bias: 3.4230241861565407 3.6955935657768997 -2.5\n", "update weight/bias: 2.5636541992623156 3.161556280358961 -3.0\n", "update weight/bias: 1.1375116526613214 2.5150917598297577 -3.5\n", "update weight/bias: 0.38678027777241897 1.1329290169473543 -4.0\n", "update weight/bias: -1.446079265832824 -0.7361650837497964 -4.5\n", "update weight/bias: 1.7930043074867144 5.9278879714909145 -4.0\n", "update weight/bias: 1.2588329026214018 4.484557531998558 -4.5\n", "update weight/bias: 0.3697277316535954 3.2936957431536 -5.0\n", "update weight/bias: -0.519377439314211 2.1028339543086423 -5.5\n", "update weight/bias: -2.352236982919454 0.23373985361149163 -6.0\n", "update weight/bias: 1.8159337720901148 4.415105700242464 -5.5\n", "update weight/bias: 0.7328910527198487 3.024123458986641 -6.0\n", "update weight/bias: -0.3501516666504174 1.6331412177308182 -6.5\n", "w = [-0.3501516666504174, 1.6331412177308182]\n", "b = -6.5\n", "\n", "\n", "ground_truth: [-1. -1. -1. -1. -1. 1. 1. -1. 1. 1. 1. 1. -1. 1. 1. 1. 1. -1.\n", " -1. -1.]\n", "predicted: [-1. -1. -1. -1. -1. -1. 1. -1. 1. 1. 1. 1. -1. 1. 1. 1. 1. -1.\n", " -1. -1.]\n" ] } ], "source": [ "import random\n", "import numpy as np\n", "\n", "# 符号函数\n", "def sign(v):\n", " if v > 0: return 1\n", " else: return -1\n", " \n", "def perceptron_train(train_data, eta=0.5, n_iter=100):\n", " weight = [0, 0] # 权重\n", " bias = 0 # 偏置量\n", " learning_rate = eta # 学习速率\n", "\n", " train_num = n_iter # 迭代次数\n", "\n", " for i in range(train_num):\n", " # select one data\n", " ti = np.random.randint(len(train_data))\n", " (x1, x2, y) = train_data[ti]\n", " \n", " y_pred = sign(weight[0] * x1 + weight[1] * x2 + bias) \n", " \n", " if y * y_pred <= 0: # 判断误分类点\n", " weight[0] = weight[0] + learning_rate * y * x1 # 更新权重\n", " weight[1] = weight[1] + learning_rate * y * x2\n", " bias = bias + learning_rate * y # 更新偏置量\n", " print(\"update weight/bias: \", weight[0], weight[1], bias)\n", "\n", " return weight, bias\n", "\n", "def perceptron_pred(data, w, b):\n", " y_pred = []\n", " for d in data:\n", " x1, x2, y = d\n", " yi = sign(w[0]*x1 + w[1]*x2 + b)\n", " y_pred.append(yi)\n", " \n", " return np.array(y_pred, dtype=float)\n", "\n", "\n", "# do training\n", "w, b = perceptron_train(train_data)\n", "print(\"w = \", w)\n", "print(\"b = \", b)\n", "\n", "# predict \n", "y_pred = perceptron_pred(train_data, w, b)\n", "\n", "print(\"\\n\")\n", "print(\"ground_truth: \", train_data[:, 2])\n", "print(\"predicted: \", y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 参考资料\n", "* [感知机(Python实现)](http://www.cnblogs.com/kaituorensheng/p/3561091.html)\n", "* [Programming a Perceptron in Python](https://blog.dbrgn.ch/2013/3/26/perceptrons-in-python/)\n", "* [损失函数、风险函数、经验风险最小化、结构风险最小化](https://blog.csdn.net/zhzhx1204/article/details/70163099)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 2 }