Browse Source

Improve knn classification

savefigrue
bushuhui 3 years ago
parent
commit
53c89eb3a1
3 changed files with 41 additions and 29 deletions
  1. +41
    -29
      2_knn/knn_classification.ipynb
  2. BIN
      2_knn/knn_test_data.pdf
  3. BIN
      2_knn/knn_train_data.pdf

+ 41
- 29
2_knn/knn_classification.ipynb View File

@@ -62,7 +62,7 @@
"source": [ "source": [
"### 1.1 距离计算\n", "### 1.1 距离计算\n",
"\n", "\n",
"要度量空间中点距离的话,有好几种度量方式,比如常见的曼哈顿距离计算、欧式距离计算等等。不过通常 KNN 算法中使用的是欧式距离。这里只是简单说一下,拿二维平面为例,二维空间两个点的欧式距离计算公式如下:\n",
"要度量空间中点距离的话,有好几种度量方式,比如常见的曼哈顿距离计算、欧式距离计算等等。不过通常 kNN 算法中使用的是欧式距离。这里只是简单说一下,拿二维平面为例,二维空间两个点的欧式距离计算公式如下:\n",
"$$\n", "$$\n",
"d = \\sqrt{(x_2 - x_1)^2 + (y_2 - y_1)^2}\n", "d = \\sqrt{(x_2 - x_1)^2 + (y_2 - y_1)^2}\n",
"$$\n", "$$\n",
@@ -72,7 +72,7 @@
"d(p, q) = \\sqrt{ (p_1-q_1)^2 + (p_1-q_1)^2 + ... + (p_n-q_n)^2 } = \\sqrt{ \\sum_{i=1,n} (p_i-q_i)^2}\n", "d(p, q) = \\sqrt{ (p_1-q_1)^2 + (p_1-q_1)^2 + ... + (p_n-q_n)^2 } = \\sqrt{ \\sum_{i=1,n} (p_i-q_i)^2}\n",
"$$\n", "$$\n",
"\n", "\n",
"这样我们就明白了如何计算距离。kNN 算法最简单粗暴的就是将 `预测点` 与 `所有点` 距离进行计算,然后保存并排序,选出前面 k 个值看看哪些类别比较多。"
"kNN 算法最简单粗暴的就是将 `预测点` 与 `所有点` 距离进行计算,然后保存并排序,选出前面 k 个值看看哪些类别比较多。"
] ]
}, },
{ {
@@ -82,7 +82,7 @@
"\n", "\n",
"## 2. 机器学习的思维模型\n", "## 2. 机器学习的思维模型\n",
"\n", "\n",
"针对kNN方法的提出机器学习的思维模型,在给定问题的情况下,是如何思考并解决机器学习问题\n",
"针对kNN方法从原理、算法、到实现,可以得出机器学习的思维模型,在给定问题的情况下,是如何思考并解决机器学习问题\n",
"\n", "\n",
"![machine learning - methodology](images/ml_methodology.png)\n", "![machine learning - methodology](images/ml_methodology.png)\n",
"\n", "\n",
@@ -112,7 +112,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -146,7 +146,7 @@
"import numpy as np\n", "import numpy as np\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"\n", "\n",
"# data generation\n",
"# 生成模拟数据\n",
"np.random.seed(314)\n", "np.random.seed(314)\n",
"\n", "\n",
"data_size1 = 100\n", "data_size1 = 100\n",
@@ -158,7 +158,7 @@
"y2 = [1 for _ in range(data_size2)]\n", "y2 = [1 for _ in range(data_size2)]\n",
"\n", "\n",
"\n", "\n",
"# all sample data\n",
"# 合并生成全部数据\n",
"x = np.concatenate((x1, x2), axis=0)\n", "x = np.concatenate((x1, x2), axis=0)\n",
"y = np.concatenate((y1, y2), axis=0)\n", "y = np.concatenate((y1, y2), axis=0)\n",
"\n", "\n",
@@ -167,7 +167,7 @@
"x = x[shuffled_index]\n", "x = x[shuffled_index]\n",
"y = y[shuffled_index]\n", "y = y[shuffled_index]\n",
"\n", "\n",
"# split train & test\n",
"# 分割训练与测试数据\n",
"split_index = int(data_size_all*0.7)\n", "split_index = int(data_size_all*0.7)\n",
"x_train = x[:split_index]\n", "x_train = x[:split_index]\n",
"y_train = y[:split_index]\n", "y_train = y[:split_index]\n",
@@ -175,12 +175,14 @@
"y_test = y[split_index:]\n", "y_test = y[split_index:]\n",
"\n", "\n",
"\n", "\n",
"# plot data\n",
"# 绘制结果\n",
"plt.scatter(x_train[:,0], x_train[:,1], c=y_train, marker='.')\n", "plt.scatter(x_train[:,0], x_train[:,1], c=y_train, marker='.')\n",
"plt.title(\"train data\")\n", "plt.title(\"train data\")\n",
"plt.savefig(\"knn_train_data.pdf\")\n",
"plt.show()\n", "plt.show()\n",
"plt.scatter(x_test[:,0], x_test[:,1], c=y_test, marker='.')\n", "plt.scatter(x_test[:,0], x_test[:,1], c=y_test, marker='.')\n",
"plt.title(\"test data\")\n", "plt.title(\"test data\")\n",
"plt.savefig(\"knn_test_data.pdf\")\n",
"plt.show()\n" "plt.show()\n"
] ]
}, },
@@ -209,38 +211,31 @@
"import operator\n", "import operator\n",
"\n", "\n",
"def knn_distance(v1, v2):\n", "def knn_distance(v1, v2):\n",
" \"\"\"计算两个多维向量的距离\"\"\"\n",
" return np.sum(np.square(v1-v2))\n", " return np.sum(np.square(v1-v2))\n",
"\n", "\n",
"def knn_vote(ys):\n", "def knn_vote(ys):\n",
" method = 1\n",
" \n",
" # method 1\n",
" if method == 1:\n",
" vote_dict = {}\n",
" \"\"\"根据ys的类别,挑选类别最多一类作为输出\"\"\"\n",
" vote_dict = {}\n",
" for y in ys:\n", " for y in ys:\n",
" if y not in vote_dict.keys():\n", " if y not in vote_dict.keys():\n",
" vote_dict[y] = 1\n", " vote_dict[y] = 1\n",
" else:\n", " else:\n",
" vote_dict[y] += 1\n", " vote_dict[y] += 1\n",
" \n",
" method = 1\n",
" \n",
" # 方法1 - 使用排序的方法\n",
" if method == 1:\n",
" sorted_vote_dict = sorted(vote_dict.items(), \\\n", " sorted_vote_dict = sorted(vote_dict.items(), \\\n",
" #key=operator.itemgetter(1), \\\n", " #key=operator.itemgetter(1), \\\n",
" key=lambda x:x[1], \\\n", " key=lambda x:x[1], \\\n",
" reverse=True)\n", " reverse=True)\n",
" \n",
" return sorted_vote_dict[0][0]\n", " return sorted_vote_dict[0][0]\n",
" \n", " \n",
" # method 2\n",
" # 方法2 - 使用循环遍历找到类别最多的一类\n",
" if method == 2:\n", " if method == 2:\n",
" maxv = 0\n",
" maxk = 0\n",
" \n",
" vote_dict = {}\n",
" for y in ys:\n",
" if y not in vote_dict.keys():\n",
" vote_dict[y] = 1\n",
" else:\n",
" vote_dict[y] += 1\n",
" \n",
" maxv = maxk = 0 \n",
" for y in np.unique(ys):\n", " for y in np.unique(ys):\n",
" if maxv < vote_dict[y]:\n", " if maxv < vote_dict[y]:\n",
" maxv = vote_dict[y]\n", " maxv = vote_dict[y]\n",
@@ -248,6 +243,14 @@
" return maxk\n", " return maxk\n",
" \n", " \n",
"def knn_predict(x, train_x, train_y, k=3):\n", "def knn_predict(x, train_x, train_y, k=3):\n",
" \"\"\"\n",
" 针对给定的数据进行分类\n",
" 参数\n",
" x - 输入的待分类样本\n",
" train_x - 训练数据的样本\n",
" train_y - 训练数据的标签\n",
" k - 最近邻的样本个数\n",
" \"\"\"\n",
" dist_arr = [knn_distance(x, train_x[j]) for j in range(len(train_x))]\n", " dist_arr = [knn_distance(x, train_x[j]) for j in range(len(train_x))]\n",
" sorted_index = np.argsort(dist_arr)\n", " sorted_index = np.argsort(dist_arr)\n",
" top_k_index = sorted_index[:k]\n", " top_k_index = sorted_index[:k]\n",
@@ -255,8 +258,7 @@
" return knn_vote(ys)\n", " return knn_vote(ys)\n",
" \n", " \n",
"\n", "\n",
"#a = knn_predict(x_train[0], x_train, y_train)\n",
"\n",
"# 对每个样本进行分类\n",
"y_train_est = [knn_predict(x_train[i], x_train, y_train) for i in range(len(x_train))]\n", "y_train_est = [knn_predict(x_train[i], x_train, y_train) for i in range(len(x_train))]\n",
"print(y_train_est)" "print(y_train_est)"
] ]
@@ -275,6 +277,7 @@
} }
], ],
"source": [ "source": [
"# 计算训练数据的精度\n",
"n_correct = 0\n", "n_correct = 0\n",
"for i in range(len(x_train)):\n", "for i in range(len(x_train)):\n",
" if y_train_est[i] == y_train[i]:\n", " if y_train_est[i] == y_train[i]:\n",
@@ -298,6 +301,7 @@
} }
], ],
"source": [ "source": [
"# 计算测试数据的精度\n",
"y_test_est = [knn_predict(x_test[i], x_train, y_train, 3) for i in range(len(x_test))]\n", "y_test_est = [knn_predict(x_test[i], x_train, y_train, 3) for i in range(len(x_test))]\n",
"n_correct = 0\n", "n_correct = 0\n",
"for i in range(len(x_test)):\n", "for i in range(len(x_test)):\n",
@@ -317,7 +321,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -325,19 +329,26 @@
"import operator\n", "import operator\n",
"\n", "\n",
"class KNN(object):\n", "class KNN(object):\n",
"\n",
" def __init__(self, k=3):\n", " def __init__(self, k=3):\n",
" \"\"\"对象构造函数,参数为:\n",
" k - 近邻个数\"\"\"\n",
" self.k = k\n", " self.k = k\n",
"\n", "\n",
" def fit(self, x, y):\n", " def fit(self, x, y):\n",
" \"\"\"拟合给定的数据,参数为:\n",
" x - 样本的特征;y - 样本的标签\"\"\"\n",
" self.x = x\n", " self.x = x\n",
" self.y = y\n", " self.y = y\n",
" return self\n", " return self\n",
"\n", "\n",
" def _square_distance(self, v1, v2):\n", " def _square_distance(self, v1, v2):\n",
" \"\"\"计算两个样本点的特征空间距离,参数为:\n",
" v1 - 样本点1;v2 - 样本点2\"\"\"\n",
" return np.sum(np.square(v1-v2))\n", " return np.sum(np.square(v1-v2))\n",
"\n", "\n",
" def _vote(self, ys):\n", " def _vote(self, ys):\n",
" \"\"\"投票算法,参数为:\n",
" ys - k个近邻样本的类别\"\"\"\n",
" ys_unique = np.unique(ys)\n", " ys_unique = np.unique(ys)\n",
" vote_dict = {}\n", " vote_dict = {}\n",
" for y in ys:\n", " for y in ys:\n",
@@ -349,6 +360,7 @@
" return sorted_vote_dict[0][0]\n", " return sorted_vote_dict[0][0]\n",
"\n", "\n",
" def predict(self, x):\n", " def predict(self, x):\n",
" \n",
" y_pred = []\n", " y_pred = []\n",
" for i in range(len(x)):\n", " for i in range(len(x)):\n",
" dist_arr = [self._square_distance(x[i], self.x[j]) for j in range(len(self.x))]\n", " dist_arr = [self._square_distance(x[i], self.x[j]) for j in range(len(self.x))]\n",


BIN
2_knn/knn_test_data.pdf View File


BIN
2_knn/knn_train_data.pdf View File


Loading…
Cancel
Save