Browse Source

Pre Merge pull request !12 from 王煜熙/master

pull/12/MERGE
王煜熙 Gitee 3 years ago
parent
commit
aad6ca4fd6
1 changed files with 23 additions and 9 deletions
  1. +23
    -9
      2_knn/knn_classification.ipynb

+ 23
- 9
2_knn/knn_classification.ipynb View File

@@ -176,14 +176,26 @@
"\n",
"\n",
"# 绘制结果\n",
"plt.scatter(x_train[:,0], x_train[:,1], c=y_train, marker='.')\n",
"plt.title(\"train data\")\n",
"plt.savefig(\"knn_train_data.pdf\")\n",
"for i in range(split_index):\n",
" if y_train[i] == 0:\n",
" plt.scatter(x_train[i,0],x_train[i,1],c = 0, marker='.')\n",
" else:\n",
" plt.scatter(x_train[i,0],x_train[i,1],c = 1, marker='^') \n",
"plt.rcParams['figure.figsize']=(12.0, 8.0)\n",
"mpl.rcParams['font.family'] = 'SimHei'\n",
"plt.title(\"训练数据\")\n",
"plt.savefig(\"fig-res-train.pdf\")\n",
"plt.show()\n",
"plt.scatter(x_test[:,0], x_test[:,1], c=y_test, marker='.')\n",
"plt.title(\"test data\")\n",
"plt.savefig(\"knn_test_data.pdf\")\n",
"plt.show()\n"
"\n",
"for i in range(data_size_all - split_index):\n",
" if y_test[i] == 0:\n",
" plt.scatter(x_test[i,0],x_test[i,1],c = 0, marker='.')\n",
" else:\n",
" plt.scatter(x_test[i,0],x_test[i,1],c = 1, marker='^')\n",
"plt.rcParams['figure.figsize']=(12.0, 8.0)\n",
"plt.title(\"测试数据\")\n",
"plt.savefig(\"fig-res-test.pdf\")\n",
"plt.show()"
]
},
{
@@ -475,7 +487,9 @@
"for i in range(nplot):\n",
" img = X_digits[i].reshape(8, 8)\n",
" axes[i].imshow(img)\n",
" axes[i].set_title(y_digits[i])\n"
" axes[i].set_title(y_digits[i])\n",
"fig.set_size_inches(16,9)\n",
"fig.savefig('fig-res-digits.pdf')"
]
},
{
@@ -574,7 +588,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.4"
"version": "3.8.5"
}
},
"nbformat": 4,


Loading…
Cancel
Save