{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 感知机\n", "\n", "感知机(perceptron)是二分类的线性分类模型,输入为实例的特征向量,输出为实例的类别(取+1和-1)。感知机对应于输入空间中将实例划分为两类的分离超平面,感知机旨在求出该超平面。为求得超平面导入了基于误分类的损失函数,利用梯度下降法 对损失函数进行最优化(最优化)。感知机的学习算法具有简单而易于实现的优点,感知机预测是用学习得到的感知机模型对新的实例进行预测的,因此属于判别模型。感知机由Rosenblatt于1957年提出的,是神经网络和支持向量机的基础。\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. 生物学解释\n", "心理学家Rosenblatt构想了感知机,它作为简化的数学模型解释大脑神经元如何工作:它取一组二进制输入值(附近的神经元),将每个输入值乘以一个连续值权重(每个附近神经元的突触强度),并设立一个阈值,如果这些加权输入值的和超过这个阈值,就输出1,否则输出0(同理于神经元是否放电)。对于感知机,绝大多数输入值不是一些数据,就是别的感知机的输出值。\n", "\n", "唐纳德·赫布提出了一个出人意料并影响深远的想法,称知识和学习发生在大脑主要是通过神经元间突触的形成与变化,简要表述为赫布法则:\n", "\n", "> 当细胞A的轴突足以接近以激发细胞B,并反复持续地对细胞B放电,一些生长过程或代谢变化将发生在某一个或这两个细胞内,以致A作为对B放电的细胞中的一个,效率增加。\n", "\n", "\n", "感知机并没有完全遵循这个想法,**但通过调输入值的权重,可以有一个非常简单直观的学习方案:给定一个有输入输出实例的训练集,感知机应该「学习」一个函数:对每个例子,若感知机的输出值比实例低太多,则增加它的权重,否则若设比实例高太多,则减少它的权重。**\n", "\n", "\n", "模仿的是生物神经系统内的神经元,它能够接受来自多个源的信号输入,然后将信号转化为便于传播的信号在进行输出(在生物体内表现为电信号)。\n", "\n", "![neuron](images/neuron.png)\n", "\n", "* dendrites - 树突\n", "* nucleus - 细胞核\n", "* axon - 轴突\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. 感知机模型\n", "\n", "假设输入空间(特征向量)为$X \\subseteq R^n$,输出空间为$Y=\\{-1, +1\\}$。输入$x \\in X$ 表示实例的特征向量,对应于输入空间的点;输出$y \\in Y$表示示例的类别。由输入空间到输出空间的函数为\n", "\n", "$$\n", "f(x) = sign(w x + b)\n", "$$\n", "\n", "称为感知机。其中,参数$w$叫做权值向量,$b$称为偏置。$w·x$表示$w$和$x$的内积。$sign$为符号函数,即\n", "![sign_function](images/sign.png)\n", "\n", "### 2.1 几何解释 \n", "感知机模型是线性分类模型,感知机模型的假设空间是定义在特征空间中的所有线性分类模型,即函数集合{f|f(x)=w·x+b}。线性方程 w·x+b=0对应于特征空间Rn中的一个超平面S,其中w是超平面的法向量,b是超平面的截踞。这个超平面把特征空间划分为两部分。位于两侧的点分别为正负两类。超平面S称为分离超平面,如下图:\n", "![perceptron_geometry_def](images/perceptron_geometry_def.png)\n", "\n", "### 2.2 生物学类比\n", "![perceptron_2](images/perceptron_2.PNG)\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. 感知机学习策略\n", "\n", "假设训练数据集是线性可分的,感知机学习的目标是求得一个能够将训练数据的正负实例点完全分开的分离超平面,即最终求得参数w、b。这需要一个学习策略,即定义(经验)损失函数并将损失函数最小化。\n", "\n", "损失函数的一个自然的选择是误分类的点的总数。但是这样得到的损失函数不是参数w、b的连续可导函数,不宜优化。损失函数的另一个选择是误分类点到分类面的距离之和。\n", "\n", "首先,对于任意一点xo到超平面的距离为\n", "$$\n", "\\frac{1}{||w||} | w \\cdot xo + b |\n", "$$\n", "\n", "其次,对于误分类点$(x_i,y_i)$来说 $-y_i(w \\cdot x_i + b) > 0$\n", "\n", "这样,假设超平面S的总的误分类点集合为M,那么所有误分类点到S的距离之和为\n", "$$\n", "-\\frac{1}{||w||} \\sum_{x_i \\in M} y_i (w \\cdot x_i + b)\n", "$$\n", "不考虑1/||w||,就得到了感知机学习的损失函数。\n", "\n", "### 3.1 经验风险函数\n", "\n", "给定数据集$T = \\{(x_1,y_1), (x_2, y_2), ... (x_N, y_N)\\}$(其中$x_i \\in R^n$, $y_i \\in \\{-1, +1\\},i=1,2...N$),感知机sign(w·x+b)学习的损失函数定义为\n", "$$\n", "L(w, b) = - \\sum_{x_i \\in M} y_i (w \\cdot x_i + b)\n", "$$\n", "其中M为误分类点的集合,这个损失函数就是感知机学习的[《经验风险函数》](https://blog.csdn.net/zhzhx1204/article/details/70163099)。\n", "\n", "显然,损失函数$L(w,b)$是非负的。如果没有误分类点,那么$L(w,b)$为0,误分类点数越少,$L(w,b)$值越小。一个特定的损失函数:在误分类时是参数$w,b$的线性函数,在正确分类时,是0.因此,给定训练数据集T,损失函数$L(w,b)$是$w,b$的连续可导函数。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 感知机学习算法\n", "\n", "\n", "最优化问题:给定数据集$T = \\{(x_1,y_1), (x_2, y_2), ... (x_N, y_N)\\}$,其中$x_i \\in R^n$, $y_i \\in \\{-1, +1\\},i=1,2...N$,求参数$w,b$,使其成为损失函数的解($M$为误分类的集合):\n", "\n", "$$\n", "min_{w,b} L(w, b) = - \\sum_{x_i \\in M} y_i (w \\cdot x_i + b)\n", "$$\n", "\n", "感知机学习是误分类驱动的,具体采用随机梯度下降法。首先,任意选定$w_0$、$b_0$,然后用梯度下降法不断极小化目标函数,极小化的过程不是一次性的把$M$中的所有误分类点梯度下降,而是一次随机选取一个误分类点使其梯度下降。\n", "\n", "假设误分类集合$M$是固定的,那么损失函数$L(w,b)$的梯度为\n", "$$\n", "\\triangledown_w L(w, b) = - \\sum_{x_i \\in M} y_i x_i \\\\\n", "\\triangledown_b L(w, b) = - \\sum_{x_i \\in M} y_i \\\\\n", "$$\n", "\n", "随机选取一个误分类点$(x_i,y_i)$,对$w,b$进行更新:\n", "$$\n", "w = w + \\eta y_i x_i \\\\\n", "b = b + \\eta y_i\n", "$$\n", "\n", "式中$\\eta$(0 ≤ $ \\eta $ ≤ 1)是学习速率(步长)。步长越大,梯度下降的速度越快,更能接近极小点。如果步长过大,有可能导致跨过极小点,导致函数发散;如果步长过小,有可能会耗很长时间才能达到极小点。\n", "\n", "直观解释:当一个实例点被误分类时,调整w,b,使分离超平面向该误分类点的一侧移动,以减少该误分类点与超平面的距离,直至超越该点被正确分类。\n", "\n", "\n", "\n", "### 4.1 算法\n", "\n", "\n", "输入:T={(x1,y1),(x2,y2)...(xN,yN)}(其中xi∈X=Rn,yi∈Y={-1, +1},i=1,2...N,学习速率为η)\n", "\n", "输出:w, b;感知机模型f(x)=sign(w·x+b)\n", "\n", "1. 初始化$w_0$,$b_0$\n", "2. 在训练数据集中选取$(x_i, y_i)$\n", "3. 如果$y_i(w * x_i+b)≤0$\n", " \n", " $w = w + η y_i x_i$\n", " \n", " $b = b + η y_i$\n", "\n", "4. 如果所有的样本都正确分类,或者迭代次数超过设定值,则终止\n", "5. 否则,跳转至(2)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. 示例程序\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAEICAYAAABGaK+TAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAcTklEQVR4nO3de5yUdd3/8ddnZnaXPXBm1RBjFU8BIepqnlKT7tQ8Z3k2stQ7szIztdMjtfKuOw/pnaZxI4o/TTM1RX+pGKCmeVpEBFGUBAFFWM6w55353H/MqAgsLDPXzDXXzvv5ePjYmYvZ63qP7r798p3r+l7m7oiISPTEwg4gIiLZUYGLiESUClxEJKJU4CIiEaUCFxGJKBW4iEhEqcBFRCJKBS49jpktMLMWM1tnZqvN7F9m9m0z2+rPu5nVmZmbWaIQWUVyoQKXnuo4d+8NDAV+C1wO3BZuJJFgqcClR3P3Ne4+CTgVGGtmI83sGDObYWZrzWyRmV25wbc8k/m62szWm9mBZjbMzKaa2QozW25md5tZv0K/F5GNqcClJLj7S8Bi4PNAE/B1oB9wDHCBmZ2Yeemhma/93L3G3Z8HDPgNMBj4DLATcGWhsot0RQUupeR9YIC7P+Xus9w95e6vAfcAh3X1Te4+z92fdPc2d28Ert/S60UKRR/USCnZEVhpZp8jPS8+EigHKoC/dvVNZrY9cCPp0Xtv0gOfVXlPK7IVGoFLSTCz/UgX+LPAn4FJwE7u3he4lfQ0CcDmluf8r8z2z7p7H+CsDV4vEhoVuPRoZtbHzI4F7gXucvdZpEfRK9291cz2B87Y4FsagRSwywbbegPrgTVmtiNwaWHSi2yZaT1w6WnMbAGwPdBJuoznAHcBt7p70sy+ClwHDACeBhaQ/tDyrMz3/xK4ACgDjgLWAXcCewDzgP8HXOzuQwr3rkQ2pQIXEYkoTaGIiESUClxEJKJU4CIiEaUCFxGJqIJeyDNo0CCvq6sr5CFFRCJv+vTpy929duPtBS3wuro6GhoaCnlIEZHIM7N3N7ddUygiIhHVnQXuJ5jZMjObvcG2a8zsTTN7zcz+pqU1RUQKrzsj8DtIX422oSeBke4+CngL+EnAuUREZCu2WuDu/gywcqNtk929M/P0BUCXFIuIFFgQc+DfBB4LYD8iIrINcipwM/sZ6QWD7t7Ca843swYza2hsbMzlcCIisoGsC9zMvgEcC5zpW1gRy93HuXu9u9fX1m5yGqNIZLmnSK37PanlXyXVNCHsOFKCsjoP3MyOAi4DDnP35mAjiUREywPQdAfQAuvfxhPDsArdaU0KpzunEd4DPA/sYWaLzexbwE2kF7l/0sxeNbNb85xTpOh4chHQmnmSguR7oeaR0rPVEbi7n76ZzbflIYtIpFjlV/DmuwEH6wUVXwo7kpQY3dRYJEuWqIPaqdC5ABK7YrHqsCNJiVGBi+TAYn2hfK+wY0iJ0looIiIRpQIXEYkoFbiISESpwEVEIkoFLiISUSpwEZGIUoGLiESUClxEJKJ0IY+IZM29FVqfAKuCijGYaUxYSCpwEcmKu+Mrz4bOt8CByuOxvr8KO1ZJ0f8uRSQ7vg46XgdvAVqg9fGwE5UcFbhEmnuS1KqLSH0wgtTyk/DUqrAjlQ6rgdh2QBwoh7LRIQcqPSpwiba2qdD+NNABnXPxpvFhJyoZZjFs4H1QdQ7UXID1uzHsSCVHc+AScalPPvXU5l8meWHx7bA+l4Udo2RpBC7RVjEGyg8EYpDYBas5L+xEIgWjEbhEmlkC638L7o6ZhR1HpKA0ApceQeUtpUgFLiISUZpCESky3vYC3vY0Vr4/1usLYceRIqYCFyki3v4qvup8oDV9x/v+N2EVh4YdS4qUplBEiknHTD4+NbIVb38lzDRS5DQCFykm5QeBxcF7AWAVh4UcSIqZClykiFjZbjDwAWh/CcpGY2XDw44kRUwFLlJkLLErJHYNO4ZEwFbnwM1sgpktM7PZG2wbYGZPmtnbma/98xtTREQ21p0PMe8Ajtpo24+BKe6+GzAl81xERApoqwXu7s8AKzfafAIwMfN4InBisLFERGRrsj2NcHt3X5J5/AGwfVcvNLPzzazBzBoaGxuzPJyIiGws5/PA3d1J31Cpqz8f5+717l5fW1ub6+FERCQj2wJfamafAsh8XRZcJBER6Y5sC3wSMDbzeCzwcDBxRGRDnlpPas3PSa38Bt72YthxpMhs9TxwM7sHOBwYZGaLgSuA3wL3mdm3gHeBU/IZUqRU+dorMjcL7sA7ZkDtNCw2IOxYUiS2WuDufnoXfzQm4CwisrHOfwMd6cdukFwGKnDJ0GJWIsWs+jygF1g1lOkKTfkkXUovUsRilcfgZSMhtSy9NorpV1Y+pp8GkSJniaHA0LBjSBHSFIqISESpwEVEIkoFLiISUSpwEZGIUoGLiESUClxEJKJU4CIiEaUCFxGJKBW4iEhEqcBFIsi9BU+tCTuGhEwFLrKNvGMuqXXX4S2TSN+QqsDHb52GL90fX3YQqbW/K/jxpXhoLRSRbeDJD/CVp4I341RCahVWPXbr3xhkhrW/BNrST5on4jXnYbH+Bc0gxUEjcJFt0fkmH//atEDbs4XPEOsN2IdPgPLCZ5CioAIX2RZlo0j/2lQAldDrywWPYP1+D4ndIfYp6Hs9FqsueAYpDppCEdkGFhsAgyZB6xRIDMMqDip8hsQwbNAjBT9usXJvx1dfDO3/gvIDsH43YlYafyvRCFxkG1l8MFZ9diDl7amVuLcHkKqEtTwIbf8Eb4K256DlgbATFYwKXCQE7ilSqy7El30eX3Yg3vF6ZnsH3jELTzaGnDBCvB348GygFHhrmGkKSgUuEoaOWdD+HNABvg5fd216KmDF1/CVZ+ONX8Tbng87ZTRUfgUSuwCW/lr5tbATFYzmwEXCEKsGT2WexCHWFzpmQvJd8GYAvOlPWMWB4WWMCIvVYIMexr0Ns4qs9+PeClRgZlt9bbHQCFwkBJbYFXpfArFaKNsb6/1ziG0Hnsy8ohzidWFGjJxsy9s9mZ7OWjoabzwE71wYcLL8UYGLhCRWPZbYds8RG/hnLD4offPivtdAYhRUHof1vizsiKWh/cXMdFYKUsvx9TeEnajbNIUiUkRilUdC5ZFhxygt1gs+WhIhnn4eERqBi0hpK9sbqs4Eq4GyEVjNJWEn6racRuBmdjFwLulzeGYB57iX0Dk8IhJ5Zob1uQz6RG/KKusRuJntCHwfqHf3kUAcOC2oYCIismW5TqEkgEozSwBVwPu5RxIRke7IusDd/T3gWmAhsARY4+6TN36dmZ1vZg1m1tDYqKvLRESCkssUSn/gBGBnYDBQbWZnbfw6dx/n7vXuXl9bW5t9UhER+YRcplC+CMx390Z37wAeBAq/NJuISInKpcAXAgeYWZWlrz0dA7wRTCyRLXNvIT1uEClducyBvwjcD7xC+hTCGDAuoFwiXUqtvRZfug++tF4LPklJy+ksFHe/wt33dPeR7n62u7cFFUxkczy5HJpvB5JAC772ynADiYRIV2JKtFgZH98PErCq0KKIhE0FLpFisb7Q59dgAyC+C9bvmrAjiYRGi1lJ5MSqToSqE8OOIRI6jcBFRCJKBS4iElEqcBGRPPLOBXjzX/GOtwPft+bARUTyxDvn4StOTi+4jcPAu7CyUYHtXyNwEZF8afsneCfQArTirdMC3b0KXEQkX8pGkb5VAkAvrHyvQHevKRQRkTyx8n2h/01461Ss4mCs4vBA968CFxHJI6s4FKs4NC/71hSKiEhEqcBFRCJKBS4iElEqcBGRiFKBi4hElApcRCSiVOAiIhGlAhcRiSgVuIhIRKnARUQiSgUuIhJRKnCREuXJD0gtP5HU0v1IrR8XdhzJggpcpEi5t+Od83Fvz8/+1/4SOt8EXwPrb8I7/52X40j+aDVCkSLkyeX4ihPB14PVwMC/YfHagA/SDKTSj83AW4Pdv+SdRuAixaj1YUitSpdsalX6ecCs9+VgfYE4VBwJieGBH0PyK6cRuJn1A8YDI0nf9e2b7v58ALlESltsIOlfz47019igwA9hZZ+B7V4Ab8NiVYHvX/Iv1ymUG4HH3f2rZlYO6KdAJAi9joeOOdA2DSq+kH7eDd65KD1qT+yOmW319WZxMP3aRlXWBW5mfYFDgW8AePqTlvx82iISUZ5aha/9NSQbsd4/wMr36db3mcWwPj8FftrtY6Wa7oV1VwMx6DUG63d9dqElMnKZA98ZaARuN7MZZjbezKo3fpGZnW9mDWbW0NjYmMPhRKLHV18KrY9Bxwv4qm/iqfX5O1jTTUAb0AKtT+Cplfk7lhSFXAo8AewD3OLuewNNwI83fpG7j3P3enevr60N+FN0kWKXXAB0ph97Mv2BZL7EB/PRr7SVp89ekR4tlwJfDCx29xczz+8nXegi8qHqC4Be6Xnm8nqID8nboazf/0D5YVC2L9Z/AumPpaQny3oO3N0/MLNFZraHu88FxgBzgosmEn2xqpPx8npIrYaykd36YHFr3B1ffyO0PgEVn8d6X45ZHIvvgA34U+6hC8i9BVqnQKw/lB8UyL+fUpLrWSjfA+7OnIHyDnBO7pFEehZLDAWGBrfDtieg6XagBZrfg8RuUPW14PZfIO5JfMWpkFwIOFSfi9V8L+xYkZJTgbv7q0B9MFFEpFuSy/joCko68ORSIjluTX0AnQuAzBWgzQ+ACnyb6EpMkaipPBZi/dLz6tYXqzo57ETZiQ0CqwQMqIDy0SEHih6thSISMRYbALX/gM6FkBiCWWXgx0g13Q3rfgNWifW/FSvfN/BjmFXAwPvwpgkQG4TVnBf4MXo6FbhIBJlVQNluedm3p9bDuv8COsDb8TWXY7X/yMuxLDEU63tVXvZdCjSFIiISUSpwEfkEi9VA758B5WD9sL7/HXYk6YKmUERkE7HqM6D6jLBjyFZoBC4iElEqcBGRiFKBi4hElApcRCSiVOAiIhGlAhcRiSgVuIhIRKnARUQiSgUuIhJRKvAupFIpXpkyi5lPv467hx1HRGQTupS+C78563944dHp4M4Xzz6Mi/6opS5FpLhoBL4ZqVSKp+/7F63rW2ltamPyHdPCjiQisgkV+GbEYjG2H1pLLB4jnogzZPfBYUcSEdmEplC6cN20K7njF38hnohzzq9PCzuOiMgmVOBd2O7TtVx2x3fDjiEi0iVNoYiIRJQKXEQkolTgIiIRpQIXEYkoFbiISETlXOBmFjezGWb2aBCBRESke4IYgV8EvBHAfkREZBvkVOBmNgQ4BhgfTBwREemuXEfgNwCXAamuXmBm55tZg5k1NDY25ng4ERH5UNYFbmbHAsvcffqWXufu49y93t3ra2trsz2ciIhsJJcR+MHA8Wa2ALgXOMLM7goklRSN9tZ2lryzlGRnMuwoIrKRrAvc3X/i7kPcvQ44DZjq7mcFlkxCt2T+Us4YegHnjfoh5426hOZ1LWFHEpEN6Dxw6dLDNz3O2hXraGtup3HRcp7720thRxKRDQRS4O7+lLsfG8S+pHj036EvZeUfL1jZt7ZPiGlEZGNaTla6dNL3j2HB7EXMfvZNjjjz8+x31OiwI4nIBlTg0qXyijIun/i9sGOISBc0By4iElEq8BylUimSSZ1iJyKFpwLPwctPvMoJfb7OMVVn8rc//D3sOCJSYlTgObj+3FtobW4j2ZFk3I/upK2lLexIIlJCVOA5SJR9/BmwxQwzCzGNiJQaFXgOfnL39xk4uD/V/ar40YTvUN6rPOxIW7Ro7nv884EXWN24JuwoIhIAnUaYg+EH7sG9i8eFHaNbZj71Oj879jfE4jHKKhKMn/17+m/XN+xYIpIDjcBLxON3TKOtuY2WdS10tHbw6tTZYUcSkRypwEvEnvsNo6KqAoBUyhk6fEjIiUQkV5pCyWhcvILFb73PHvvtSlXvyrDjBO64C44k2ZlizvNvceQ3DmeXUUPDjiQiOTJ3L9jB6uvrvaGhoWDH6645z8/l8i/9ilg8RlWfKsbNvJZJNz/OI7dOZpdRdfzsnouo7lsddkwRKVFmNt3d6zferikU4JFbJtPa1Ebz2haaVjcx6Y9PcO9/P8SK91cxY+osJl5x3ybfs3blOlYty+1sjvbWdh67bQqTJz5FR3tHTvsSkdKjKRSgbuROVFSV09bcTiqVoqp35UfndCc7kqxZvu4Tr///457k5osmAHDaj0/i61ecktVxf3L01cx9eR5gPPvgi/zy4ctzeh8iUlo0Age++sPjOPniY9l7zGe5dMKFHPvt/2DXvXemrKKMmv7VnPnzkz/x+v+9/C462jrpaOvkz1c/QGdH5zYfM5VKMeuZN2hrbqetuY2GyTODejsiUiI0AgfiiTjn/Or0T2y77qmrWLV0Nb0H1FBWXvaJP+vdv5qmNc0AVFRWEE/Et/mYsViMYaPrWPD6IsyMzxywW/ZvQERKkgq8C2bGgB36b/bPrnrocq455ybaWzu5+E/nZ30J/bVTr2DSH58gXhbnuAuOzCWuiJQgnYUiIlLkujoLRSPwEpVKpXj2wRdZv6qJw049iOo+VWFHEpFtpALPg3kz5jPzqdf57KGfYfd9h4UdZ7P++IPbeeL2aXjKeeCGRxn32nXE49s+ly8i4VGBB+yt6f/mh4f9glRnilgixu/+cQXDD9g97FibePbBF2ltSq9f/sH8ZaxauoZBgweEnEpEtoVOIwzYjCmz6WxP0tHeSUdrBzP+MSvsSJs16rARlPcqI56I0Wdgb61MKBJBGoEHbOQhe5Ioi5PsTFLWq4yRn98z7Eibdent32HPz+3K2hXrOPY/v5TVqZAiEi6dhZIHrz0zhxlTZrHX4SMY/YWRXb7ulX+8xtVn3EAqmeLSCRdy0An7FTCliERFV2ehqMAzJt/5FC8/NoNDvnIAh33twIIc8+Tab7J2Rfoy/YrKch5Zf5duyyYimwj8NEIz2wm4E9gecGCcu9+YfcTwPPfQS/zhO+NpbW7j+Uem02dgDXsf8dm8HzeZTH70OJVy3F0FLiLdlsuHmJ3AJe4+HDgAuNDMhgcTq7DmvTqf1swd5ZOdSea/trAgx710woWUV5ZTVpHg4nH/SSymz5RFpPuyHoG7+xJgSebxOjN7A9gRmBNQtoI55KTPcf91j2CxGAbsf8w+BTnuwSfuz6Pr78LdVd4iss0CmQM3szrgGWCku6/d6M/OB84H+PSnP73vu+++m/Px8mHJO0t5a/o7fOaA3dhup0FhxxER+UjePsQ0sxrgaeBqd39wS68t5g8xRUSKVV7uyGNmZcADwN1bK28REQlW1gVu6dMlbgPecPfrg4skIiLdkcsI/GDgbOAIM3s188+XA8olIiJbkctZKM8COmlZRCQkOndNRCSiVOAiIhGlAhcRiaiiL/Cmtc384bvjufIr1/D2K++EHSdQ7s4137yZoypO49yRF7NiyaqwI4lIhBR9gV9zzs38ffwUnnvoJX50xJW0rG8JO1JgXp02m2f++jzJjiSL577PxCv+kvW+5s2Yz0M3Pcb8WcV5pauIBK/ob+gw/7WFdLZ3ApDsSLJq6RoqaypDThUMdyCz+qADnsruqtg5L7zFZV+8Ck85Fotx43O/ZthedYHlFJHiVPQj8JMvPoaKqnIqa3oxbHQdO+y8XdiRAjP6CyM45MT9icVjDB62A2OvOiWr/UyfPJP2lg7aWztIdnbySpHexk1EglX0I/Djv3MUIw7ek9XL1jDqsOGBrNq3bNFy7r/uEar7VXHKpSdQWd0rgKTbLhaLcfmd3+Oyid/NaR3wEQfvSXllGW3N7cQTcUYcVHw3URaR4JXcHXmSySRnDr2AVUvXEE/E2Pc/9uJXk34cyL7XLF/LY7dNpbKmF0efO4byirJA9tsdDZNn8urUWdQfOXqLt3ETkegJ/I48UbV+VRNrlq8jlUyRSqZ486W3c97n6sY1XDrmKhbMXoSZkahIMPPp1/nFfZcEkLh76r+0F/Vf2qtgxxOR8BX9HHjQ+gzszbDRdfSqrqBXdQVHnH5Izvv8y+8eZtGb7wPpUwM7WjuYOW12zvsVEdmSkhuBmxnXTbuS5x56meq+Vex/9N4B7JNPrApTVp5g/6MLc1cfESldJTcCB6ioTI+8P/flfQK5ifCpl53IziM/Tbwszi6jhrLrvruwaO57PP+Ibl4hIvlTkgUetL6D+nDL9N/xeNu97LbvLsx75R3mvvxvrj7t9yxb2Bh2PBHpoVTgAXvvrSV0tKUvPIrFY6xYsjrcQCLSY6nAA3bGz0/+6MKjoSN2Yrd9dg47koj0UCX3IWa+7XfkaO546w+seH8Vu46uI56Ihx1JRHooFXgeDBo8gEGDB4QdQ0R6OE2hFIGWplaSncmwY4hIxKjAQ+Tu3Pid/+Wk/mP5ysBzmPP83LAjiUiEqMBD9P6/P2DyxKdIdqZoXtfCzRfdHnYkEYkQFXiIynuVZxYFB4sZlTXhrIooItGkAg9R7ZCBfPv6sfQd1Ie6ETtxyfgLwo4kIhFScsvJiohETVfLyWoELiISUSpwEZGIyqnAzewoM5trZvPMLJjb2oiISLdkXeBmFgduBo4GhgOnm9nwoIKJiMiW5TIC3x+Y5+7vuHs7cC9wQjCxRERka3Ip8B2BRRs8X5zZ9glmdr6ZNZhZQ2Oj1sYWEQlK3j/EdPdx7l7v7vW1tbX5PpyISMnIZTXC94CdNng+JLOtS9OnT19uZu9uwzEGAcuzyFbMeuJ7Ar2vqNH7ipahm9uY9YU8ZpYA3gLGkC7ul4Ez3P31bBNu5hgNmzt5Pcp64nsCva+o0fvqGbIegbt7p5l9F3gCiAMTgixvERHZspxu6ODufwf+HlAWERHZBsV+Jea4sAPkQU98T6D3FTV6Xz1AQRezEhGR4BT7CFxERLqgAhcRiaiiLPCeuEiWme1kZtPMbI6ZvW5mF4WdKUhmFjezGWb2aNhZgmJm/czsfjN708zeMLMDw84UBDO7OPMzONvM7jGzSN4KyswmmNkyM5u9wbYBZvakmb2d+do/zIz5VnQF3oMXyeoELnH34cABwIU95H196CLgjbBDBOxG4HF33xPYix7w/sxsR+D7QL27jyR9CvBp4abK2h3AURtt+zEwxd13A6ZknvdYRVfg9NBFstx9ibu/knm8jnQZbLJ2TBSZ2RDgGGB82FmCYmZ9gUOB2wDcvd3dV4caKjgJoDJzMV4V8H7IebLi7s8AKzfafAIwMfN4InBiITMVWjEWeLcWyYoyM6sD9gZeDDlKUG4ALgNSIecI0s5AI3B7ZmpovJlVhx0qV+7+HnAtsBBYAqxx98nhpgrU9u6+JPP4A2D7MMPkWzEWeI9mZjXAA8AP3H1t2HlyZWbHAsvcfXrYWQKWAPYBbnH3vYEmesBfxzNzwieQ/h/UYKDazM4KN1V+ePoc6R59nnQxFvg2L5IVFWZWRrq873b3B8POE5CDgePNbAHp6a4jzOyucCMFYjGw2N0//FvS/aQLPeq+CMx390Z37wAeBA4KOVOQlprZpwAyX5eFnCevirHAXwZ2M7Odzayc9Acsk0LOlDMzM9LzqW+4+/Vh5wmKu//E3Ye4ex3p/1ZT3T3yIzp3/wBYZGZ7ZDaNAeaEGCkoC4EDzKwq8zM5hh7w4ewGJgFjM4/HAg+HmCXvcloLJR968CJZBwNnA7PM7NXMtp9m1pOR4vQ94O7MQOId4JyQ8+TM3V80s/uBV0ifGTWDiF5+bmb3AIcDg8xsMXAF8FvgPjP7FvAucEp4CfNPl9KLiERUMU6hiIhIN6jARUQiSgUuIhJRKnARkYhSgYuIRJQKXEQkolTgIiIR9X8UiHOiYTe1HwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# data generation\n", "np.random.seed(314)\n", "\n", "data_size1 = 20\n", "x1 = np.random.randn(data_size1, 2) + np.array([2,2])\n", "y1 = [-1 for _ in range(data_size1)]\n", "\n", "data_size2 = 20\n", "x2 = np.random.randn(data_size2, 2)*2 + np.array([8,8])\n", "y2 = [1 for _ in range(data_size2)]\n", "\n", "# all sample data\n", "x = np.concatenate((x1, x2), axis=0)\n", "y = np.concatenate((y1, y2), axis=0)\n", "\n", "shuffled_index = np.random.permutation(data_size1 + data_size2)\n", "x = x[shuffled_index]\n", "y = y[shuffled_index]\n", "\n", "\n", "train_data = np.concatenate((x, y[:, np.newaxis]), axis=1)\n", "\n", "# plot data\n", "plt.scatter(train_data[:,0], train_data[:,1], c=train_data[:,2], marker='.')\n", "plt.title(\"Data\")\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "lines_to_end_of_cell_marker": 2 }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "update weight and bias: 3.7185416425430744 4.415441255489154 0.5\n", "update weight and bias: 2.63445626503829 3.324758332173668 0.0\n", "update weight and bias: 2.0207439860270506 1.893264739036497 -0.5\n", "update weight and bias: 1.2322166573021474 0.22799125733467895 -1.0\n", "update weight and bias: -0.6006428863030957 -1.6411028433624717 -1.5\n", "update weight and bias: 5.109204567082854 1.9813789454818795 -1.0\n", "update weight and bias: 4.489662780423084 -0.3506475821384758 -1.5\n", "update weight and bias: 3.3939105634317013 -1.0412656669082738 -2.0\n", "update weight and bias: 2.504805392463895 -2.232127455753232 -2.5\n", "update weight and bias: 6.223347035006969 2.183313799735922 -2.0\n", "update weight and bias: 5.140304315636703 0.7923315584800992 -2.5\n", "update weight and bias: 4.5265920366254635 -0.6391620346570717 -3.0\n", "update weight and bias: 3.6672220497312384 -1.1731993200750104 -3.5\n", "update weight and bias: 2.0065425403755413 -2.3921089617658904 -4.0\n", "update weight and bias: 4.533466246969885 1.4498630856120154 -3.5\n", "update weight and bias: 3.4493808694651005 0.3591801622965294 -4.0\n", "update weight and bias: 2.5900108825708754 -0.17485712312140933 -4.5\n", "update weight and bias: 0.9293313732151782 -1.3937667648122891 -5.0\n", "update weight and bias: 4.046427755076163 3.0054356400349302 -4.5\n", "update weight and bias: 3.2956963801872607 1.6232728971525268 -5.0\n", "update weight and bias: 2.6761545935274915 -0.7087536304678286 -5.5\n", "update weight and bias: 1.0154750841717943 -1.9276632721587084 -6.0\n", "update weight and bias: 5.37894106506809 3.555613771026151 -5.5\n", "update weight and bias: 4.7593992784083206 1.2235872434057957 -6.0\n", "update weight and bias: 4.139857491748551 -1.1084392842145596 -6.5\n", "update weight and bias: 2.479177982392854 -2.327348925905439 -7.0\n", "update weight and bias: 5.809083413002281 0.646001016876383 -6.5\n", "update weight and bias: 5.195371133991041 -0.7854925762607878 -7.0\n", "update weight and bias: 4.306265963023234 -1.9763543651057456 -7.5\n", "update weight and bias: 2.88012341642224 -2.622818885634949 -8.0\n", "update weight and bias: 6.331829427476793 1.6208587585142347 -7.5\n", "update weight and bias: 4.905686880875798 0.9743942379850314 -8.0\n", "update weight and bias: 4.371515476010486 -0.46893620150732485 -8.5\n", "update weight and bias: 2.7108359666547885 -1.6878458431982046 -9.0\n", "update weight and bias: 6.254961478724868 2.5307472883846627 -8.5\n", "update weight and bias: 4.663848346587878 0.9710594278881524 -9.0\n", "update weight and bias: 4.044306559928109 -1.360967099732203 -9.5\n", "update weight and bias: 2.6181640133271147 -2.007431620261406 -10.0\n", "update weight and bias: 5.145087719921459 1.8345404271164996 -9.5\n", "update weight and bias: 4.061002342416674 0.7438575038010136 -10.0\n", "w = [4.061002342416674, 0.7438575038010136]\n", "b = -10.0\n", "\n", "\n", "ground_truth: [ 1. -1. 1. 1. -1. -1. -1. 1. -1. 1. 1. -1. 1. -1. -1. -1. 1. 1.\n", " -1. -1. -1. -1. -1. -1. 1. 1. 1. 1. 1. -1. -1. 1. 1. -1. 1. -1.\n", " 1. 1. 1. -1.]\n", "predicted: [ 1. -1. 1. 1. 1. -1. 1. 1. -1. 1. 1. -1. 1. -1. -1. -1. 1. 1.\n", " 1. -1. 1. -1. -1. -1. 1. 1. 1. 1. 1. 1. 1. 1. 1. -1. 1. -1.\n", " 1. 1. 1. -1.]\n" ] } ], "source": [ "import random\n", "import numpy as np\n", "\n", "# 符号函数\n", "def sign(v):\n", " if v > 0: return 1\n", " else: return -1\n", " \n", "def perceptron_train(train_data, eta=0.5, n_iter=100):\n", " weight = [0, 0] # 权重\n", " bias = 0 # 偏置量\n", " learning_rate = eta # 学习速率\n", "\n", " train_num = n_iter # 迭代次数\n", "\n", " for i in range(train_num):\n", " # select one data\n", " ti = np.random.randint(len(train_data))\n", " (x1, x2, y) = train_data[ti]\n", " \n", " y_pred = sign(weight[0] * x1 + weight[1] * x2 + bias) \n", " \n", " if y * y_pred <= 0: # 判断误分类点\n", " weight[0] = weight[0] + learning_rate * y * x1 # 更新权重\n", " weight[1] = weight[1] + learning_rate * y * x2\n", " bias = bias + learning_rate * y # 更新偏置量\n", " print(\"update weight and bias: \", weight[0], weight[1], bias)\n", "\n", " return weight, bias\n", "\n", "def perceptron_pred(data, w, b):\n", " y_pred = []\n", " for d in data:\n", " x1, x2, y = d\n", " yi = sign(w[0]*x1 + w[1]*x2 + b)\n", " y_pred.append(yi)\n", " \n", " return np.array(y_pred, dtype=float)\n", "\n", "\n", "# do training\n", "w, b = perceptron_train(train_data)\n", "print(\"w = \", w)\n", "print(\"b = \", b)\n", "\n", "# predict \n", "y_pred = perceptron_pred(train_data, w, b)\n", "\n", "print(\"\\n\")\n", "print(\"ground_truth: \", train_data[:, 2])\n", "print(\"predicted: \", y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reference\n", "* [感知机(Python实现)](http://www.cnblogs.com/kaituorensheng/p/3561091.html)\n", "* [Programming a Perceptron in Python](https://blog.dbrgn.ch/2013/3/26/perceptrons-in-python/)\n", "* [损失函数、风险函数、经验风险最小化、结构风险最小化](https://blog.csdn.net/zhzhx1204/article/details/70163099)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 2 }