{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# kNN 分类算法\n", "\n", "\n", "K最近邻(k-Nearest Neighbor,kNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路是:***如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别***。\n", "\n", "kNN方法虽然从原理上也依赖于[极限定理](https://baike.baidu.com/item/%E6%9E%81%E9%99%90%E5%AE%9A%E7%90%86/13672616),但在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。\n", "\n", "kNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的`k`个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成正比(组合函数)。\n", "\n", "kNN可以说是一种最直接的用来分类未知数据的方法。基本通过下面这张图跟文字说明就可以明白kNN是干什么的\n", "![knn](images/knn.png)\n", "\n", "简单来说,kNN可以看成:**有那么一堆你已经知道分类的数据,然后当一个新数据进入的时候,就开始跟训练数据里的每个点求距离,然后挑选这个训练数据最近的K个点,看看这几个点属于什么类型,然后用少数服从多数的原则,给新数据归类**。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "该算法存在的问题:\n", "1. 当样本不平衡时,如一个类的样本数量很大,而其他类样本数量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大数量类的样本占多数。在这种情况下可能会产生误判的结果。因此我们需要减少数量对运行结果的影响。可以采用权值的方法(和该样本距离小的邻居权值大)来改进。\n", "2. 计算量较大,因为对每一个待分类的数据都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。目前常用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. 算法步骤:\n", "\n", "输入:\n", "* 训练数据: $T=\\{(x_1,y_1),(x_2,y_2), ..., (x_N,y_N)\\}$, 其中$x_i \\in X=R^n$,$y_i \\in Y = {0, 1, ..., K-1}$,i=1,2...N\n", "* 用户输入数据:$x_u$\n", "\n", "输出:预测的最优类别$y_{pred}$\n", "\n", "\n", "1. 准备数据;\n", "2. 计算测试数据与各个训练数据之间的**距离**;\n", "3. 按照距离的递增关系进行排序;\n", "4. 选取距离最小的`k`个点;\n", "5. 确定前`k`个点所在类别的出现频率;\n", "6. 返回前`k`个点中出现频率最高的类别作为测试数据的预测分类。\n", "\n", "\n", "\n", "**深入思考:**\n", "* 上述的处理过程,难点有哪些?\n", "* 每个处理步骤如何用程序语言来描述?\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1.1 距离计算\n", "\n", "要度量空间中点距离的话,有好几种度量方式,比如常见的曼哈顿距离计算、欧式距离计算等等。不过通常 kNN 算法中使用的是欧式距离。这里只是简单说一下,拿二维平面为例,二维空间两个点的欧式距离计算公式如下:\n", "$$\n", "d = \\sqrt{(x_2 - x_1)^2 + (y_2 - y_1)^2}\n", "$$\n", "\n", "在二维空间其实就是计算 $(x_1,y_1)$ 和 $(x_2, y_2)$ 的距离。拓展到多维空间,则公式变成:\n", "$$\n", "d(p, q) = \\sqrt{ (p_1-q_1)^2 + (p_1-q_1)^2 + ... + (p_n-q_n)^2 } = \\sqrt{ \\sum_{i=1,n} (p_i-q_i)^2}\n", "$$\n", "\n", "kNN 算法最简单粗暴的就是将 `预测点` 与 `所有点` 距离进行计算,然后保存并排序,选出前面 k 个值看看哪些类别比较多。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 2. 机器学习的思维模型\n", "\n", "针对kNN方法从原理、算法、到实现,可以得出机器学习的思维模型,在给定问题的情况下,是如何思考并解决机器学习问题。\n", "\n", "![machine learning - methodology](images/ml_methodology.png)\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上图是机器学习的经典的流程\n", "* 问题:我们需要解决的问题是什么?\n", "* 核心思想: 通过什么手段解决问题?\n", "* 数学理论: 如何构建数学模型,使用什么数学方法?\n", "* 算法: 如何将数学理论、处理流程转化成计算机可以实现的程序?\n", "* 编程: 如何把算法变成可以计算机执行的程序?\n", "* 测试: 如何使用训练、测试数据来验证算法\n", "* 深入思考:所采用的方法能够取得什么效果,存在什么问题,如何改进?\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. 生成数据" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAEFCAYAAADDkQ0WAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAioUlEQVR4nO3df5Rc5X3f8fdXu6tdhC1AQfwINgLqgLFTp7XWjS0gcSSD3WMQAnoaA0a4mFK7oSc1qdMARuW4x5xaaXqMY3pcjsE1NsYIghEyoRQtboP4UbOikHNCIDFoDVIIEiWAQUawq2//eGYyV6N7Z2fu3Dv33rmf1zlzdmZ2duaZmZ3vfeb7fJ/nMXdHRESqY0HRDRARkd4ocIuIVIwCt4hIxShwi4hUjAK3iEjFKHBLJZjZ4ZHzo/Pc9jQzG2u77iozOzjmth81s+PN7Dtm9gkz+7iZHWJma8zsSwn3f7KZLUz5VET6psAtpWdmJwAPmNkhjateMLOjzeyrZjbRCMqXNm57MPBd4KjI3x8FnAu8HnP3RwC/B7wFjAL/FRgBfh2YTWjSfwIWmdkpZvaCmW2JnB43sxv7ftIiHShwS+m5+9PAdcA/bly1B/gb4AzAG5f3NH73L4BvAgeZ2cVmNg1sIQTjzWb2gJn9aeTufwT8Z2Bv4/LngVeAU4AxMzvdzM40s+PMbKTRk59tPO4ssNHdT26egIsJBwGR3HT8yilSNDM7EDgYuBV4sXH1W+4+a2Z7gTlC0HUz+2XgXwIfA24CNgD/BHgIOMfdd8Q8xBSwGHgvsAJ4A1gO/ENCwL8I+GXg3zZ+rgfeB9wDfA9YbWbvBd4JHAY8A9yd2QsgEkM9bim7FcAPgGeB8cZ1exNuexjwNiGo/pzQm/4p8G7gbjO7M+ZvTgeeAx4G/hjYBvw/4C/c/Q8a9/FDd3/E3be4+wrgceA04P8Cd7n7R4HfAf5H4/wfmZk+W5Ib/XNJqbn7fe5+CiG4dkxBuPvjwEmEXPa/Bv4W+CtgNfA52nLWjUHO/whcDcwAjwB/CZwDHGZm7wCOBLZH/uYAYBEhF/6rhB73FkJ65vTG+Qcb7RDJhVIlUhnuntTTBsDMDLgP2AF8AXgS+EfAfyGkMt5tZg8AT7r7vwLOJgxCrgIOBY4Htrr7x83sD4DzgA8BdzXu/yRCwF4KfJYQuG9w96vMbBK42N0/l+mTFomhwC2VYWbHAi8RXx2Cu7uZfYfQM/8rQipjAyHnvRP4d8CXaPWg7wA2AQ8QAvSphLQKwI3A/26c/4vGz4fc/dfM7H8RUiqfBp40s6eAV4EDzexx4OvursoSyY1SJVIVBtwOfIqQzoDw/2vNU6O2+lFCnvufEXrF1wD/nNb/+m8TBhhx91lCT3uaEMSvBWYbA6IvEwYqtzR7+r7/UpofJ6RX/tbdf93dfxX4BqAab8mVAreUnpkdBLwL+DLwCcKAIcAYIUiOEb49vgP4I0Jt9g8IPe8PAf+G0JN+mxDIf9PMPtK4j78B/pRQufIV4CPAJCHlcj9wipld3pz0Y2afA04APglYo1Txvc06buDf5/MqiLQoVSJV8Bqh17yTEDRvB3D3X4G/H2QcdfeXgY82/8jMfoXQ8343oQ78Und/28x+091fMrN/ANzbOH3S3V9oBN9NwO+5+21mdh3hILCZ0Js/p3H7V4H/0HioZxo13JjZxcCBub0SIoQeQ9FtEOmamS2Yb5Cyx/tb6O5vtV032kijiJSSAreISMUoxy0iUjEK3CIiFZP74OShhx7qxxxzTN4PIyIyVLZu3fqSuy+N+13ugfuYY45heno674cRERkqZvazpN8pVSIiUjEK3CIiFaPALSJSMQrcIiIVo8AtUmFzc0W3QIqgwC1SUTMzcNhh8LPE2gMZVgrc0hP18Mpj3Tp45RW46qqiWyKDpsAtXVMPrzxmZuC222Dv3vBT70m9KHBL19TDK49162C2sX7h3Jzek7pR4JauqIdXHs33ohm4335b70ndKHBLV9TDK4/oe9Gk96Recl+Pe3Jy0rVWSbXNzMCJJ8Kbb7aum5iAp56CZcsKa1Ytzc3BgQeCOyyIdLv27g2XX38dRkaKa59kx8y2uvtk3O+0dZnMq1MP76abimlTXY2MwM6dsGfP/r8bH1fQrgv1uKUj9fDKb25O78EwUo9bUlMPr9xmZmD5cnjsMaWt6kSBW+a1eHHRLZAk0RJNpa3qQ1UlIhWlEs36UuAWiVGFqf0q0awvBW6RNlWY2q9JOPWmwC3SpgpT+zUJp95UDii11l5KF51sVNZJRirRrIdO5YDqcUttxaVE8sgbZ50vb5Zobt8Ozz3XOm3fDi++qKBdBwrcUlvtKZE88sZ55csXL4alS/c/qXSzHroK3GY2Zmab2q77gpltzqdZIvmKK6XLI29chXy5VM+8gdvMDgC2AqdGrlsGfCa/Zonkqz0l8qUvwYYNIUc8MdE6mYXAnibdoTprycu8gdvdf+HuHwC2R66+Frg86W/M7BIzmzaz6V27dmXQTJHsxKVEbr8dpqezzRurzlry0nOO28zOA54Anky6jbtf7+6T7j65dOnSftonkrmklMj69dnljVVnLXlKMzh5OrAK+AGw3MwuzbZJIvmZm8s+JRJHddaSp67ruM3sp+7+nsjlY4BvufvHOv2d6rilbF57LXm1wyyqMlRnLVnQsq4iEXmXzGkpXMlb14E72ttuXJ4BOva2RepK9dSSJ03AERGpGAVuEZGKUeAWEakYBW4RkYpR4BYRqRgFbhGRilHgFhGpGAVuEZGKUeAWEakYBW4RkYpR4BYRqRgFbhGRilHgFhmgrHd8l3pS4BbpQhYBN68d36WlLgdGBW6ReWQVcLXje77qdGBU4BaZRxYBVzu+569OB0YFbpEOsgq42vE9X3U7MCpwS211kw/NIuAOYsf3uuR2k9TtwKjALbXUTT40q4Cb947vdcrtxhnEgbFsFLillrrJh2YRcOfmYMOGsLv7xETrZBaCSxY95TrlduPkfWAsI3P3XB9gcnLSp6enc30MkV7MzMCJJ8Kbb4Yg+tRTsGzZvreZm4MDDwT3EHSb9u4Nl19/vfvd2l97LXnH9343Fe7muQyzLN+nsjGzre4+Gfe7rnZ5N7Mx4A53P8PMDPjvwAnATuBsd5/t9Pci85mbG9wHLC4fetNN+95mZAR27kwOuL20Nc8d37t5Llkb5Hs1nyzfpyqZN1ViZgcAW4FTG1edBIy6+4eBxcBp+TVP6mCQOdpe8qGLF8PSpfuf8gzEvSgit1vGfHrZ36c8zBu43f0X7v4BYHvjqheBaxvn38qrYVIfg8zRDlM+tIjnUvd8ell0neM2s5+6+3sil88CfhdY5e5zbbe9BLgE4Oijj17+szIdnqVUBpmjHaZ86CCfSzM1Uvd8+qD1neOOucPVhKB9RnvQBnD364HrIQxOpnkMyUeZ8pMw2BztMOVDB/VcZmZg+XJ47LFi8ukSr+cet5kdAdwGfMLd35jv71RVUh7RD2EZekrRHlyTenLlsnYt3HwznHkm3HOP3qtB6tTjTlPHfSFwJHCvmW0xs4v6ap0MTNnyk8OUb85bETMjo9PI77orDH62t0nvVTFUx10TZctPDipHW7bUUBrdflPK+rmuXQu33NI6uC5YAAsXtn5fxbGBKsm6xy0VVLa1HJo52u3b4bnnWqft2+HFF7MJBGUsXUujm29KWT/X9lJDgLExePDBfN4r6Y0Cdw2UdS2HvOtvy5YaSqPbVe+yfq5xaay9e+FrX6tPrXSZKXDXQB1zyWVc5jNNnrqbb0pZP9dBrK8i/VHgHnJ1/RD2mhrK+3VIk8ro9ptS1mmwQaSxpD8K3EOujh/CXlNDeebCmweENKmMbr4p5ZUGi6axlixRaqRsFLhroExrOQyih99raiivXHjzgLBlS++pjG6/Ka1bl2+Z3rAM8A4blQPKwAxiAlCvZYZ5lkk2J68cdRQ8/3y4bmwMPvWp7mYczrccbPO5vvVW6/kuXJhtmV7zOZx/fvpZksNQklkElQNKKQyiyqPX1FBeZZLRAcNm0IbeUhnzfVMaGYFHH23VVjfL9bJKg2Ux6Kkeez4UuGUgBlnl0W1qKM8yybh0TVM/B4j2VNMf/mHrumi5XhZpsCwOasNQkllGCtwyEHn0bPvNl+dVJhk3eaVpZCR9RU9777WfA898j53FQa2MJZnDQoFbcpdHz7bfr+B5lkl26m2PjsLDD6dLZbT3XtMeeLp57bI4qJWtJDOtUrbL3XM9LV++3KXeLrjAfXTUPQyhhdPYWLi+n/tcsKC/+3j1VfedO/c/vfpq+vucnXUfH3dfuNB9ZGTf5wzhdUjT5m3b3Ccmwn1MTLg/80zrcSYmWqfm5dnZ5Pua77WLPode7zuuvc3TxIT7zEzy7ZcsSf59UYpsFzDtCXFVgVtylUUQaNcexMr2YX/1VfcXXgjPcWwsPP/mKe3zjh78mge9NAeebl+7fg9qvR6sszgQ56HIdnUK3CoHlNxlvct5dNW6XsrrBi2r553luuWDeO3KVJLZj6Lb1akcUD1uqZRev4J3K03Pf1CySjXl9drF6aXHHvdtogyKbhcdetwanJRKyaMSpDlY9+yzfTUtF1kOog5ysbEylGT2I6t25TWwqcAtlRENYs1JJ+Pj8UGs14D2d38XvhYXHTDaZbXWTFkXG5vvYFLWdnUjz8lHCtxSGSMj8MILIWideWYIQqtX7x/EevnANHtW7mHq+GWX5fkM0slirZkyLjY238HkmWeKmXWZ1UEuz8lHGpysuSqtI9Fc62TjRjj11ORBo17W11i7Fr7//daHcWQkBIwyDI7lIen9Lur/oNMA7qWX9r9OSh7t6uaAmcXApgYnJVZZa2eTNEuzjj02edCol1LBuME6cD/77LyfSTGS3u/268swUFv2ks/5ZDGwieq4JU5Za2fjJAXZ9g929AMz32SXCy6InyQzMlK9QNGNpPc7en1ZDuZFV3T0I6vqHQVu2U/VejRxJXHtpXFxH5jx8fjn1pwYZBZ/n5/+9MCfYq6S3u/26886q/iD+SDLFvOQVflm34EbGAM2Nc5PAD8CngC+SyNPnnRS4C6nKvVotm0LQbY9uI6P7zsL8/zz9//AmCU/t5dfjp/dODaWflZnWSW93+3fUBYsKD5Q5rFEwqBkOVO4U+Ced3DSzA4A/g9wvLtPmNnFwKS7f87MfgR83d3/Z9Lfa3CyfLKciTcIZ58NP/zhvteNjsJZZ8F117UuH3lk+JhDqBBpGh+HN96IH3zLelbnoHUzqJj0ft93X2uQt10WsyrTDHj2OutykG3rVlb/U31tpODuv3D3DwDbG1etBO5rnL8f+K3umyJlUKVd3+fmQhUJ7FuetWABbNrU2g/xkENa5W5nnhkCObQCfNKHtEzbuvWq27LHpPd77drkVQz7nQiTtoZ5EGWLeW/uMIj/qTR13L8EvNo4/xqwpP0GZnaJmU2b2fSuXbv6aZ9kLK5GdXw8fFA3bCjfEpbPP7//Di9JH+TFi0PPetOmVkCanYU77yzfxJosdFMnnFSTDLBtW+v6pBLBtAfzfmqY8w58w7C5Q5rA/RJwUOP8QY3L+3D369190t0nly5d2k/7JGNxPZrVq8PkgjVrylfTHe0tRnd4SfogV+nbRD+63aQgqQe7Y0c4KG7fHgL4yEg4MI6Ph1M/syrLvIFCmdvWk6Tkd/sJ+Gnj50XAf2ucvxv4WKe/0+BkuZW5uqTX6oI8lpAdtG7bmPXgcpZrk5d54LvMbWtHFuWAkcA9Tqgq+XNUVVJ5Zf5HTlNdkMfmCGmkXWe8mxrqMpfLqW3ZySRwpz0pcJdXmf+Rq9x7TjuJpdsJUWUul1PbstMpcGutkhqLLqrfVKaNCapaqtfLWilN3a5tMahyufbH7OY+i2hbt8rctiSdygEVuGuqiv/IVZB2caHoYlfzHTwHeUBrLuz12GPdPY8yH2zL3LY4CtwSq2r/yFnIexW8NFuDzczACSfsO2moLBOi0nx7kGz0NQFHhleVJ5+kkffEi7S7pqxbF24bVYYSxqEpnRtCCtxSG3lPvEhTQz43B7fe2pqqD8m7+gxa9PmU4UAiLQrcMpTaA17evce0u6aMjIQp+dEp+nG7+mRtvgNCWfeClECBW4ZOXEok795j2jU2ZmbCWizRKfqbNsHu3fmlrLpJGdVlBmpVKXDL0GlPiQyq95hmzKCIADlfyqisGwtLi6pKZKjEleNddVU569WLKMnstlyxjhVHZdOpqmR00I0RyVN7SuTKK+H221u9x6Zmrvvb3y6uXr2ZXkkKkGnb1ankMS5lFHfwUnAuN/W4pRSyqK9O2jDg0Ufh8MP3v/0w9h7bJ8xEX9eqbaBRd6rjllLLqr46KV+8fv1w1qvH5Zqj+ev211UDjsNDgVsKl0V9dd0G1OIOdu0lj5dd1npdq/76lL19g6ZUiRQq7doeceo0oBY3FT063X50NAx6zs21XtdDDqnm69PreinDQqkSKa0s6qubvbG6TOGPm0zUXvI4O9t6XZqva1Vfn2HYaixrCtxSmCzqq/Nef6SM4g52cfnrpirPetR6KfEUuKUwWQyW1a03lnSwu/XWVv56QcynuqqDkFovJZ5y3FKILCafZJkfr4qkzS/OOQe+/vXwui5b1npdzcJtep3Uk/fyt92oe/mictxSOtG1PbZsCR/IBx/sbXGlKvfG0lRJdKoMufNOWLIEjjgCdu1q7eLe7ZopUWVJP6l8MZl63FK4frf6aqpKb6yfKolBVM6UYfME7dCkHXCkxPrZ6quM6490o+jA2CkNUqb0U53KO+MoVSKllSbdUeXJJEVXScyXBilT+qmq5YuDkCpwm9mBZrbRzB40s/VZN0rqIW05YNq1r8ug6MDYqQpHmydUR9oe9/nAI+5+EvB+MzsxwzZJTfQz+FTF3ljRgXG+3r4GA6sjbeDeAywyMwMmgLeivzSzS8xs2symd+3a1W8bZQhVOd2RVtGBsVNvv47vR5WlGpw0szHgYeAdwJS7/07SbTU4KUnqNPhUdJVEN1U4dXo/qiCPjRQuB77p7t8ys1vMbIW7P5S+iVJHdQoGeW2a0K2k3v6VV8L3vhcu1+n9qLq0gfudQPPYvYfQ8xYpvSJnBBYVGNvTINHrb74ZvvxlOO64Ytom6aTNcV8HfN7MHgYOAKaya5JIPsoyI3DQkqpw1qwJwfzqq4tuofQqVeB29xl3P8ndP+Luv+3uGrqQ0qvbglRR7VU4b7wBmzZp1b2q0gQcqYWiJ76UTdH15NIfBW6phflK4eqk6Hpy6Z8Ctwy9ToGqjnnvfurJ63aQKysFbhl6nQJV3fLe/Uy0qeNBrqy0OqAMtU4TX5qbDOzZU/xKeIOUdqJN0asa1o1WB5Ta6rQg1Zo1+2+oWwdp1nnR4G65KHDL0IsGqiVLWuVwGzdqgK5bqkIpFwVuqYW5uX1ztEUv+FQlqkIpHwVuGXrNgH3ZZWEg8sors1sJrw5VFjrIlY8GJ2XorV0bFlJasCAEnIkJePRROPzw/W/by0p4/ewdWRVFr2pYZ3msDihSCc2v+e77DkSuX99/ZUS0lHBYqyyKXtVQ4ilVIkNt3bqQk43KIkdbpyqLKu42NOyGP3BPTYWk5pQWMKybZnCNy0P3m6NVlYUUabgD99QUrF4N11wTfip410pcbxvC1/t+tuRSlYUUbbgD9/33w+7d4fzu3eGy1EJzavfISMjFNk9jYzA6GoJs2h3hVWUhRRvuwL1yJSxaFM4vWhQuSy1EZ0w+/3zrtGNHuP6II9LlaLWprpTBcFeVrFoFd90VetorV4bLUht5DJ6pykLKYLgDN4RgrYAtGVI1hRRtuFMlIiJDSIF7UFSWKCIZUeAeBJUlikiGUgduM/t9M3vAzO4xs4VZNmroqCxRRDKUKnCb2XHA+939FOAe4F2ZtmrYqCxRRDKUtqpkFXCImf0Z8CLwx9k1KSNTU+UpA1RZ4sDMzakkT4Zf2lTJUmCXu/8Gobd9cvSXZnaJmU2b2fSuXbv6bWPvyphTXrUKvvIVBe0caTNbqYu0gfs14OnG+WeBo6K/dPfr3X3S3SeXLl3aT/vSUU65luq2Y7vUV9rAvRX4UOP8ewjBuzzqmFOueblhnZZZFUkVuN39YeAlM3sUeNrdf5Jts7rQKVA1c8pXXBF+ljk9kUXALWNqaMC0zKrUirvnelq+fLlnbvNm90WL3CH83Lw5+8cYhKyexxVXhPtonq64Itt2lty2be4TE/u+BBMT7jMzRbdMJD1g2hPiajUn4Nxww3DksLPKxdcxNRShZValbsofuNtTCVNTcMcdrd+Pj1c3UGUVcKuUGsqYllmVOir36oDN3O3u3fC1r7VqoaNrap5zTnUDVZb13TVdBVHLrEodlTtwx6USVq4MQXz37tBLveiiQpvYt5oG3CxpmVWpm3KnSuJSCTVOC4iIQNl73EmphKx6qWWaFi8i0iULVSf5mZyc9Onp6VwfI5Vo/nzRouHvvesgJVIpZrbV3SfjflfuVEme6jQtXhN0RIZKfQN3nWqf63SQEqmB4Qvc3U4hr9MgZ50OUiI1MFw57va89dVXh+XilNdVjlukYjrluMtdVdKr9pTAFVeEudDNyTvRgNVvIKtaIFS9uMjQGK5USTQlMDraWsCiPa/b72CdBvtEpEDDFbijeetrrknO6/Y7WNfN38etsVLG9bLL2i4RSZa0bGBWp1yWde3W5s1hidP25VLTLqfavL/16zv/ffv9z3f7ogzL8rgiQ4gOy7oOV467XVJeN83iTr0MfLb3yO+8c/8eehnyzXHfHMrQLhHpqPqpkriv+t18/Y/bvLfT37UHuVdeSd78d+XKsDQdwNhYWGe0eblM5XgqExSppqSueFanXFMlcV/1+0mDNP9ufNz93HP3/dte7nfz5nAf0S1Z4u6zDJLSSdK32dmiWyBVxtCmStp7wRdeCMcfn+7rf/S+9uyBW26BjRtbZYS9pFfa1wxv3uexx5YvFaEywVzMzMDy5fDYY7BsWdGtkWFT7VRJ9Ks+wI4d8OMfh/QEJH/9j0uJtN8X7F8xEpde6aZdndoiQ2ndupBN0/Zpkofqz5ycmgo97R07Wte9732wZk18z7jTqoBTU3DjjfAnfxJ6yJ1WDZxvAk7z9wcf3P/szapN9qm5mRk48UR4880wvPHUU+p1S+86zZysdo67af36ffPJ69cn37abHdHny/sOooyu29LDTn/bS7uU687MBRe4j46Gt2xsLFwW6RUdctx9BWXgC8DmTrcZWB33+vXuK1Z0Dtru+wbd0dHubt8e0LoJ/v1ob2Mvj5XmoKJ67sxs2+Y+MbHvWzYx4T4zU3TLpGo6Be7UOW4zWwZ8Ju3fZ+6LX4QHHww/kzRTDuef35oSf/XVyWWDcVPbp6Zg27Z8y/uiA6Wzs6Gt3T5WmlmhWvY1M+vWtVZaaJqbU65bstXP4OS1wOVxvzCzS8xs2symd+3a1cdDZCgahL/97eR1TKLaA9qNN4b7uOUW2Ls35NKvvjr7vHN7ffU113S//Gya2mzVc2dibg42bIAFC0Juu3kyg9tuC78XyUKqckAzOw94Angy7vfufj1wPYTByVQty3pALq4XOzvbOVC17yjv3rqPt9+GJ58MgfuDH8w2eKeZ2dnP3/bzePL3RkZg5879K0EhfEEbGRl8m2Q4paoqMbPvA0cTAv8JwFXu/o2426aqKsljP8i0a3VHDyDQuo+oK64IZYLtt1cAFJGUMl+P293Pa9zxMcC3koJ2almvodEMpmk2VmifoHLXXXDDDXDHHa2SwWZQjx4c4tYAT2qXgryI9KCcMyfbUxT95Fyz7r03A3lc0O3lgNNrkO+GDgQitdBX4Hb3GeBj2TQlIsuca14r4MVNFe/lgJPHt4r2A0HzcRTIRYZKeae8dzu9fD5ZVUx0u+JgtxsQZ13JkVQBo116RIZOeQN3P6JBNovd3JO2KosL5s0DDnQO9FnvMt9+IIhWwKg2W2S4JM3Myeo08B1w8pgFGDdTstPjFDUTMTrLU7MhRSqNPGZOllbaWYCdUiFxaY1Oj1PUTMRoeinrHr2IlMbwBe40ueP5dm2PC4KdHqcsMxGzGicQkVIpZzlgL9pL4NJUpHRT4REt+WteTnqcss1EVJmgyHBJyqFkdRr41mVp7uPcc1tbjSWtGBjdjmx8vDo5Y+W6RSqJ2mxd1mstdLT2eWwsLCbRXDEQ9p1lecMNrUUo9uwJl6vQe9VO7iJDp9o57oMPjl/ytNud36NB7e23W8u37d4d8tnRnLfZvo/dfrmsypJvF5HsJHXFszrllipJ2hChl53f23d2HxsL583iy//6SZUUucOMdrcRqRyGMlXSvkzrK6/sf320FC8uXRAdRDz44NAjhxCum5q91FWr4O670w3y5bEuSS+0k7vIUKluqiQpBRB3fad0QbNk7pVXQrokasWKfYNstLyuPfXSqQ5cO8yISJaSuuJZnXKvKolLAcRd380GwM1UyHzpkPbUy3wb+qqyQ0R6RIdUSaqNFHqRaiOFokxNhcWZ3OGzn01OL1x5ZRi4bFqxAh56qHU5urFC9L5VSy0iXcp8I4Wh1W0uuH351jVr4PHHOy/nqjyziGREgTuN5qBms3f+wQ+Wa6akiAw1Be5+3Hln6GVv3BgCd3t6REQkB9WtKimaKkVEpCAK3GlpRqKIFESpkrTKtgKgiNSGAnc/VCkiIgVIlSqx4Dtm9oiZ3WVmOgCIiAxI2hz3ScCou38YWAycll2TRESkk7SB+0Xg2sb5tzJqi4iIdCFVisPd/xrAzM4CFgL3Rn9vZpcAlwAcffTRfTZRuqZp9SK1kHqtEjNbDVwGnOHuP0+6XenXKhmWYBddOnbRIu3sLlJxndYqSTs4eQTwReCTnYJ26c23u3uVaEKQSG2kzXFfCBwJ3GtmW8zsogzbNDjDFOw0IUikNtLmuL8KfDXjtgxe+yp/VQ52mhAkUhv1rr8etmCnCUEitVDvwA0KdiJSOVpkSkSkYhS4RUQqRoFbRKRiFLhFRCpGgVtEpGIUuEVEKib1WiVdP4DZLuBnjYuHAi/l+oD5UvuLV/XnUPX2Q/WfQ1Xav8zdl8b9IvfAvc+DmU0nLZpSBWp/8ar+HKrefqj+c6h6+0GpEhGRylHgFhGpmEEH7usH/HhZU/uLV/XnUPX2Q/WfQ9XbP9gct4iI9E+pEhGRilHgFhGpmIEEbgu+Y2aPmNldZlbJ5WTN7AtmtrnodqRlZr9vZg+Y2T1mtrDo9vTCzA40s41m9qCZrS+6Pb0wszEz29Q4P2FmPzKzJ8zsu2ZmRbevG23PoXKf52j7I9dV9vM8qB73ScCou38YWAycNqDHzYyZLQM+U3Q70jKz44D3u/spwD3AuwpuUq/OBx5x95OA95vZiUU3qBtmdgCwFTi1cdWnge3u/mvAIZHrSyvmOVTq8xzT/sp/ngcVuF8Erm2cf2tAj5m1a4HLi25EH1YBh5jZnwGnANsKbk+v9gCLGj3UCSryf+Tuv3D3DwDbG1etBO5rnL8f+K1CGtaDmOdQqc9zTPuh4p/ngQRud/9rd/+JmZ0FLATuHcTjZsXMzgOeAJ4sui19WArscvffIPS2Ty64Pb36PvBPgb8EnnL3ZwpuT1q/BLzaOP8asKTAtqSiz3PxBjY4aWargd8FznD3uUE9bkZOJ/RYfwAsN7NLC25PGq8BTzfOPwscVWBb0rgc+Ka7vxdYYmYrim5QSi8BBzXOH0Q11szYjz7PxRrU4OQRwBeBT7r7zwfxmFly9/Pc/WTgU8BWd/9G0W1KYSvwocb59xCCd5W8E3izcX4P8I4C29KPKVo54ZXAjwtsSyr6PBdvUD3uC4EjgXvNbIuZXTSgx5UGd38YeMnMHgWedvefFN2mHl0HfN7MHgYOIATAKroZOMrM/hx4mWo+D32eC6aZkyIiFaMJOCIiFaPALSJSMQrcIiIVo8AtIlIxCtwiIhWjwC0iUjH/HyazQGSZ+5e+AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import matplotlib as mpl\n", "\n", "# 生成模拟数据\n", "np.random.seed(314)\n", "\n", "data_size1 = 100\n", "x1 = np.random.randn(data_size1, 2) + np.array([4,4])\n", "y1 = [0 for _ in range(data_size1)]\n", "\n", "data_size2 = 100\n", "x2 = np.random.randn(data_size2, 2)*2 + np.array([10,10])\n", "y2 = [1 for _ in range(data_size2)]\n", "\n", "\n", "# 合并生成全部数据\n", "x = np.concatenate((x1, x2), axis=0)\n", "y = np.concatenate((y1, y2), axis=0)\n", "\n", "data_size_all = data_size1 + data_size2\n", "shuffled_index = np.random.permutation(data_size_all)\n", "x = x[shuffled_index]\n", "y = y[shuffled_index]\n", "\n", "# 分割训练与测试数据\n", "split_index = int(data_size_all*0.7)\n", "x_train = x[:split_index]\n", "y_train = y[:split_index]\n", "x_test = x[split_index:]\n", "y_test = y[split_index:]\n", "\n", "\n", "# 绘制结果\n", "for i in range(split_index):\n", " if y_train[i] == 0:\n", " plt.scatter(x_train[i,0],x_train[i,1], s=38, c = 'r', marker='.')\n", " else:\n", " plt.scatter(x_train[i,0],x_train[i,1], s=38, c = 'b', marker='^') \n", "#plt.rcParams['figure.figsize']=(12.0, 8.0)\n", "mpl.rcParams['font.family'] = 'SimHei'\n", "plt.title(\"训练数据\")\n", "plt.savefig(\"fig-res-knn-traindata.pdf\")\n", "plt.show()\n", "\n", "for i in range(data_size_all - split_index):\n", " if y_test[i] == 0:\n", " plt.scatter(x_test[i,0],x_test[i,1], s=38, c = 'r', marker='.')\n", " else:\n", " plt.scatter(x_test[i,0],x_test[i,1], s=38, c = 'b', marker='^')\n", "#plt.rcParams['figure.figsize']=(12.0, 8.0)\n", "mpl.rcParams['font.family'] = 'SimHei'\n", "plt.title(\"测试数据\")\n", "plt.savefig(\"fig-res-knn-testdata.pdf\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 最简单的程序实现" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0]\n" ] } ], "source": [ "import numpy as np\n", "import operator\n", "\n", "def knn_distance(v1, v2):\n", " \"\"\"计算两个多维向量的距离\"\"\"\n", " return np.sum(np.square(v1-v2))\n", "\n", "def knn_vote(ys):\n", " \"\"\"根据ys的类别,挑选类别最多一类作为输出\"\"\"\n", " vote_dict = {}\n", " for y in ys:\n", " if y not in vote_dict.keys():\n", " vote_dict[y] = 1\n", " else:\n", " vote_dict[y] += 1\n", " \n", " method = 1\n", " \n", " # 方法1 - 使用排序的方法\n", " if method == 1:\n", " sorted_vote_dict = sorted(vote_dict.items(), \\\n", " #key=operator.itemgetter(1), \\\n", " key=lambda x:x[1], \\\n", " reverse=True)\n", " return sorted_vote_dict[0][0]\n", " \n", " # 方法2 - 使用循环遍历找到类别最多的一类\n", " if method == 2:\n", " maxv = maxk = 0 \n", " for y in np.unique(ys):\n", " if maxv < vote_dict[y]:\n", " maxv = vote_dict[y]\n", " maxk = y\n", " return maxk\n", " \n", "def knn_predict(x, train_x, train_y, k=3):\n", " \"\"\"\n", " 针对给定的数据进行分类\n", " 参数\n", " x - 输入的待分类样本\n", " train_x - 训练数据的样本\n", " train_y - 训练数据的标签\n", " k - 最近邻的样本个数\n", " \"\"\"\n", " dist_arr = [knn_distance(x, train_x[j]) for j in range(len(train_x))]\n", " sorted_index = np.argsort(dist_arr)\n", " top_k_index = sorted_index[:k]\n", " ys=train_y[top_k_index]\n", " return knn_vote(ys)\n", " \n", "\n", "# 对每个样本进行分类\n", "y_train_est = [knn_predict(x_train[i], x_train, y_train, k=5) for i in range(len(x_train))]\n", "print(y_train_est)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 绘制结果\n", "for i in range(len(y_train_est)):\n", " if y_train_est[i] == 0:\n", " plt.scatter(x_train[i,0],x_train[i,1], s=38, c = 'r', marker='.')\n", " else:\n", " plt.scatter(x_train[i,0],x_train[i,1], s=38, c = 'b', marker='^') \n", "#plt.rcParams['figure.figsize']=(12.0, 8.0)\n", "mpl.rcParams['font.family'] = 'SimHei'\n", "plt.title(\"Train Results\")\n", "plt.savefig(\"fig-res-knn-train-res.pdf\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train Accuracy: 100.000000%\n" ] } ], "source": [ "# 计算训练数据的精度\n", "n_correct = 0\n", "for i in range(len(x_train)):\n", " if y_train_est[i] == y_train[i]:\n", " n_correct += 1\n", "accuracy = n_correct / len(x_train) * 100.0\n", "print(\"Train Accuracy: %f%%\" % accuracy)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test Accuracy: 96.666667%\n", "58 60\n" ] } ], "source": [ "# 计算测试数据的精度\n", "y_test_est = [knn_predict(x_test[i], x_train, y_train, 3) for i in range(len(x_test))]\n", "n_correct = 0\n", "for i in range(len(x_test)):\n", " if y_test_est[i] == y_test[i]:\n", " n_correct += 1\n", "accuracy = n_correct / len(x_test) * 100.0\n", "print(\"Test Accuracy: %f%%\" % accuracy)\n", "print(n_correct, len(x_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. 通过类实现kNN程序" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np\n", "import operator\n", "\n", "class KNN(object):\n", " def __init__(self, k=3):\n", " \"\"\"对象构造函数,参数为:\n", " k - 近邻个数\"\"\"\n", " self.k = k\n", "\n", " def fit(self, x, y):\n", " \"\"\"拟合给定的数据,参数为:\n", " x - 样本的特征;y - 样本的标签\"\"\"\n", " self.x = x\n", " self.y = y\n", " return self\n", "\n", " def _square_distance(self, v1, v2):\n", " \"\"\"计算两个样本点的特征空间距离,参数为:\n", " v1 - 样本点1;v2 - 样本点2\"\"\"\n", " return np.sum(np.square(v1-v2))\n", "\n", " def _vote(self, ys):\n", " \"\"\"投票算法,参数为:\n", " ys - k个近邻样本的类别\"\"\"\n", " ys_unique = np.unique(ys)\n", " vote_dict = {}\n", " for y in ys:\n", " if y not in vote_dict.keys():\n", " vote_dict[y] = 1\n", " else:\n", " vote_dict[y] += 1\n", " sorted_vote_dict = sorted(vote_dict.items(), key=operator.itemgetter(1), reverse=True)\n", " return sorted_vote_dict[0][0]\n", "\n", " def predict(self, x):\n", " \n", " y_pred = []\n", " for i in range(len(x)):\n", " dist_arr = [self._square_distance(x[i], self.x[j]) for j in range(len(self.x))]\n", " sorted_index = np.argsort(dist_arr)\n", " top_k_index = sorted_index[:self.k]\n", " y_pred.append(self._vote(ys=self.y[top_k_index]))\n", " return np.array(y_pred)\n", "\n", " def score(self, y_true=None, y_pred=None):\n", " if y_true is None and y_pred is None:\n", " y_pred = self.predict(self.x)\n", " y_true = self.y\n", " score = 0.0\n", " for i in range(len(y_true)):\n", " if y_true[i] == y_pred[i]:\n", " score += 1\n", " score /= len(y_true)\n", " return score" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train accuracy: 100.000000 %\n", "test accuracy: 96.666667 %\n" ] } ], "source": [ "# data preprocessing\n", "#x_train = (x_train - np.min(x_train, axis=0)) / (np.max(x_train, axis=0) - np.min(x_train, axis=0))\n", "#x_test = (x_test - np.min(x_test, axis=0)) / (np.max(x_test, axis=0) - np.min(x_test, axis=0))\n", "\n", "# knn classifier\n", "clf = KNN(k=3)\n", "train_acc = clf.fit(x_train, y_train).score() * 100.0\n", "\n", "y_test_pred = clf.predict(x_test)\n", "test_acc = clf.score(y_test, y_test_pred) * 100.0\n", "\n", "print('train accuracy: %f %%' % train_acc)\n", "print('test accuracy: %f %%' % test_acc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. sklearn program" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Feature dimensions: (1797, 64)\n", "Label dimensions: (1797,)\n" ] } ], "source": [ "#% matplotlib inline\n", "\n", "import matplotlib.pyplot as plt\n", "from sklearn import datasets, neighbors, linear_model\n", "\n", "# load data\n", "digits = datasets.load_digits()\n", "X_digits = digits.data\n", "y_digits = digits.target\n", "\n", "print(\"Feature dimensions: \", X_digits.shape)\n", "print(\"Label dimensions: \", y_digits.shape)\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot sample images\n", "nplot = 10\n", "fig, axes = plt.subplots(nrows=1, ncols=nplot)\n", "\n", "for i in range(nplot):\n", " img = X_digits[i].reshape(8, 8)\n", " axes[i].imshow(img)\n", " axes[i].set_title(y_digits[i])\n", "fig.set_size_inches(16,9)\n", "fig.savefig('fig-res-digits.pdf')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# split train / test data\n", "n_samples = len(X_digits)\n", "n_train = int(0.4 * n_samples)\n", "\n", "X_train = X_digits[:n_train]\n", "y_train = y_digits[:n_train]\n", "X_test = X_digits[n_train:]\n", "y_test = y_digits[n_train:]\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KNN score: 0.953661\n", "LogisticRegression score: 0.927711\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/bushuhui/anaconda3/envs/dl/lib/python3.7/site-packages/sklearn/linear_model/_logistic.py:765: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n" ] } ], "source": [ "# do KNN classification\n", "knn = neighbors.KNeighborsClassifier()\n", "logistic = linear_model.LogisticRegression()\n", "\n", "print('KNN score: %f' % knn.fit(X_train, y_train).score(X_test, y_test))\n", "print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_test, y_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. 深入思考\n", "\n", "* 如果输入的数据非常多,怎么快速进行距离计算?\n", " - [kd-tree](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KDTree.html#sklearn.neighbors.KDTree) \n", " - Fast Library for Approximate Nearest Neighbors (FLANN)\n", " - [PyNNDescent for fast Approximate Nearest Neighbors](https://pynndescent.readthedocs.io/en/latest/)\n", "* 如何选择最好的`k`?\n", " - https://zhuanlan.zhihu.com/p/143092725\n", "* kNN存在的问题?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 参考资料\n", "* [Digits Classification Exercise](http://scikit-learn.org/stable/auto_examples/exercises/plot_digits_classification_exercise.html)\n", "* [knn算法的原理与实现](https://zhuanlan.zhihu.com/p/36549000)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.4" } }, "nbformat": 4, "nbformat_minor": 2 }