You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

PCA_and_Logistic_Regression.ipynb 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# Chaining a PCA and a logistic regression"
  8. ]
  9. },
  10. {
  11. "cell_type": "markdown",
  12. "metadata": {},
  13. "source": [
  14. "The PCA does an unsupervised dimensionality reduction, while the logistic regression does the prediction.\n",
  15. "\n",
  16. "We use a GridSearchCV to set the dimensionality of the PCA"
  17. ]
  18. },
  19. {
  20. "cell_type": "code",
  21. "execution_count": 3,
  22. "metadata": {},
  23. "outputs": [
  24. {
  25. "data": {
  26. "image/png": "\n",
  27. "text/plain": [
  28. "<Figure size 288x216 with 1 Axes>"
  29. ]
  30. },
  31. "metadata": {
  32. "needs_background": "light"
  33. },
  34. "output_type": "display_data"
  35. }
  36. ],
  37. "source": [
  38. "% matplotlib inline\n",
  39. "\n",
  40. "import numpy as np\n",
  41. "import matplotlib.pyplot as plt\n",
  42. "\n",
  43. "from sklearn import linear_model, decomposition, datasets\n",
  44. "from sklearn.pipeline import Pipeline\n",
  45. "from sklearn.model_selection import GridSearchCV\n",
  46. "\n",
  47. "logistic = linear_model.LogisticRegression()\n",
  48. "\n",
  49. "pca = decomposition.PCA()\n",
  50. "pipe = Pipeline(steps=[('pca', pca), ('logistic', logistic)])\n",
  51. "\n",
  52. "digits = datasets.load_digits()\n",
  53. "X_digits = digits.data\n",
  54. "y_digits = digits.target\n",
  55. "\n",
  56. "# Plot the PCA spectrum\n",
  57. "pca.fit(X_digits)\n",
  58. "\n",
  59. "plt.figure(1, figsize=(4, 3))\n",
  60. "plt.clf()\n",
  61. "plt.axes([.2, .2, .7, .7])\n",
  62. "plt.plot(pca.explained_variance_, linewidth=2)\n",
  63. "plt.axis('tight')\n",
  64. "plt.xlabel('n_components')\n",
  65. "plt.ylabel('explained_variance_')\n",
  66. "\n",
  67. "# Prediction\n",
  68. "n_components = [20, 40, 64]\n",
  69. "Cs = np.logspace(-4, 4, 3)\n",
  70. "\n",
  71. "# Parameters of pipelines can be set using ‘__’ separated parameter names:\n",
  72. "estimator = GridSearchCV(pipe,\n",
  73. " dict(pca__n_components=n_components,\n",
  74. " logistic__C=Cs))\n",
  75. "estimator.fit(X_digits, y_digits)\n",
  76. "\n",
  77. "plt.axvline(estimator.best_estimator_.named_steps['pca'].n_components,\n",
  78. " linestyle=':', label='n_components chosen')\n",
  79. "plt.legend(prop=dict(size=12))\n",
  80. "plt.show()"
  81. ]
  82. },
  83. {
  84. "cell_type": "code",
  85. "execution_count": 12,
  86. "metadata": {},
  87. "outputs": [
  88. {
  89. "name": "stdout",
  90. "output_type": "stream",
  91. "text": [
  92. "(1797, 64)\n"
  93. ]
  94. },
  95. {
  96. "data": {
  97. "text/plain": [
  98. "<Figure size 432x288 with 0 Axes>"
  99. ]
  100. },
  101. "metadata": {},
  102. "output_type": "display_data"
  103. },
  104. {
  105. "data": {
  106. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAECCAYAAADesWqHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAC8tJREFUeJzt3X+o1fUdx/HXazetlpK2WoRGZgwhguUPZFHEphm2wv2zRKFgsaF/bJFsULZ/Rv/1V7Q/RiBWCzKjawkjtpaSEUGr3Wu2TG2UGCnVLTTM/lCy9/44X4eJ637v3f187jnn/XzAwXO9x/P63Ht9ne/3e+73nLcjQgBy+c5kLwBAfRQfSIjiAwlRfCAhig8kRPGBhLqi+LaX237X9nu21xfOesz2iO3dJXNOy7vc9g7be2y/Y/uewnnn2X7D9ltN3gMl85rMAdtv2n6+dFaTd8D227Z32R4qnDXD9hbb+2zvtX1dwax5zdd06nLU9roiYRExqRdJA5LelzRX0lRJb0m6umDejZIWSNpd6eu7TNKC5vp0Sf8u/PVZ0rTm+hRJr0v6UeGv8beSnpL0fKXv6QFJF1fKekLSr5rrUyXNqJQ7IOljSVeUuP9u2OIvlvReROyPiBOSnpb0s1JhEfGKpMOl7v8seR9FxM7m+heS9kqaVTAvIuJY8+GU5lLsLC3bsyXdKmljqYzJYvtCdTYUj0pSRJyIiM8rxS+V9H5EfFDizruh+LMkfXjaxwdVsBiTyfYcSfPV2QqXzBmwvUvSiKRtEVEy72FJ90r6umDGmULSi7aHba8pmHOlpE8lPd4cymy0fUHBvNOtkrS51J13Q/FTsD1N0rOS1kXE0ZJZEXEyIq6VNFvSYtvXlMixfZukkYgYLnH/3+KGiFgg6RZJv7Z9Y6Gcc9Q5LHwkIuZL+lJS0eegJMn2VEkrJA2WyuiG4h+SdPlpH89u/q5v2J6iTuk3RcRztXKb3dIdkpYXirhe0grbB9Q5RFti+8lCWf8VEYeaP0ckbVXncLGEg5IOnrbHtEWdB4LSbpG0MyI+KRXQDcX/p6Qf2L6yeaRbJekvk7ymCWPb6hwj7o2IhyrkXWJ7RnP9fEnLJO0rkRUR90fE7IiYo87P7aWIuKNE1im2L7A9/dR1STdLKvIbmoj4WNKHtuc1f7VU0p4SWWdYrYK7+VJnV2ZSRcRXtn8j6e/qPJP5WES8UyrP9mZJP5Z0se2Dkv4QEY+WylNnq3inpLeb425J+n1E/LVQ3mWSnrA9oM4D+zMRUeXXbJVcKmlr5/FU50h6KiJeKJh3t6RNzUZpv6S7CmadejBbJmlt0ZzmVwcAEumGXX0AlVF8ICGKDyRE8YGEKD6QUFcVv/Dpl5OWRR553ZbXVcWXVPObW/UHSR553ZTXbcUHUEGRE3hs9/VZQTNnzhzzvzl+/LjOPffcceXNmjX2FysePnxYF1100bjyjh4d+2uIjh07pmnTpo0r79Chsb80IyLUnL03ZidPnhzXv+sVETHqN2bST9ntRTfddFPVvAcffLBq3vbt26vmrV9f/AVv33DkyJGqed2IXX0gIYoPJETxgYQoPpAQxQcSovhAQhQfSIjiAwm1Kn7NEVcAyhu1+M2bNv5Jnbf8vVrSattXl14YgHLabPGrjrgCUF6b4qcZcQVkMWEv0mneOKD2a5YBjEOb4rcacRURGyRtkPr/ZblAr2uzq9/XI66AjEbd4tcecQWgvFbH+M2ct1Kz3gBUxpl7QEIUH0iI4gMJUXwgIYoPJETxgYQoPpAQxQcSYpLOONSebDN37tyqeeMZEfb/OHz4cNW8lStXVs0bHBysmtcGW3wgIYoPJETxgYQoPpAQxQcSovhAQhQfSIjiAwlRfCAhig8k1GaE1mO2R2zvrrEgAOW12eL/WdLywusAUNGoxY+IVyTVfRUFgKI4xgcSYnYekNCEFZ/ZeUDvYFcfSKjNr/M2S3pN0jzbB23/svyyAJTUZmjm6hoLAVAPu/pAQhQfSIjiAwlRfCAhig8kRPGBhCg+kBDFBxLqi9l5CxcurJpXe5bdVVddVTVv//79VfO2bdtWNa/2/xdm5wHoChQfSIjiAwlRfCAhig8kRPGBhCg+kBDFBxKi+EBCFB9IqM2bbV5ue4ftPbbfsX1PjYUBKKfNufpfSfpdROy0PV3SsO1tEbGn8NoAFNJmdt5HEbGzuf6FpL2SZpVeGIByxnSMb3uOpPmSXi+xGAB1tH5Zru1pkp6VtC4ijp7l88zOA3pEq+LbnqJO6TdFxHNnuw2z84De0eZZfUt6VNLeiHio/JIAlNbmGP96SXdKWmJ7V3P5aeF1ASiozey8VyW5wloAVMKZe0BCFB9IiOIDCVF8ICGKDyRE8YGEKD6QEMUHEuqL2XkzZ86smjc8PFw1r/Ysu9pqfz/BFh9IieIDCVF8ICGKDyRE8YGEKD6QEMUHEqL4QEIUH0iI4gMJtXmX3fNsv2H7rWZ23gM1FgagnDbn6h+XtCQijjXvr/+q7b9FxD8Krw1AIW3eZTckHWs+nNJcGJgB9LBWx/i2B2zvkjQiaVtEMDsP6GGtih8RJyPiWkmzJS22fc2Zt7G9xvaQ7aGJXiSAiTWmZ/Uj4nNJOyQtP8vnNkTEoohYNFGLA1BGm2f1L7E9o7l+vqRlkvaVXhiActo8q3+ZpCdsD6jzQPFMRDxfdlkASmrzrP6/JM2vsBYAlXDmHpAQxQcSovhAQhQfSIjiAwlRfCAhig8kRPGBhJidNw7bt2+vmtfvav/8jhw5UjWvG7HFBxKi+EBCFB9IiOIDCVF8ICGKDyRE8YGEKD6QEMUHEqL4QEKti98M1XjTNm+0CfS4sWzx75G0t9RCANTTdoTWbEm3StpYdjkAami7xX9Y0r2Svi64FgCVtJmkc5ukkYgYHuV2zM4DekSbLf71klbYPiDpaUlLbD955o2YnQf0jlGLHxH3R8TsiJgjaZWklyLijuIrA1AMv8cHEhrTW29FxMuSXi6yEgDVsMUHEqL4QEIUH0iI4gMJUXwgIYoPJETxgYQoPpBQX8zOqz0LbeHChVXzaqs9y67293NwcLBqXjdiiw8kRPGBhCg+kBDFBxKi+EBCFB9IiOIDCVF8ICGKDyRE8YGEWp2y27y19heSTkr6irfQBnrbWM7V/0lEfFZsJQCqYVcfSKht8UPSi7aHba8puSAA5bXd1b8hIg7Z/r6kbbb3RcQrp9+geUDgQQHoAa22+BFxqPlzRNJWSYvPchtm5wE9os203AtsTz91XdLNknaXXhiActrs6l8qaavtU7d/KiJeKLoqAEWNWvyI2C/phxXWAqASfp0HJETxgYQoPpAQxQcSovhAQhQfSIjiAwlRfCAhR8TE36k98Xf6LebOnVszTkNDQ1Xz1q5dWzXv9ttvr5pX++e3aFF/v5wkIjzabdjiAwlRfCAhig8kRPGBhCg+kBDFBxKi+EBCFB9IiOIDCVF8IKFWxbc9w/YW2/ts77V9XemFASin7UCNP0p6ISJ+bnuqpO8WXBOAwkYtvu0LJd0o6ReSFBEnJJ0ouywAJbXZ1b9S0qeSHrf9pu2NzWCNb7C9xvaQ7bovXQMwZm2Kf46kBZIeiYj5kr6UtP7MGzFCC+gdbYp/UNLBiHi9+XiLOg8EAHrUqMWPiI8lfWh7XvNXSyXtKboqAEW1fVb/bkmbmmf090u6q9ySAJTWqvgRsUsSx+5An+DMPSAhig8kRPGBhCg+kBDFBxKi+EBCFB9IiOIDCfXF7Lza1qxZUzXvvvvuq5o3PDxcNW/lypVV8/ods/MAnBXFBxKi+EBCFB9IiOIDCVF8ICGKDyRE8YGEKD6Q0KjFtz3P9q7TLkdtr6uxOABljPqeexHxrqRrJcn2gKRDkrYWXheAgsa6q79U0vsR8UGJxQCoY6zFXyVpc4mFAKindfGb99RfIWnwf3ye2XlAj2g7UEOSbpG0MyI+OdsnI2KDpA1S/78sF+h1Y9nVXy1284G+0Kr4zVjsZZKeK7scADW0HaH1paTvFV4LgEo4cw9IiOIDCVF8ICGKDyRE8YGEKD6QEMUHEqL4QEIUH0io1Oy8TyWN5zX7F0v6bIKX0w1Z5JFXK++KiLhktBsVKf542R6KiEX9lkUeed2Wx64+kBDFBxLqtuJv6NMs8sjrqryuOsYHUEe3bfEBVEDxgYQoPpAQxQcSovhAQv8BVOSY4UmSu60AAAAASUVORK5CYII=\n",
  107. "text/plain": [
  108. "<Figure size 288x288 with 1 Axes>"
  109. ]
  110. },
  111. "metadata": {
  112. "needs_background": "light"
  113. },
  114. "output_type": "display_data"
  115. }
  116. ],
  117. "source": [
  118. "# Compare the performance\n",
  119. "from sklearn.datasets import load_digits\n",
  120. "from sklearn.linear_model.logistic import LogisticRegression\n",
  121. "from sklearn import decomposition\n",
  122. "from sklearn.metrics import confusion_matrix\n",
  123. "from sklearn.metrics import accuracy_score\n",
  124. "import matplotlib.pyplot as plt\n",
  125. "\n",
  126. "\n",
  127. "# load digital data\n",
  128. "digits, dig_label = load_digits(return_X_y=True)\n",
  129. "print(digits.shape)\n",
  130. "\n",
  131. "# draw one digital\n",
  132. "plt.gray() \n",
  133. "plt.matshow(digits[0].reshape([8, 8])) \n",
  134. "plt.show() \n"
  135. ]
  136. },
  137. {
  138. "cell_type": "code",
  139. "execution_count": 9,
  140. "metadata": {},
  141. "outputs": [
  142. {
  143. "name": "stdout",
  144. "output_type": "stream",
  145. "text": [
  146. "accuracy train = 0.998608, accuracy_test = 0.897222\n"
  147. ]
  148. }
  149. ],
  150. "source": [
  151. "\n",
  152. "# calculate train/test data number\n",
  153. "N = len(digits)\n",
  154. "N_train = int(N*0.8)\n",
  155. "N_test = N - N_train\n",
  156. "\n",
  157. "# split train/test data\n",
  158. "x_train = digits[:N_train, :]\n",
  159. "y_train = dig_label[:N_train]\n",
  160. "x_test = digits[N_train:, :]\n",
  161. "y_test = dig_label[N_train:]\n",
  162. "\n",
  163. "# do logistic regression\n",
  164. "lr=LogisticRegression()\n",
  165. "lr.fit(x_train,y_train)\n",
  166. "\n",
  167. "pred_train = lr.predict(x_train)\n",
  168. "pred_test = lr.predict(x_test)\n",
  169. "\n",
  170. "# calculate train/test accuracy\n",
  171. "acc_train = accuracy_score(y_train, pred_train)\n",
  172. "acc_test = accuracy_score(y_test, pred_test)\n",
  173. "print(\"accuracy train = %f, accuracy_test = %f\" % (acc_train, acc_test))\n"
  174. ]
  175. },
  176. {
  177. "cell_type": "code",
  178. "execution_count": 19,
  179. "metadata": {},
  180. "outputs": [
  181. {
  182. "name": "stdout",
  183. "output_type": "stream",
  184. "text": [
  185. "accuracy train = 0.987474, accuracy_test = 0.894444\n"
  186. ]
  187. }
  188. ],
  189. "source": [
  190. "# do PCA with 'n_components=40'\n",
  191. "pca = decomposition.PCA(n_components=40)\n",
  192. "pca.fit(x_train)\n",
  193. "\n",
  194. "x_train_pca = pca.transform(x_train)\n",
  195. "x_test_pca = pca.transform(x_test)\n",
  196. "\n",
  197. "# do logistic regression\n",
  198. "lr=LogisticRegression()\n",
  199. "lr.fit(x_train_pca,y_train)\n",
  200. "\n",
  201. "pred_train = lr.predict(x_train_pca)\n",
  202. "pred_test = lr.predict(x_test_pca)\n",
  203. "\n",
  204. "# calculate train/test accuracy\n",
  205. "acc_train = accuracy_score(y_train, pred_train)\n",
  206. "acc_test = accuracy_score(y_test, pred_test)\n",
  207. "print(\"accuracy train = %f, accuracy_test = %f\" % (acc_train, acc_test))\n"
  208. ]
  209. },
  210. {
  211. "cell_type": "markdown",
  212. "metadata": {},
  213. "source": [
  214. "## References\n",
  215. "* [Pipelining: chaining a PCA and a logistic regression](http://scikit-learn.org/stable/auto_examples/plot_digits_pipe.html)"
  216. ]
  217. }
  218. ],
  219. "metadata": {
  220. "kernelspec": {
  221. "display_name": "Python 3",
  222. "language": "python",
  223. "name": "python3"
  224. },
  225. "language_info": {
  226. "codemirror_mode": {
  227. "name": "ipython",
  228. "version": 3
  229. },
  230. "file_extension": ".py",
  231. "mimetype": "text/x-python",
  232. "name": "python",
  233. "nbconvert_exporter": "python",
  234. "pygments_lexer": "ipython3",
  235. "version": "3.5.2"
  236. },
  237. "main_language": "python"
  238. },
  239. "nbformat": 4,
  240. "nbformat_minor": 2
  241. }

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。

Contributors (1)