diff --git a/2_knn/knn_classification.ipynb b/2_knn/knn_classification.ipynb index 1b60650..40f693e 100644 --- a/2_knn/knn_classification.ipynb +++ b/2_knn/knn_classification.ipynb @@ -33,7 +33,8 @@ "4. 选取距离最小的`k`个点;\n", "5. 确定前`k`个点所在类别的出现频率;\n", "6. 返回前`k`个点中出现频率最高的类别作为测试数据的预测分类。\n", - "\n" + "\n", + "上述的处理过程,难点有哪些?" ] }, { @@ -45,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -121,12 +122,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 3. Simple Program" + "## 3. 最简单的程序实现" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -171,7 +172,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -193,14 +194,14 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Test Accuracy: 95.734597%\n" + "Test Accuracy: 96.682464%\n" ] } ], @@ -218,12 +219,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 4. Complex Program" + "## 4. 通过类实现kNN程序" ] }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -277,15 +278,15 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "train accuracy: 0.986\n", - "test accuracy: 0.967\n" + "train accuracy: 98.568507 %\n", + "test accuracy: 96.682464 %\n" ] } ], @@ -296,19 +297,18 @@ "\n", "# knn classifier\n", "clf = KNN(k=3)\n", - "acc = clf.fit(x_train, y_train).score()\n", - "\n", - "print('train accuracy: {:.3}'.format(clf.score()))\n", + "train_acc = clf.fit(x_train, y_train).score() * 100.0\n", + "test_acc = clf.score(y_test, y_test_pred) * 100.0\n", "\n", - "y_test_pred = clf.predict(x_test)\n", - "print('test accuracy: {:.3}'.format(clf.score(y_test, y_test_pred)))" + "print('train accuracy: %f %%' % train_acc)\n", + "print('test accuracy: %f %%' % test_acc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 4. sklearn program" + "## 5. sklearn program" ] }, { @@ -426,13 +426,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 5. 深入思考\n", + "## 6. 深入思考\n", "\n", "* 如果输入的数据非常多,怎么快速进行距离计算?\n", " - kd-tree\n", " - Fast Library for Approximate Nearest Neighbors (FLANN)\n", "* 如何选择最好的`k`?\n", - " - https://zhuanlan.zhihu.com/p/143092725" + " - https://zhuanlan.zhihu.com/p/143092725\n", + "* kNN存在的问题?" ] }, {