|
- import numpy as np
- import random
- from matplotlib import pyplot as plt
- import cv2
- import os
-
- def loss_function(present_center, pre_center):
- '''
- 损失函数,计算上一次与当前聚类中的差异(像素差的平方和)
- :param present_center: 当前聚类中心
- :param pre_center: 上一次聚类中心
- :return: 损失值
- '''
- present_center = np.array(present_center)
- pre_center = np.array(pre_center)
- return np.sum((present_center - pre_center)**2)
-
-
- def classifer(intput_signal, center):
- '''
- 分类器(通过当前的聚类中心,给输入图像分类)
- :param intput_signal: 输入图像
- :param center: 聚类中心
- :return: 标签矩阵
- '''
- if len(intput_signal.shape) == 2:
- input_row, input_col = intput_signal.shape # 输入图像的尺寸
- else:
- input_row, input_col, layers = intput_signal.shape # 输入图像的尺寸
-
- pixls_labels = np.zeros((input_row, input_col)) # 储存所有像素标签
-
- pixl_distance_t = [] # 单个元素与所有聚类中心的距离,临时用
-
- for i in range(input_row):
- for j in range(input_col):
- # 计算每个像素与所有聚类中心的差平方
- for k in range(len(center)):
- distance_t = np.sum(abs((intput_signal[i, j]).astype(int) - center[k].astype(int))**2)
- pixl_distance_t.append(distance_t)
- # 差异最小则为该类
- pixls_labels[i, j] = int(pixl_distance_t.index(min(pixl_distance_t)))
- # 清空该list,为下一个像素点做准备
- pixl_distance_t = []
- return pixls_labels
-
-
- def k_means(input_signal, center_num, threshold):
- '''
- 基于k-means算法的图像分割(适用于灰度图)
- :param input_signal: 输入图像
- :param center_num: 聚类中心数目
- :param threshold: 迭代阈值
- :return:
- '''
- input_signal_cp = np.copy(input_signal) # 输入信号的副本
- if len(input_signal_cp.shape) == 2:
- input_row, input_col = input_signal_cp.shape # 输入图像的尺寸
- is_gray = True
- else:
- input_row, input_col, layers = input_signal_cp.shape # 输入图像的尺寸
- is_gray = False
- pixls_labels = np.zeros((input_row, input_col)) # 储存所有像素标签
-
- is_present_center_random = True
- if os.path.exists('./present_center.npz'):
- present_center_npz = np.load('./present_center.npz', allow_pickle=True)
- present_center = present_center_npz['center_arr']
- is_gray_old = present_center_npz['is_gray']
- pixls_labels = present_center_npz['pixls_labels']
-
- if is_gray_old == is_gray:
- is_present_center_random = False
-
- if is_present_center_random:
- # 随机初始聚类中心行标与列标
- initial_center_row_num = [i for i in range(input_row)]
- random.shuffle(initial_center_row_num)
- initial_center_row_num = initial_center_row_num[:center_num]
-
- initial_center_col_num = [i for i in range(input_col)]
- random.shuffle(initial_center_col_num)
- initial_center_col_num = initial_center_col_num[:center_num]
-
- # 当前的聚类中心
- present_center = []
- for i in range(center_num):
- present_center.append(input_signal_cp[initial_center_row_num[i], initial_center_row_num[i]])
- pixls_labels = classifer(input_signal_cp, present_center)
-
-
- print("Start Train")
- num = 0 # 用于记录迭代次数
- while True:
- pre_centet = present_center.copy() # 储存前一次的聚类中心
- # 计算当前聚类中心
- for n in range(center_num):
- temp = np.where(pixls_labels == n)
- present_center[n] = sum(input_signal_cp[temp].astype(int)) / len(input_signal_cp[temp])
- # 根据当前聚类中心分类
- pixls_labels = classifer(input_signal_cp, present_center)
- # 计算上一次聚类中心与当前聚类中心的差异
- loss = loss_function(present_center, pre_centet)
- num = num + 1
- print("Step:"+ str(num) + " Loss:" + str(loss))
- # 当损失小于迭代阈值时,结束迭代
- if loss <= threshold:
- np.savez('./present_center.npz', center_arr=present_center, is_gray=is_gray, pixls_labels=pixls_labels)
- break
- return pixls_labels
-
- if __name__ == '__main__':
- image = cv2.imread('10008.jpg', cv2.IMREAD_COLOR)
- # image = cv2.imread('10008.jpg', cv2.IMREAD_GRAYSCALE)
-
- #如果本身是彩图 似乎也没什么影响!
- #plt显示按照 rgb次序!因此要转换
- b,g,r = cv2.split(image)
- image = cv2.merge([r,g,b])
-
- image = cv2.resize(image, (100, 100))
-
- k = 3
- threshold = 1
- labels = k_means(image, k, threshold)
-
- plt.subplot(1, 2, 1)
- plt.title("Soucre Image")
- plt.imshow(image,cmap="gray")
- plt.subplot(1, 2, 2)
- plt.title("Segamenting Image with k-means\n" + "k=" + str(k) + " threshold=" + str(threshold))
- plt.imshow(labels/3)
- plt.show()
|