You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

knn_classification.ipynb 126 kB


  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# kNN Classification\n",
  8. "\n",
  9. "\n",
  10. "K最近邻(k-Nearest Neighbor,kNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路是:***如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别***。KNN方法虽然从原理上也依赖于极限定理,但在类别决策时,只与极少量的相邻样本有关。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。\n",
  11. "\n",
  12. "kNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,就可以得到该样本的属性。更有用的方法是将不同距离的邻居对该样本产生的影响给予不同的权值(weight),如权值与距离成正比(组合函数)。\n",
  13. "\n",
  14. "该算法在分类时有个主要的不足是,当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数。 该算法只计算“最近的”邻居样本,某一类的样本数量很大,那么或者这类样本并不接近目标样本,或者这类样本很靠近目标样本。无论怎样,数量并不能影响运行结果。可以采用权值的方法(和该样本距离小的邻居权值大)来改进。该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。目前常用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分。\n",
  15. "\n",
  16. "k-NN可以说是一种最直接的用来分类未知数据的方法。基本通过下面这张图跟文字说明就可以明白K-NN是干什么的\n",
  17. "![knn](images/knn.png)\n",
  18. "\n",
  19. "简单来说,k-NN可以看成:**有那么一堆你已经知道分类的数据,然后当一个新数据进入的时候,就开始跟训练数据里的每个点求距离,然后挑离这个训练数据最近的K个点看看这几个点属于什么类型,然后用少数服从多数的原则,给新数据归类**。\n"
  20. ]
  21. },
  22. {
  23. "cell_type": "markdown",
  24. "metadata": {},
  25. "source": [
  26. "## 算法步骤:(FIXME)\n",
  27. "\n",
  28. "* step.1---初始化距离为最大值\n",
  29. "* step.2---计算未知样本和每个训练样本的距离dist\n",
  30. "* step.3---得到目前K个最临近样本中的最大距离maxdist\n",
  31. "* step.4---如果dist小于maxdist,则将该训练样本作为K-最近邻样本\n",
  32. "* step.5---重复步骤2、3、4,直到未知样本和所有训练样本的距离都算完\n",
  33. "* step.6---统计K-最近邻样本中每个类标号出现的次数\n",
  34. "* step.7---选择出现频率最大的类标号作为未知样本的类标号"
  35. ]
  36. },
  37. {
  38. "cell_type": "markdown",
  39. "metadata": {},
  40. "source": [
  41. "## Generate sample data (FIXME)"
  42. ]
  43. },
  44. {
  45. "cell_type": "code",
  46. "execution_count": 1,
  47. "metadata": {},
  48. "outputs": [
  49. {
  50. "data": {
  51. "image/png": "\n",
  52. "text/plain": [
  53. "<Figure size 432x288 with 1 Axes>"
  54. ]
  55. },
  56. "metadata": {
  57. "needs_background": "light"
  58. },
  59. "output_type": "display_data"
  60. }
  61. ],
  62. "source": [
  63. "%matplotlib inline\n",
  64. "import numpy as np\n",
  65. "import matplotlib.pyplot as plt\n",
  66. "\n",
  67. "# generate sample data\n",
  68. "n = 100\n",
  69. "x_1_1 = 10 + (np.random.rand(n, 1)*2 -1)*4\n",
  70. "x_1_2 = 15 + (np.random.rand(n, 1)*2 -1)*4\n",
  71. "x1 = np.concatenate((x_1_1, x_1_2), axis=1)\n",
  72. "y1 = np.zeros([n, 1])\n",
  73. "\n",
  74. "x_2_1 = 20 + (np.random.rand(n, 1)*2 -1)*4\n",
  75. "x_2_2 = 5 + (np.random.rand(n, 1)*2 -1)*4\n",
  76. "x2 = np.concatenate((x_2_1, x_2_2), axis=1)\n",
  77. "y2 = np.ones([n, 1])\n",
  78. "\n",
  79. "x = np.concatenate((x1, x2), axis=0)\n",
  80. "y = np.concatenate((y1, y2), axis=0)\n",
  81. "y = y.flatten()\n",
  82. "\n",
  83. "# draw sample data\n",
  84. "plt.scatter(x[:,0], x[:,1], c=y)\n",
  85. "plt.show()\n",
  86. "\n"
  87. ]
  88. },
  89. {
  90. "cell_type": "code",
  91. "execution_count": 2,
  92. "metadata": {},
  93. "outputs": [
  94. {
  95. "name": "stdout",
  96. "output_type": "stream",
  97. "text": [
  98. "[0.0, 0.0, 0.0, 0.0, 0.0]\n",
  99. "[1.0, 1.0, 1.0, 1.0, 1.0]\n"
  100. ]
  101. }
  102. ],
  103. "source": [
  104. "# generate test data\n",
  105. "x_test = np.array([[12.5, 10.0], [15.4, 8.0]])\n",
  106. "\n",
  107. "k = 5\n",
  108. "# do knn\n",
  109. "for s in x_test:\n",
  110. " d = np.sum((s - x)**2, axis=1)\n",
  111. " idx = np.argsort(d)\n",
  112. " ys_5 = list(y[idx[:5]]) \n",
  113. " print(ys_5)\n",
  114. "\n",
  115. " # TODO: you need to implement the vote algorithm"
  116. ]
  117. },
  118. {
  119. "cell_type": "markdown",
  120. "metadata": {},
  121. "source": [
  122. "## Program"
  123. ]
  124. },
  125. {
  126. "cell_type": "code",
  127. "execution_count": 5,
  128. "metadata": {},
  129. "outputs": [],
  130. "source": [
  131. "import numpy as np\n",
  132. "import operator\n",
  133. "\n",
  134. "class KNN(object):\n",
  135. "\n",
  136. " def __init__(self, k=3):\n",
  137. " self.k = k\n",
  138. "\n",
  139. " def fit(self, x, y):\n",
  140. " self.x = x\n",
  141. " self.y = y\n",
  142. "\n",
  143. " def _square_distance(self, v1, v2):\n",
  144. " return np.sum(np.square(v1-v2))\n",
  145. "\n",
  146. " def _vote(self, ys):\n",
  147. " ys_unique = np.unique(ys)\n",
  148. " vote_dict = {}\n",
  149. " for y in ys:\n",
  150. " if y not in vote_dict.keys():\n",
  151. " vote_dict[y] = 1\n",
  152. " else:\n",
  153. " vote_dict[y] += 1\n",
  154. " sorted_vote_dict = sorted(vote_dict.items(), key=operator.itemgetter(1), reverse=True)\n",
  155. " return sorted_vote_dict[0][0]\n",
  156. "\n",
  157. " def predict(self, x):\n",
  158. " y_pred = []\n",
  159. " for i in range(len(x)):\n",
  160. " dist_arr = [self._square_distance(x[i], self.x[j]) for j in range(len(self.x))]\n",
  161. " sorted_index = np.argsort(dist_arr)\n",
  162. " top_k_index = sorted_index[:self.k]\n",
  163. " y_pred.append(self._vote(ys=self.y[top_k_index]))\n",
  164. " return np.array(y_pred)\n",
  165. "\n",
  166. " def score(self, y_true=None, y_pred=None):\n",
  167. " if y_true is None and y_pred is None:\n",
  168. " y_pred = self.predict(self.x)\n",
  169. " y_true = self.y\n",
  170. " score = 0.0\n",
  171. " for i in range(len(y_true)):\n",
  172. " if y_true[i] == y_pred[i]:\n",
  173. " score += 1\n",
  174. " score /= len(y_true)\n",
  175. " return score"
  176. ]
  177. },
  178. {
  179. "cell_type": "code",
  180. "execution_count": 3,
  181. "metadata": {},
  182. "outputs": [
  183. {
  184. "data": {
  185. "image/png": "\n",
  186. "text/plain": [
  187. "<Figure size 432x288 with 1 Axes>"
  188. ]
  189. },
  190. "metadata": {
  191. "needs_background": "light"
  192. },
  193. "output_type": "display_data"
  194. },
  195. {
  196. "data": {
  197. "image/png": "\n",
  198. "text/plain": [
  199. "<Figure size 432x288 with 1 Axes>"
  200. ]
  201. },
  202. "metadata": {
  203. "needs_background": "light"
  204. },
  205. "output_type": "display_data"
  206. }
  207. ],
  208. "source": [
  209. "%matplotlib inline\n",
  210. "\n",
  211. "import numpy as np\n",
  212. "import matplotlib.pyplot as plt\n",
  213. "\n",
  214. "# data generation\n",
  215. "np.random.seed(314)\n",
  216. "data_size_1 = 300\n",
  217. "x1_1 = np.random.normal(loc=5.0, scale=1.0, size=data_size_1)\n",
  218. "x2_1 = np.random.normal(loc=4.0, scale=1.0, size=data_size_1)\n",
  219. "y_1 = [0 for _ in range(data_size_1)]\n",
  220. "\n",
  221. "data_size_2 = 400\n",
  222. "x1_2 = np.random.normal(loc=10.0, scale=2.0, size=data_size_2)\n",
  223. "x2_2 = np.random.normal(loc=8.0, scale=2.0, size=data_size_2)\n",
  224. "y_2 = [1 for _ in range(data_size_2)]\n",
  225. "\n",
  226. "x1 = np.concatenate((x1_1, x1_2), axis=0)\n",
  227. "x2 = np.concatenate((x2_1, x2_2), axis=0)\n",
  228. "x = np.hstack((x1.reshape(-1,1), x2.reshape(-1,1)))\n",
  229. "y = np.concatenate((y_1, y_2), axis=0)\n",
  230. "\n",
  231. "data_size_all = data_size_1+data_size_2\n",
  232. "shuffled_index = np.random.permutation(data_size_all)\n",
  233. "x = x[shuffled_index]\n",
  234. "y = y[shuffled_index]\n",
  235. "\n",
  236. "split_index = int(data_size_all*0.7)\n",
  237. "x_train = x[:split_index]\n",
  238. "y_train = y[:split_index]\n",
  239. "x_test = x[split_index:]\n",
  240. "y_test = y[split_index:]\n",
  241. "\n",
  242. "# visualize data\n",
  243. "plt.scatter(x_train[:,0], x_train[:,1], c=y_train, marker='.')\n",
  244. "plt.title(\"train data\")\n",
  245. "plt.show()\n",
  246. "plt.scatter(x_test[:,0], x_test[:,1], c=y_test, marker='.')\n",
  247. "plt.title(\"test data\")\n",
  248. "plt.show()\n",
  249. "\n"
  250. ]
  251. },
  252. {
  253. "cell_type": "code",
  254. "execution_count": 6,
  255. "metadata": {},
  256. "outputs": [
  257. {
  258. "name": "stdout",
  259. "output_type": "stream",
  260. "text": [
  261. "train accuracy: 0.986\n",
  262. "test accuracy: 0.957\n"
  263. ]
  264. }
  265. ],
  266. "source": [
  267. "# data preprocessing\n",
  268. "x_train = (x_train - np.min(x_train, axis=0)) / (np.max(x_train, axis=0) - np.min(x_train, axis=0))\n",
  269. "x_test = (x_test - np.min(x_test, axis=0)) / (np.max(x_test, axis=0) - np.min(x_test, axis=0))\n",
  270. "\n",
  271. "# knn classifier\n",
  272. "clf = KNN(k=3)\n",
  273. "clf.fit(x_train, y_train)\n",
  274. "\n",
  275. "print('train accuracy: {:.3}'.format(clf.score()))\n",
  276. "\n",
  277. "y_test_pred = clf.predict(x_test)\n",
  278. "print('test accuracy: {:.3}'.format(clf.score(y_test, y_test_pred)))"
  279. ]
  280. },
  281. {
  282. "cell_type": "markdown",
  283. "metadata": {},
  284. "source": [
  285. "## sklearn program"
  286. ]
  287. },
  288. {
  289. "cell_type": "code",
  290. "execution_count": 7,
  291. "metadata": {},
  292. "outputs": [
  293. {
  294. "name": "stdout",
  295. "output_type": "stream",
  296. "text": [
  297. "Feature dimensions: (1797, 64)\n",
  298. "Label dimensions: (1797,)\n"
  299. ]
  300. }
  301. ],
  302. "source": [
  303. "% matplotlib inline\n",
  304. "\n",
  305. "import matplotlib.pyplot as plt\n",
  306. "from sklearn import datasets, neighbors, linear_model\n",
  307. "\n",
  308. "# load data\n",
  309. "digits = datasets.load_digits()\n",
  310. "X_digits = digits.data\n",
  311. "y_digits = digits.target\n",
  312. "\n",
  313. "print(\"Feature dimensions: \", X_digits.shape)\n",
  314. "print(\"Label dimensions: \", y_digits.shape)\n"
  315. ]
  316. },
  317. {
  318. "cell_type": "code",
  319. "execution_count": 8,
  320. "metadata": {},
  321. "outputs": [
  322. {
  323. "data": {
  324. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAABLCAYAAABQtG2+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAFI5JREFUeJztnXmcFNW1x79nNgZmYAQGB9kExBEhUVRC1ERxeUZM3guo+USjiXlGJYGHL0bNxjMfSWIkLyaicSGSIHGLS94n6Iu7LwqK4jIRA0EZIovsy7DOvvV5f1RPV912ehime7q6M+f7+fRn7u1bXfc3t27dqjp17j2iqhiGYRjZQ07YAgzDMIzDwwZuwzCMLMMGbsMwjCzDBm7DMIwswwZuwzCMLMMGbsMwjCzDBm7DMIwsIyMGbhEZICKLRaRWRD4SkctC0DBLRCpEpFFEfp/u+gM6eonIwmg7VIvIeyJyQUhaHhaR7SJyUETWisjVYegI6DlWRBpE5OGQ6l8Srb8m+qkMQ0dUy6Ui8kH0nFknImekuf6auE+riNyVTg0BLSNF5FkR2SciO0TkbhHJC0HH8SLysogcEJEPReTC7qorIwZu4B6gCSgDLgfmi8j4NGvYBtwC3J/meuPJAzYDk4ES4CbgCREZGYKWucBIVe0HfBG4RUROCUFHG/cA74RYP8AsVS2Ofo4LQ4CInAf8N3Al0Bc4E1ifTg2BNigGBgP1wB/TqSHAvcAu4ChgAt65MzOdAqIXiqeAp4EBwHTgYREp7476Qh+4RaQIuBj4karWqOoy4H+Br6VTh6r+SVWfBPaks952dNSq6hxV3aiqEVV9GtgApH3AVNXVqtrYlo1+jkm3DvDuMIH9wF/CqD/D+DHwE1V9M9pHtqrq1hD1XIw3cL4WUv2jgCdUtUFVdwDPA+m+8RsLDAHmqWqrqr4MvE43jWOhD9xAOdCiqmsD3/2N9Dd8RiIiZXhttDqk+u8VkTpgDbAdeDYEDf2AnwDXp7vudpgrIlUi8rqInJXuykUkF5gIDIo+jm+JmgZ6p1tLgK8DD2p462fcAVwqIn1EZChwAd7gHTYCfKI7dpwJA3cxcDDuuwN4j4A9GhHJBx4BHlDVNWFoUNWZeMfiDOBPQGPHv+gWfgosVNUtIdQd5PvAaGAosAD4s4ik+wmkDMgHvoR3TCYAJ+GZ1NKOiByNZ5p4IIz6o7yKd6N3ENgCVABPpllDJd5Tx3dFJF9EPofXLn26o7JMGLhrgH5x3/UDqkPQkjGISA7wEJ7tf1aYWqKPfsuAYcCMdNYtIhOAfwHmpbPe9lDVt1S1WlUbVfUBvEfhz6dZRn30712qul1Vq4DbQ9DRxteAZaq6IYzKo+fJ83g3FUVAKdAf7x1A2lDVZmAa8AVgB3AD8ATehSTlZMLAvRbIE5FjA9+dSEimgUxARARYiHd3dXG0U2QCeaTfxn0WMBLYJCI7gBuBi0Xk3TTraA/FexxOX4Wq+/AGg6BZIswlPq8g3LvtAcAI4O7oBXUPsIgQLmSqulJVJ6vqQFU9H+/p7O3uqCv0gVtVa/Gulj8RkSIR+QwwFe9uM22ISJ6IFAK5QK6IFIbhUhRlPnA88G+qWn+ojbsDETky6nJWLCK5InI+8BXS/3JwAd7FYkL08xvgGeD8dIoQkSNE5Py2fiEil+N5c4RhS10EXBs9Rv2B7+B5M6QVETkdz2wUljcJ0SeODcCM6HE5As/mvjLdWkTkhGj/6CMiN+J5ufy+WypT1dA/eFfNJ4FaYBNwWQga5uB7TrR95oSg4+ho3Q14ZqS2z+Vp1jEIWIrnyXEQWAVckwF9ZQ7wcAj1DsJzRayOtsmbwHkhtUE+ngvcfrzH8l8DhSHouA94KAP6xARgCbAPqMIzUZSFoOO2qIYa4DlgTHfVJdEKDcMwjCwhdFOJYRiGcXjYwG0YhpFl2MBtGIaRZXRq4BaRKSJSGZ2p9YPuFmU6TIfpMB3/rDpSwSFfTkan2K4FzsPzH30H+Iqqvp/oNwXSSwsparespdT9fvDgvbH01tojnLLCLb77sqpS07KXPhQj5FBHNYUUkUsuDdTSpI0f86ftSMfHth3rX8N65bQ4Zft3+pM4VZWGvdu7TUfkCH+7kcN3OmU7mv15SqrKvsr9KdPRNNT9/hMDd8fSeyO5TtmeSn/b7j4ukud7ZEZGu/cZsrbJ14FSy8GU6Qj2B4Da5oJYOn9dQ0K9qdbRka74flr9vl+Wah1NQ9zvNdAlSvu6c+WOyvPbR1VZVdnEyFF55OVB5Wqld24/ciWP+tZqmiL1h6WjcaQ7EXF4sT9+bD4w0Ckr3O5P8lVValpT10+1vMDJB49F05pIu785FIl0tEdn/JQnAR+q6noAEXkMz8864cBdSBGflnPbLau6+DQn/90bHoulf/TXqU5Z+fXbY+l9TTv4x55lnBxdvXJDdAb4KBnLW9q+a3FHOuIZ8oA/OB/bZ5dT9uTt58TSNbs2svuZJ7pNR905n46lF95xu1M2d/uUWHr3ql28dXVFynRsuNY9Lm9/fX4s/Vh1f6fsocmTYunuPi65pUfG0vX3ustxFJz3USy9X/ewnvdTpiPYHwDe3joilh52ceK5YanW0ZGu+H669AS/fVKtY9M3T3fyTSX+4HTVua84ZbNL/dVul1fU8/3bdvO7h71B9YvjvHYcXXwKy6vad//uSMfamyc6+V+c4Y8fNzz9VafsuJ/7Cybub9rBP/a+nrL2aLr3aCc/sq9/Adl2atcmfSfS0R6dMZUMxVtmtI0t0e8cRGS6eOtZVzR3w3IWja01FOJ3zEJ608jH56Z0t47mugMZoaNuV11G6MiU49JIvenIQB1bd7QyeIh/e16YW0xDpDbtOhoitRnRHqkiZS8nVXWBqk5U1Yn59ErVbk2H6TAdpqPH6TgUnTGVbAWGB/LDot91iaBpBODSvvti6TuOqHHKnnn3hVh6eUU9F8woouoL3iN9w31r6EXqVrLcWD0gll40wl1W+Ldn+sFFGocU0fyKf6VuoD4pHZHJJzn51+65L5ZeG7dCydSBK2LpyjG1rCI5HWvn+yaPuee4x+UTd/rr0P/92/c6ZXedMTKWbq6ChpffS0pHR2yYMSaWbvq7azscg28q6UVvGpJsjyDBtoa4PrHN3fbJ2uJYuvLdWn755dTp2PfvrgnrhRG+CeuYx7/llI3hzVg61e0RT8EB/57vuZvPcspemjk2lj5Yv42dG5Yyd7u3QkHdgXcoAFrrd6Hq2ug7w1njEgcd+tW/uoGRnjrNP7dyVuWx+erk2iN3vB8345XxjyfeMK5/3FrlxtsImrS6SmfuuN8BjhWRUSJSAFyKF+ggrXxqQiGNB6poPLiHSGsLO9nMII5KtwwKRg2jnhrqtZaIRkLTMeaEPhmho3jA8IzQ0Y/+GaEjU45LprRH3+MGU735INXbqmltbg1Nx8DjSzOiPVLFIe+4VbVFRGYBL+AtwHS/qqZ95b68PGHYZy9i/bMLUFWGM4xiKUm3DCQ3l+OYwApeQ1GGMDIUHbl5khE6JCcz2iNHcjhOw9eRKcclU9pDcnOYdOOp/N9/vohGlDJGhNMeeTkZcVxSRadWv1PVZ0ki8knLOX7UrUv7vueUXTDl0li6ZKUbK+DLy9w3ui1Tj6Fs6vcAGDUjudUS400U95XfHci5LkD9VrmuP6VyFKUpulqvn+ba0YKPVQv/crZTtu6S3zj5+TImKR1j5/vxKx768SSn7Kalj8bS8V4lxX98y82nsD1yy4508l+7yH/T/vgitz8EH10ByjiOMs4CoHV1cnF83693379PK/L3t7bZfbn2Xysvd/JHD95NGSd6Ona6nh+Hy7TrX05YNvrJjl+epbKfjpjzRsKyD+ed6uSvKnPP42W/KOd0vNCLrZJceyx53z3mb5ck9va56yN34carLrqeCUwDoM9itw93hubSxDERrtzkm1ODHkgAPzvhKSe/lDEki82cNAzDyDJs4DYMw8gybOA2DMPIMtIS4aVhoF/NTbs+6ZRFViaOgfvOqtRGydo0x5/99dSVtzll5fmJpxwPfXGPk29Noabg7C6Axzf5dtznrnM1nr36MidfEHCH6wpO258w1ikLuml+eb1rW84b7Hablh3u1PxkCLr/AdxRsjiWXjrPdaP64H53Fl3OAV/XmO8kp+OlnW57BGcDxveVyCr3JVfrztS9ux/X2/W8Db4DyVm6In7zlFJ3oT+Ld9uZiWdiP3fRrzrcz+OX+f1n8LzkbNxjHnDPvpcefSSWvvLNM5yy95vKnHzftftj6a6cw/lrEntB75zq981JT21yysYVxJ8fZuM2DMPocdjAbRiGkWWkx1TS378+PLLcnQlW3kEQ5LySJiffcqAgwZadI+jSdN38C52yZ1e8mPB38W5AyV7tgi5vlT8Y7ZRddW7ihWZ6f9VdWyGVJpt4k9UXTvZj8Z70fNxUsLjwuCumDImlu2I2Cc4O/GC6O0tz/PLpsfQwXBPEhim/c/In3jaTVBFcwArgjAu/GUtXneiulhiv+Xh8HR250XWG+Mfsp/b4bqyb5rhmx1F/jDPpJekSGTQtjJjproh4X/kfEv7uquuud/KDFyfXBkEaBiQeA+JnPH/+vEucfLLtEXTtjJ8NGRw/Rj1/tVP2w6PcEyboxtpVTXbHbRiGkWXYwG0YhpFldMpUIiIbgWq8p/MWVZ3Y8S+6h62z55JT2AtyhH3a0Ok1jFPNMn2WXPIQBCGnx+tYsutB8iQfQUAjoemo/N1PycnvheTk8JHWh6Zj/byfklPg6dgeoo5M6R+mI/Ucjo37bFWt6kolhfv8Vd0+9cl1TtmBoJjBrvvOJeP+6uRvi8DQq2aQW1TMqB8u74qULrHr5LjVvJbAKUymQLq27OMHc/0psRum/CbhdpNm3+jk++/8+P+cjI6OCNqqgzZsgD33u0EGmkuWMHD2teQWF1HehaUIeh3w+0f8dPLVp/nuXreudO2K8eTWtHBa6TQKcnonPbU6nuAU6VI+3cGWoLnKoJuuIbdvEeXfqEiq3v85cLKTD9pxb73I/R9nT3ftpUUjCzjplJkUFBR1yXUwaH8tOM8tK9/mu0ROmj3DKeu/OLX9NLg8RXD1THBXSCwc4QYwuPxRt+2XnpzPpGO+QUFen6Tt3fEr/L0y+cpYunypW+/593/byY+8w48uFd+uncVMJYZhGFlGZ++4FXhRRBS4T1UXxG8gItOB6QCFJF6MJSlE2LZwAYiQr0cyTEa3s0kadAAreA0UhjLadAjsuvO3IEIfLQtNhwhU7P0zgjBEh4fYHsKuXy4MvT0Q4b2VixCEoTq4x/dTASo2/iHaP4aE2h7J0tmB+7OqulVEjgReEpE1qvpqcIPoYL4AoJ8M6DgCcRcZ+q1Z5JWU0FJTzZZb5lGkfekvg5xt0qFjImdTKL1p0gbe5bUer6Psxpnk9S+h9WANW753Z2g6Jg24kMLcYhpb66jYvTi89pj9Lb89vn1XaDpOOekaevUqoamphvfeuLfH99NJo6+gML8fjS21VKxZFJqOVNDZZV23Rv/uEpHFeAGEX+34Vz79Kn1L9s3DnnbKrpju+3zmT9tNRxz7cz8+cQ5DOMhe+jOog190D4Xi2bcKpJBBevg6gtN2b53o2m2DU6vfvnW+U3b25W4w5fpHhsRinAxatC6p9ghGwwEY8rI/xTnohw/w4Dg3iPG0/TOAJvJKChjUheMStB9fu/gzTlnQvnnPg3c7ZUEfb4BhVatppY486JKOIPGRZ4J2+DHfTxgnG4CRy9qiKZWwPUkdD/3JfYEWtGPHT8v/Usm7Tn7rJW3zBXox6I3kdKyNW15gbfPrsXTpc+57q/j5BcmeL8Gp5vHvQIJLRjSPdZfinf2oa8deOKNtmeT+DLouteNH8B1CfFu9cO6dTj7o597VZSsOaeMWkSIR6duWBj4H/L1LtSVBU10LLerF8mrVFvaykyLSvxB6pLkxI3S01jfT2uxNimhtbgxNR11dhEi9ty50pKEpvOPS0JQRx6W2LuIfl5bw2qOuLkKkwdMRaQyvf7RqZpy3tRnSHqmiM3fcZcBiEWnb/g+q+nzHP0k9tXsbqeAtUFCUwQynVAanWwYttTVUsCR0HU37aql8xrsDVY0wlKNC0bG3KsK2Ob/1dLRGGBGSjpb9tRlxXHbubmXVq95MSo1EGBJSe+zZHWH7r+/xMpHwjksjDaxkeUYcl0xoj1TRmdBl6yEazqOLBKdTXzL/Bqfsphv8SCt3rHMfC9+ZEJxa3I9TpYu+M+0QH5nk7NW+GeKV8W7EipbP+qaeHPI4dVFyOoKPVR25FbXctNctC+oaD6Nu992MRiXpdpa/353Gfe0tjyXYEqa94bp/nbk5ENUo8SJyXSK/qi6Wjl+Vb8DDxYFcMaNT2D92n+lGao6fXh9k/HI3As7pBwNT85Nsj1HzP3TzI/zp1PGP4N9c664eeUa5HwA7Z2dyKwleM9GdTv7Vm31X1fbcVNvoI8WcSnLHJXiuxv+Pr6zwz4l4M0r8aprnRPwlI5J1F403hwSDGE/u47bVf1wxy8n3WXr40XfiMXdAwzCMLMMGbsMwjCzDBm7DMIwsQ1RT76ooIruBWqBLU+TjKO3Efo5W1Y/59ZiOjNbxUSf3YTpMxz+Djs5oaVdHu6hqt3yAikzYj+nITB22D9tHT9pHKvejqmYqMQzDyDZs4DYMw8gyunPg/thCVCHtx3Sk9vep3I/tw/bRU/aRyv10z8tJwzAMo/swU4lhGEaWYQO3YRhGltEtA7eITBGRShH5UER+kMR+NorIKhF5T0QOezEO02E6TIfpyHYd7ZIqv8KAr2IusA4YDRQAfwPGdXFfG4FS02E6TIfp6Ik6En264457EvChqq5X1SbgMWDqIX7THZgO02E6TEe262iX7hi4hwKbA/kt0e+6Qlusy79GY8GZDtNhOkxHT9LRLp2NORkWh4x1aTpMh+kwHT1NR3fccW8Fhgfyw6LfHTYaiHUJtMW6NB2mw3SYjp6iI+FOU/rBu4tfD4zCN+qP78J+ioC+gfQbwBTTYTpMh+noKToSfVJuKlHVFhGZBbyA92b2flVdfYiftUdSsS5Nh+kwHaYj23Ukwqa8G4ZhZBk2c9IwDCPLsIHbMAwjy7CB2zAMI8uwgdswDCPLsIHbMAwjy7CB2zAMI8uwgdswDCPL+H+2ihC0591JagAAAABJRU5ErkJggg==\n",
  325. "text/plain": [
  326. "<Figure size 432x288 with 10 Axes>"
  327. ]
  328. },
  329. "metadata": {
  330. "needs_background": "light"
  331. },
  332. "output_type": "display_data"
  333. }
  334. ],
  335. "source": [
  336. "# plot sample images\n",
  337. "nplot = 10\n",
  338. "fig, axes = plt.subplots(nrows=1, ncols=nplot)\n",
  339. "\n",
  340. "for i in range(nplot):\n",
  341. " img = X_digits[i].reshape(8, 8)\n",
  342. " axes[i].imshow(img)\n",
  343. " axes[i].set_title(y_digits[i])\n"
  344. ]
  345. },
  346. {
  347. "cell_type": "code",
  348. "execution_count": 9,
  349. "metadata": {},
  350. "outputs": [],
  351. "source": [
  352. "# split train / test data\n",
  353. "n_samples = len(X_digits)\n",
  354. "n_train = int(0.4 * n_samples)\n",
  355. "\n",
  356. "X_train = X_digits[:n_train]\n",
  357. "y_train = y_digits[:n_train]\n",
  358. "X_test = X_digits[n_train:]\n",
  359. "y_test = y_digits[n_train:]\n"
  360. ]
  361. },
  362. {
  363. "cell_type": "code",
  364. "execution_count": 12,
  365. "metadata": {},
  366. "outputs": [
  367. {
  368. "name": "stdout",
  369. "output_type": "stream",
  370. "text": [
  371. "KNN score: 0.953661\n",
  372. "LogisticRegression score: 0.908248\n"
  373. ]
  374. }
  375. ],
  376. "source": [
  377. "# do KNN classification\n",
  378. "knn = neighbors.KNeighborsClassifier()\n",
  379. "logistic = linear_model.LogisticRegression()\n",
  380. "\n",
  381. "print('KNN score: %f' % knn.fit(X_train, y_train).score(X_test, y_test))\n",
  382. "print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_test, y_test))"
  383. ]
  384. },
  385. {
  386. "cell_type": "markdown",
  387. "metadata": {},
  388. "source": [
  389. "## References\n",
  390. "* [Digits Classification Exercise](http://scikit-learn.org/stable/auto_examples/exercises/plot_digits_classification_exercise.html)\n",
  391. "* [knn算法的原理与实现](https://zhuanlan.zhihu.com/p/36549000)"
  392. ]
  393. }
  394. ],
  395. "metadata": {
  396. "kernelspec": {
  397. "display_name": "Python 3",
  398. "language": "python",
  399. "name": "python3"
  400. },
  401. "language_info": {
  402. "codemirror_mode": {
  403. "name": "ipython",
  404. "version": 3
  405. },
  406. "file_extension": ".py",
  407. "mimetype": "text/x-python",
  408. "name": "python",
  409. "nbconvert_exporter": "python",
  410. "pygments_lexer": "ipython3",
  411. "version": "3.5.2"
  412. }
  413. },
  414. "nbformat": 4,
  415. "nbformat_minor": 2
  416. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。