{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# KNN Classification\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Feature dimensions: (1797, 64)\n", "Label dimensions: (1797,)\n" ] } ], "source": [ "% matplotlib inline\n", "\n", "import matplotlib.pyplot as plt\n", "from sklearn import datasets, neighbors, linear_model\n", "\n", "# load data\n", "digits = datasets.load_digits()\n", "X_digits = digits.data\n", "y_digits = digits.target\n", "\n", "print(\"Feature dimensions: \", X_digits.shape)\n", "print(\"Label dimensions: \", y_digits.shape)\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAABLCAYAAABQtG2+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAFI5JREFUeJztnXmcFNW1x79nNgZmYAQGB9kExBEhUVRC1ERxeUZM3guo+USjiXlGJYGHL0bNxjMfSWIkLyaicSGSIHGLS94n6Iu7LwqK4jIRA0EZIovsy7DOvvV5f1RPV912ehime7q6M+f7+fRn7u1bXfc3t27dqjp17j2iqhiGYRjZQ07YAgzDMIzDwwZuwzCMLMMGbsMwjCzDBm7DMIwswwZuwzCMLMMGbsMwjCzDBm7DMIwsIyMGbhEZICKLRaRWRD4SkctC0DBLRCpEpFFEfp/u+gM6eonIwmg7VIvIeyJyQUhaHhaR7SJyUETWisjVYegI6DlWRBpE5OGQ6l8Srb8m+qkMQ0dUy6Ui8kH0nFknImekuf6auE+riNyVTg0BLSNF5FkR2SciO0TkbhHJC0HH8SLysogcEJEPReTC7qorIwZu4B6gCSgDLgfmi8j4NGvYBtwC3J/meuPJAzYDk4ES4CbgCREZGYKWucBIVe0HfBG4RUROCUFHG/cA74RYP8AsVS2Ofo4LQ4CInAf8N3Al0Bc4E1ifTg2BNigGBgP1wB/TqSHAvcAu4ChgAt65MzOdAqIXiqeAp4EBwHTgYREp7476Qh+4RaQIuBj4karWqOoy4H+Br6VTh6r+SVWfBPaks952dNSq6hxV3aiqEVV9GtgApH3AVNXVqtrYlo1+jkm3DvDuMIH9wF/CqD/D+DHwE1V9M9pHtqrq1hD1XIw3cL4WUv2jgCdUtUFVdwDPA+m+8RsLDAHmqWqrqr4MvE43jWOhD9xAOdCiqmsD3/2N9Dd8RiIiZXhttDqk+u8VkTpgDbAdeDYEDf2AnwDXp7vudpgrIlUi8rqInJXuykUkF5gIDIo+jm+JmgZ6p1tLgK8DD2p462fcAVwqIn1EZChwAd7gHTYCfKI7dpwJA3cxcDDuuwN4j4A9GhHJBx4BHlDVNWFoUNWZeMfiDOBPQGPHv+gWfgosVNUtIdQd5PvAaGAosAD4s4ik+wmkDMgHvoR3TCYAJ+GZ1NKOiByNZ5p4IIz6o7yKd6N3ENgCVABPpllDJd5Tx3dFJF9EPofXLn26o7JMGLhrgH5x3/UDqkPQkjGISA7wEJ7tf1aYWqKPfsuAYcCMdNYtIhOAfwHmpbPe9lDVt1S1WlUbVfUBvEfhz6dZRn30712qul1Vq4DbQ9DRxteAZaq6IYzKo+fJ83g3FUVAKdAf7x1A2lDVZmAa8AVgB3AD8ATehSTlZMLAvRbIE5FjA9+dSEimgUxARARYiHd3dXG0U2QCeaTfxn0WMBLYJCI7gBuBi0Xk3TTraA/FexxOX4Wq+/AGg6BZIswlPq8g3LvtAcAI4O7oBXUPsIgQLmSqulJVJ6vqQFU9H+/p7O3uqCv0gVtVa/Gulj8RkSIR+QwwFe9uM22ISJ6IFAK5QK6IFIbhUhRlPnA88G+qWn+ojbsDETky6nJWLCK5InI+8BXS/3JwAd7FYkL08xvgGeD8dIoQkSNE5Py2fiEil+N5c4RhS10EXBs9Rv2B7+B5M6QVETkdz2wUljcJ0SeODcCM6HE5As/mvjLdWkTkhGj/6CMiN+J5ufy+WypT1dA/eFfNJ4FaYBNwWQga5uB7TrR95oSg4+ho3Q14ZqS2z+Vp1jEIWIrnyXEQWAVckwF9ZQ7wcAj1DsJzRayOtsmbwHkhtUE+ngvcfrzH8l8DhSHouA94KAP6xARgCbAPqMIzUZSFoOO2qIYa4DlgTHfVJdEKDcMwjCwhdFOJYRiGcXjYwG0YhpFl2MBtGIaRZXRq4BaRKSJSGZ2p9YPuFmU6TIfpMB3/rDpSwSFfTkan2K4FzsPzH30H+Iqqvp/oNwXSSwsparespdT9fvDgvbH01tojnLLCLb77sqpS07KXPhQj5FBHNYUUkUsuDdTSpI0f86ftSMfHth3rX8N65bQ4Zft3+pM4VZWGvdu7TUfkCH+7kcN3OmU7mv15SqrKvsr9KdPRNNT9/hMDd8fSeyO5TtmeSn/b7j4ukud7ZEZGu/cZsrbJ14FSy8GU6Qj2B4Da5oJYOn9dQ0K9qdbRka74flr9vl+Wah1NQ9zvNdAlSvu6c+WOyvPbR1VZVdnEyFF55OVB5Wqld24/ciWP+tZqmiL1h6WjcaQ7EXF4sT9+bD4w0Ckr3O5P8lVValpT10+1vMDJB49F05pIu785FIl0tEdn/JQnAR+q6noAEXkMz8864cBdSBGflnPbLau6+DQn/90bHoulf/TXqU5Z+fXbY+l9TTv4x55lnBxdvXJDdAb4KBnLW9q+a3FHOuIZ8oA/OB/bZ5dT9uTt58TSNbs2svuZJ7pNR905n46lF95xu1M2d/uUWHr3ql28dXVFynRsuNY9Lm9/fX4s/Vh1f6fsocmTYunuPi65pUfG0vX3ustxFJz3USy9X/ewnvdTpiPYHwDe3joilh52ceK5YanW0ZGu+H669AS/fVKtY9M3T3fyTSX+4HTVua84ZbNL/dVul1fU8/3bdvO7h71B9YvjvHYcXXwKy6vad//uSMfamyc6+V+c4Y8fNzz9VafsuJ/7Cybub9rBP/a+nrL2aLr3aCc/sq9/Adl2atcmfSfS0R6dMZUMxVtmtI0t0e8cRGS6eOtZVzR3w3IWja01FOJ3zEJ608jH56Z0t47mugMZoaNuV11G6MiU49JIvenIQB1bd7QyeIh/e16YW0xDpDbtOhoitRnRHqkiZS8nVXWBqk5U1Yn59ErVbk2H6TAdpqPH6TgUnTGVbAWGB/LDot91iaBpBODSvvti6TuOqHHKnnn3hVh6eUU9F8woouoL3iN9w31r6EXqVrLcWD0gll40wl1W+Ldn+sFFGocU0fyKf6VuoD4pHZHJJzn51+65L5ZeG7dCydSBK2LpyjG1rCI5HWvn+yaPuee4x+UTd/rr0P/92/c6ZXedMTKWbq6ChpffS0pHR2yYMSaWbvq7azscg28q6UVvGpJsjyDBtoa4PrHN3fbJ2uJYuvLdWn755dTp2PfvrgnrhRG+CeuYx7/llI3hzVg61e0RT8EB/57vuZvPcspemjk2lj5Yv42dG5Yyd7u3QkHdgXcoAFrrd6Hq2ug7w1njEgcd+tW/uoGRnjrNP7dyVuWx+erk2iN3vB8345XxjyfeMK5/3FrlxtsImrS6SmfuuN8BjhWRUSJSAFyKF+ggrXxqQiGNB6poPLiHSGsLO9nMII5KtwwKRg2jnhrqtZaIRkLTMeaEPhmho3jA8IzQ0Y/+GaEjU45LprRH3+MGU735INXbqmltbg1Nx8DjSzOiPVLFIe+4VbVFRGYBL+AtwHS/qqZ95b68PGHYZy9i/bMLUFWGM4xiKUm3DCQ3l+OYwApeQ1GGMDIUHbl5khE6JCcz2iNHcjhOw9eRKcclU9pDcnOYdOOp/N9/vohGlDJGhNMeeTkZcVxSRadWv1PVZ0ki8knLOX7UrUv7vueUXTDl0li6ZKUbK+DLy9w3ui1Tj6Fs6vcAGDUjudUS400U95XfHci5LkD9VrmuP6VyFKUpulqvn+ba0YKPVQv/crZTtu6S3zj5+TImKR1j5/vxKx768SSn7Kalj8bS8V4lxX98y82nsD1yy4508l+7yH/T/vgitz8EH10ByjiOMs4CoHV1cnF83693379PK/L3t7bZfbn2Xysvd/JHD95NGSd6Ona6nh+Hy7TrX05YNvrJjl+epbKfjpjzRsKyD+ed6uSvKnPP42W/KOd0vNCLrZJceyx53z3mb5ck9va56yN34carLrqeCUwDoM9itw93hubSxDERrtzkm1ODHkgAPzvhKSe/lDEki82cNAzDyDJs4DYMw8gybOA2DMPIMtIS4aVhoF/NTbs+6ZRFViaOgfvOqtRGydo0x5/99dSVtzll5fmJpxwPfXGPk29Noabg7C6Axzf5dtznrnM1nr36MidfEHCH6wpO258w1ikLuml+eb1rW84b7Hablh3u1PxkCLr/AdxRsjiWXjrPdaP64H53Fl3OAV/XmO8kp+OlnW57BGcDxveVyCr3JVfrztS9ux/X2/W8Db4DyVm6In7zlFJ3oT+Ld9uZiWdiP3fRrzrcz+OX+f1n8LzkbNxjHnDPvpcefSSWvvLNM5yy95vKnHzftftj6a6cw/lrEntB75zq981JT21yysYVxJ8fZuM2DMPocdjAbRiGkWWkx1TS378+PLLcnQlW3kEQ5LySJiffcqAgwZadI+jSdN38C52yZ1e8mPB38W5AyV7tgi5vlT8Y7ZRddW7ihWZ6f9VdWyGVJpt4k9UXTvZj8Z70fNxUsLjwuCumDImlu2I2Cc4O/GC6O0tz/PLpsfQwXBPEhim/c/In3jaTVBFcwArgjAu/GUtXneiulhiv+Xh8HR250XWG+Mfsp/b4bqyb5rhmx1F/jDPpJekSGTQtjJjproh4X/kfEv7uquuud/KDFyfXBkEaBiQeA+JnPH/+vEucfLLtEXTtjJ8NGRw/Rj1/tVP2w6PcEyboxtpVTXbHbRiGkWXYwG0YhpFldMpUIiIbgWq8p/MWVZ3Y8S+6h62z55JT2AtyhH3a0Ok1jFPNMn2WXPIQBCGnx+tYsutB8iQfQUAjoemo/N1PycnvheTk8JHWh6Zj/byfklPg6dgeoo5M6R+mI/Ucjo37bFWt6kolhfv8Vd0+9cl1TtmBoJjBrvvOJeP+6uRvi8DQq2aQW1TMqB8u74qULrHr5LjVvJbAKUymQLq27OMHc/0psRum/CbhdpNm3+jk++/8+P+cjI6OCNqqgzZsgD33u0EGmkuWMHD2teQWF1HehaUIeh3w+0f8dPLVp/nuXreudO2K8eTWtHBa6TQKcnonPbU6nuAU6VI+3cGWoLnKoJuuIbdvEeXfqEiq3v85cLKTD9pxb73I/R9nT3ftpUUjCzjplJkUFBR1yXUwaH8tOM8tK9/mu0ROmj3DKeu/OLX9NLg8RXD1THBXSCwc4QYwuPxRt+2XnpzPpGO+QUFen6Tt3fEr/L0y+cpYunypW+/593/byY+8w48uFd+uncVMJYZhGFlGZ++4FXhRRBS4T1UXxG8gItOB6QCFJF6MJSlE2LZwAYiQr0cyTEa3s0kadAAreA0UhjLadAjsuvO3IEIfLQtNhwhU7P0zgjBEh4fYHsKuXy4MvT0Q4b2VixCEoTq4x/dTASo2/iHaP4aE2h7J0tmB+7OqulVEjgReEpE1qvpqcIPoYL4AoJ8M6DgCcRcZ+q1Z5JWU0FJTzZZb5lGkfekvg5xt0qFjImdTKL1p0gbe5bUer6Psxpnk9S+h9WANW753Z2g6Jg24kMLcYhpb66jYvTi89pj9Lb89vn1XaDpOOekaevUqoamphvfeuLfH99NJo6+gML8fjS21VKxZFJqOVNDZZV23Rv/uEpHFeAGEX+34Vz79Kn1L9s3DnnbKrpju+3zmT9tNRxz7cz8+cQ5DOMhe+jOog190D4Xi2bcKpJBBevg6gtN2b53o2m2DU6vfvnW+U3b25W4w5fpHhsRinAxatC6p9ghGwwEY8rI/xTnohw/w4Dg3iPG0/TOAJvJKChjUheMStB9fu/gzTlnQvnnPg3c7ZUEfb4BhVatppY486JKOIPGRZ4J2+DHfTxgnG4CRy9qiKZWwPUkdD/3JfYEWtGPHT8v/Usm7Tn7rJW3zBXox6I3kdKyNW15gbfPrsXTpc+57q/j5BcmeL8Gp5vHvQIJLRjSPdZfinf2oa8deOKNtmeT+DLouteNH8B1CfFu9cO6dTj7o597VZSsOaeMWkSIR6duWBj4H/L1LtSVBU10LLerF8mrVFvaykyLSvxB6pLkxI3S01jfT2uxNimhtbgxNR11dhEi9ty50pKEpvOPS0JQRx6W2LuIfl5bw2qOuLkKkwdMRaQyvf7RqZpy3tRnSHqmiM3fcZcBiEWnb/g+q+nzHP0k9tXsbqeAtUFCUwQynVAanWwYttTVUsCR0HU37aql8xrsDVY0wlKNC0bG3KsK2Ob/1dLRGGBGSjpb9tRlxXHbubmXVq95MSo1EGBJSe+zZHWH7r+/xMpHwjksjDaxkeUYcl0xoj1TRmdBl6yEazqOLBKdTXzL/Bqfsphv8SCt3rHMfC9+ZEJxa3I9TpYu+M+0QH5nk7NW+GeKV8W7EipbP+qaeHPI4dVFyOoKPVR25FbXctNctC+oaD6Nu992MRiXpdpa/353Gfe0tjyXYEqa94bp/nbk5ENUo8SJyXSK/qi6Wjl+Vb8DDxYFcMaNT2D92n+lGao6fXh9k/HI3As7pBwNT85Nsj1HzP3TzI/zp1PGP4N9c664eeUa5HwA7Z2dyKwleM9GdTv7Vm31X1fbcVNvoI8WcSnLHJXiuxv+Pr6zwz4l4M0r8aprnRPwlI5J1F403hwSDGE/u47bVf1wxy8n3WXr40XfiMXdAwzCMLMMGbsMwjCzDBm7DMIwsQ1RT76ooIruBWqBLU+TjKO3Efo5W1Y/59ZiOjNbxUSf3YTpMxz+Djs5oaVdHu6hqt3yAikzYj+nITB22D9tHT9pHKvejqmYqMQzDyDZs4DYMw8gyunPg/thCVCHtx3Sk9vep3I/tw/bRU/aRyv10z8tJwzAMo/swU4lhGEaWYQO3YRhGltEtA7eITBGRShH5UER+kMR+NorIKhF5T0QOezEO02E6TIfpyHYd7ZIqv8KAr2IusA4YDRQAfwPGdXFfG4FS02E6TIfp6Ik6En264457EvChqq5X1SbgMWDqIX7THZgO02E6TEe262iX7hi4hwKbA/kt0e+6Qlusy79GY8GZDtNhOkxHT9LRLp2NORkWh4x1aTpMh+kwHT1NR3fccW8Fhgfyw6LfHTYaiHUJtMW6NB2mw3SYjp6iI+FOU/rBu4tfD4zCN+qP78J+ioC+gfQbwBTTYTpMh+noKToSfVJuKlHVFhGZBbyA92b2flVdfYiftUdSsS5Nh+kwHaYj23Ukwqa8G4ZhZBk2c9IwDCPLsIHbMAwjy7CB2zAMI8uwgdswDCPLsIHbMAwjy7CB2zAMI8uwgdswDCPL+H+2ihC0591JagAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# plot sample images\n", "nplot = 10\n", "fig, axes = plt.subplots(nrows=1, ncols=nplot)\n", "\n", "for i in range(nplot):\n", " img = X_digits[i].reshape(8, 8)\n", " axes[i].imshow(img)\n", " axes[i].set_title(y_digits[i])\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# split train / test data\n", "n_samples = len(X_digits)\n", "n_train = int(0.4 * n_samples)\n", "\n", "X_train = X_digits[:n_train]\n", "y_train = y_digits[:n_train]\n", "X_test = X_digits[n_train:]\n", "y_test = y_digits[n_train:]\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "KNN score: 0.953661\n", "LogisticRegression score: 0.908248\n" ] } ], "source": [ "# do KNN classification\n", "knn = neighbors.KNeighborsClassifier()\n", "logistic = linear_model.LogisticRegression()\n", "\n", "print('KNN score: %f' % knn.fit(X_train, y_train).score(X_test, y_test))\n", "print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_test, y_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References\n", "* [Digits Classification Exercise](http://scikit-learn.org/stable/auto_examples/exercises/plot_digits_classification_exercise.html)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }