{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Chaining a PCA and a logistic regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The PCA does an unsupervised dimensionality reduction, while the logistic regression does the prediction.\n", "\n", "We use a GridSearchCV to set the dimensionality of the PCA" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "% matplotlib inline\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "from sklearn import linear_model, decomposition, datasets\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.model_selection import GridSearchCV\n", "\n", "logistic = linear_model.LogisticRegression()\n", "\n", "pca = decomposition.PCA()\n", "pipe = Pipeline(steps=[('pca', pca), ('logistic', logistic)])\n", "\n", "digits = datasets.load_digits()\n", "X_digits = digits.data\n", "y_digits = digits.target\n", "\n", "# Plot the PCA spectrum\n", "pca.fit(X_digits)\n", "\n", "plt.figure(1, figsize=(4, 3))\n", "plt.clf()\n", "plt.axes([.2, .2, .7, .7])\n", "plt.plot(pca.explained_variance_, linewidth=2)\n", "plt.axis('tight')\n", "plt.xlabel('n_components')\n", "plt.ylabel('explained_variance_')\n", "\n", "# Prediction\n", "n_components = [20, 40, 64]\n", "Cs = np.logspace(-4, 4, 3)\n", "\n", "# Parameters of pipelines can be set using ‘__’ separated parameter names:\n", "estimator = GridSearchCV(pipe,\n", " dict(pca__n_components=n_components,\n", " logistic__C=Cs))\n", "estimator.fit(X_digits, y_digits)\n", "\n", "plt.axvline(estimator.best_estimator_.named_steps['pca'].n_components,\n", " linestyle=':', label='n_components chosen')\n", "plt.legend(prop=dict(size=12))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1797, 64)\n" ] }, { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAECCAYAAADesWqHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAC8tJREFUeJzt3X+o1fUdx/HXazetlpK2WoRGZgwhguUPZFHEphm2wv2zRKFgsaF/bJFsULZ/Rv/1V7Q/RiBWCzKjawkjtpaSEUGr3Wu2TG2UGCnVLTTM/lCy9/44X4eJ637v3f187jnn/XzAwXO9x/P63Ht9ne/3e+73nLcjQgBy+c5kLwBAfRQfSIjiAwlRfCAhig8kRPGBhLqi+LaX237X9nu21xfOesz2iO3dJXNOy7vc9g7be2y/Y/uewnnn2X7D9ltN3gMl85rMAdtv2n6+dFaTd8D227Z32R4qnDXD9hbb+2zvtX1dwax5zdd06nLU9roiYRExqRdJA5LelzRX0lRJb0m6umDejZIWSNpd6eu7TNKC5vp0Sf8u/PVZ0rTm+hRJr0v6UeGv8beSnpL0fKXv6QFJF1fKekLSr5rrUyXNqJQ7IOljSVeUuP9u2OIvlvReROyPiBOSnpb0s1JhEfGKpMOl7v8seR9FxM7m+heS9kqaVTAvIuJY8+GU5lLsLC3bsyXdKmljqYzJYvtCdTYUj0pSRJyIiM8rxS+V9H5EfFDizruh+LMkfXjaxwdVsBiTyfYcSfPV2QqXzBmwvUvSiKRtEVEy72FJ90r6umDGmULSi7aHba8pmHOlpE8lPd4cymy0fUHBvNOtkrS51J13Q/FTsD1N0rOS1kXE0ZJZEXEyIq6VNFvSYtvXlMixfZukkYgYLnH/3+KGiFgg6RZJv7Z9Y6Gcc9Q5LHwkIuZL+lJS0eegJMn2VEkrJA2WyuiG4h+SdPlpH89u/q5v2J6iTuk3RcRztXKb3dIdkpYXirhe0grbB9Q5RFti+8lCWf8VEYeaP0ckbVXncLGEg5IOnrbHtEWdB4LSbpG0MyI+KRXQDcX/p6Qf2L6yeaRbJekvk7ymCWPb6hwj7o2IhyrkXWJ7RnP9fEnLJO0rkRUR90fE7IiYo87P7aWIuKNE1im2L7A9/dR1STdLKvIbmoj4WNKHtuc1f7VU0p4SWWdYrYK7+VJnV2ZSRcRXtn8j6e/qPJP5WES8UyrP9mZJP5Z0se2Dkv4QEY+WylNnq3inpLeb425J+n1E/LVQ3mWSnrA9oM4D+zMRUeXXbJVcKmlr5/FU50h6KiJeKJh3t6RNzUZpv6S7CmadejBbJmlt0ZzmVwcAEumGXX0AlVF8ICGKDyRE8YGEKD6QUFcVv/Dpl5OWRR553ZbXVcWXVPObW/UHSR553ZTXbcUHUEGRE3hs9/VZQTNnzhzzvzl+/LjOPffcceXNmjX2FysePnxYF1100bjyjh4d+2uIjh07pmnTpo0r79Chsb80IyLUnL03ZidPnhzXv+sVETHqN2bST9ntRTfddFPVvAcffLBq3vbt26vmrV9f/AVv33DkyJGqed2IXX0gIYoPJETxgYQoPpAQxQcSovhAQhQfSIjiAwm1Kn7NEVcAyhu1+M2bNv5Jnbf8vVrSattXl14YgHLabPGrjrgCUF6b4qcZcQVkMWEv0mneOKD2a5YBjEOb4rcacRURGyRtkPr/ZblAr2uzq9/XI66AjEbd4tcecQWgvFbH+M2ct1Kz3gBUxpl7QEIUH0iI4gMJUXwgIYoPJETxgYQoPpAQxQcSYpLOONSebDN37tyqeeMZEfb/OHz4cNW8lStXVs0bHBysmtcGW3wgIYoPJETxgYQoPpAQxQcSovhAQhQfSIjiAwlRfCAhig8k1GaE1mO2R2zvrrEgAOW12eL/WdLywusAUNGoxY+IVyTVfRUFgKI4xgcSYnYekNCEFZ/ZeUDvYFcfSKjNr/M2S3pN0jzbB23/svyyAJTUZmjm6hoLAVAPu/pAQhQfSIjiAwlRfCAhig8kRPGBhCg+kBDFBxLqi9l5CxcurJpXe5bdVVddVTVv//79VfO2bdtWNa/2/xdm5wHoChQfSIjiAwlRfCAhig8kRPGBhCg+kBDFBxKi+EBCFB9IqM2bbV5ue4ftPbbfsX1PjYUBKKfNufpfSfpdROy0PV3SsO1tEbGn8NoAFNJmdt5HEbGzuf6FpL2SZpVeGIByxnSMb3uOpPmSXi+xGAB1tH5Zru1pkp6VtC4ijp7l88zOA3pEq+LbnqJO6TdFxHNnuw2z84De0eZZfUt6VNLeiHio/JIAlNbmGP96SXdKWmJ7V3P5aeF1ASiozey8VyW5wloAVMKZe0BCFB9IiOIDCVF8ICGKDyRE8YGEKD6QEMUHEuqL2XkzZ86smjc8PFw1r/Ysu9pqfz/BFh9IieIDCVF8ICGKDyRE8YGEKD6QEMUHEqL4QEIUH0iI4gMJtXmX3fNsv2H7rWZ23gM1FgagnDbn6h+XtCQijjXvr/+q7b9FxD8Krw1AIW3eZTckHWs+nNJcGJgB9LBWx/i2B2zvkjQiaVtEMDsP6GGtih8RJyPiWkmzJS22fc2Zt7G9xvaQ7aGJXiSAiTWmZ/Uj4nNJOyQtP8vnNkTEoohYNFGLA1BGm2f1L7E9o7l+vqRlkvaVXhiActo8q3+ZpCdsD6jzQPFMRDxfdlkASmrzrP6/JM2vsBYAlXDmHpAQxQcSovhAQhQfSIjiAwlRfCAhig8kRPGBhJidNw7bt2+vmtfvav/8jhw5UjWvG7HFBxKi+EBCFB9IiOIDCVF8ICGKDyRE8YGEKD6QEMUHEqL4QEKti98M1XjTNm+0CfS4sWzx75G0t9RCANTTdoTWbEm3StpYdjkAami7xX9Y0r2Svi64FgCVtJmkc5ukkYgYHuV2zM4DekSbLf71klbYPiDpaUlLbD955o2YnQf0jlGLHxH3R8TsiJgjaZWklyLijuIrA1AMv8cHEhrTW29FxMuSXi6yEgDVsMUHEqL4QEIUH0iI4gMJUXwgIYoPJETxgYQoPpBQX8zOqz0LbeHChVXzaqs9y67293NwcLBqXjdiiw8kRPGBhCg+kBDFBxKi+EBCFB9IiOIDCVF8ICGKDyRE8YGEWp2y27y19heSTkr6irfQBnrbWM7V/0lEfFZsJQCqYVcfSKht8UPSi7aHba8puSAA5bXd1b8hIg7Z/r6kbbb3RcQrp9+geUDgQQHoAa22+BFxqPlzRNJWSYvPchtm5wE9os203AtsTz91XdLNknaXXhiActrs6l8qaavtU7d/KiJeKLoqAEWNWvyI2C/phxXWAqASfp0HJETxgYQoPpAQxQcSovhAQhQfSIjiAwlRfCAhR8TE36k98Xf6LebOnVszTkNDQ1Xz1q5dWzXv9ttvr5pX++e3aFF/v5wkIjzabdjiAwlRfCAhig8kRPGBhCg+kBDFBxKi+EBCFB9IiOIDCVF8IKFWxbc9w/YW2/ts77V9XemFASin7UCNP0p6ISJ+bnuqpO8WXBOAwkYtvu0LJd0o6ReSFBEnJJ0ouywAJbXZ1b9S0qeSHrf9pu2NzWCNb7C9xvaQ7bovXQMwZm2Kf46kBZIeiYj5kr6UtP7MGzFCC+gdbYp/UNLBiHi9+XiLOg8EAHrUqMWPiI8lfWh7XvNXSyXtKboqAEW1fVb/bkmbmmf090u6q9ySAJTWqvgRsUsSx+5An+DMPSAhig8kRPGBhCg+kBDFBxKi+EBCFB9IiOIDCfXF7Lza1qxZUzXvvvvuq5o3PDxcNW/lypVV8/ods/MAnBXFBxKi+EBCFB9IiOIDCVF8ICGKDyRE8YGEKD6Q0KjFtz3P9q7TLkdtr6uxOABljPqeexHxrqRrJcn2gKRDkrYWXheAgsa6q79U0vsR8UGJxQCoY6zFXyVpc4mFAKindfGb99RfIWnwf3ye2XlAj2g7UEOSbpG0MyI+OdsnI2KDpA1S/78sF+h1Y9nVXy1284G+0Kr4zVjsZZKeK7scADW0HaH1paTvFV4LgEo4cw9IiOIDCVF8ICGKDyRE8YGEKD6QEMUHEqL4QEIUH0io1Oy8TyWN5zX7F0v6bIKX0w1Z5JFXK++KiLhktBsVKf542R6KiEX9lkUeed2Wx64+kBDFBxLqtuJv6NMs8sjrqryuOsYHUEe3bfEBVEDxgYQoPpAQxQcSovhAQv8BVOSY4UmSu60AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Compare the performance\n", "from sklearn.datasets import load_digits\n", "from sklearn.linear_model.logistic import LogisticRegression\n", "from sklearn import decomposition\n", "from sklearn.metrics import confusion_matrix\n", "from sklearn.metrics import accuracy_score\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "# load digital data\n", "digits, dig_label = load_digits(return_X_y=True)\n", "print(digits.shape)\n", "\n", "# draw one digital\n", "plt.gray() \n", "plt.matshow(digits[0].reshape([8, 8])) \n", "plt.show() \n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy train = 0.998608, accuracy_test = 0.897222\n" ] } ], "source": [ "\n", "# calculate train/test data number\n", "N = len(digits)\n", "N_train = int(N*0.8)\n", "N_test = N - N_train\n", "\n", "# split train/test data\n", "x_train = digits[:N_train, :]\n", "y_train = dig_label[:N_train]\n", "x_test = digits[N_train:, :]\n", "y_test = dig_label[N_train:]\n", "\n", "# do logistic regression\n", "lr=LogisticRegression()\n", "lr.fit(x_train,y_train)\n", "\n", "pred_train = lr.predict(x_train)\n", "pred_test = lr.predict(x_test)\n", "\n", "# calculate train/test accuracy\n", "acc_train = accuracy_score(y_train, pred_train)\n", "acc_test = accuracy_score(y_test, pred_test)\n", "print(\"accuracy train = %f, accuracy_test = %f\" % (acc_train, acc_test))\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "accuracy train = 0.987474, accuracy_test = 0.894444\n" ] } ], "source": [ "# do PCA with 'n_components=40'\n", "pca = decomposition.PCA(n_components=40)\n", "pca.fit(x_train)\n", "\n", "x_train_pca = pca.transform(x_train)\n", "x_test_pca = pca.transform(x_test)\n", "\n", "# do logistic regression\n", "lr=LogisticRegression()\n", "lr.fit(x_train_pca,y_train)\n", "\n", "pred_train = lr.predict(x_train_pca)\n", "pred_test = lr.predict(x_test_pca)\n", "\n", "# calculate train/test accuracy\n", "acc_train = accuracy_score(y_train, pred_train)\n", "acc_test = accuracy_score(y_test, pred_test)\n", "print(\"accuracy train = %f, accuracy_test = %f\" % (acc_train, acc_test))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References\n", "* [Pipelining: chaining a PCA and a logistic regression](http://scikit-learn.org/stable/auto_examples/plot_digits_pipe.html)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" }, "main_language": "python" }, "nbformat": 4, "nbformat_minor": 2 }