You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ClusteringAlgorithms.py 7.1 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # -*- coding: utf-8 -*-
  2. # ---
  3. # jupyter:
  4. # jupytext_format_version: '1.2'
  5. # kernelspec:
  6. # display_name: Python 3
  7. # language: python
  8. # name: python3
  9. # language_info:
  10. # codemirror_mode:
  11. # name: ipython
  12. # version: 3
  13. # file_extension: .py
  14. # mimetype: text/x-python
  15. # name: python
  16. # nbconvert_exporter: python
  17. # pygments_lexer: ipython3
  18. # version: 3.5.2
  19. # ---
  20. # # Comparing different clustering algorithms on toy datasets
  21. #
  22. # This example shows characteristics of different clustering algorithms on datasets that are “interesting” but still in 2D. With the exception of the last dataset, the parameters of each of these dataset-algorithm pairs has been tuned to produce good clustering results. Some algorithms are more sensitive to parameter values than others.
  23. # The last dataset is an example of a ‘null’ situation for clustering: the data is homogeneous, and there is no good clustering. For this example, the null dataset uses the same parameters as the dataset in the row above it, which represents a mismatch in the parameter values and the data structure.
  24. # While these examples give some intuition about the algorithms, this intuition might not apply to very high dimensional data.
  25. # +
  26. % matplotlib inline
  27. import time
  28. import warnings
  29. import numpy as np
  30. import matplotlib.pyplot as plt
  31. from sklearn import cluster, datasets, mixture
  32. from sklearn.neighbors import kneighbors_graph
  33. from sklearn.preprocessing import StandardScaler
  34. from itertools import cycle, islice
  35. np.random.seed(0)
  36. # ============
  37. # Generate datasets. We choose the size big enough to see the scalability
  38. # of the algorithms, but not too big to avoid too long running times
  39. # ============
  40. n_samples = 1500
  41. noisy_circles = datasets.make_circles(n_samples=n_samples, factor=.5,
  42. noise=.05)
  43. noisy_moons = datasets.make_moons(n_samples=n_samples, noise=.05)
  44. blobs = datasets.make_blobs(n_samples=n_samples, random_state=8)
  45. no_structure = np.random.rand(n_samples, 2), None
  46. # Anisotropicly distributed data
  47. random_state = 170
  48. X, y = datasets.make_blobs(n_samples=n_samples, random_state=random_state)
  49. transformation = [[0.6, -0.6], [-0.4, 0.8]]
  50. X_aniso = np.dot(X, transformation)
  51. aniso = (X_aniso, y)
  52. # blobs with varied variances
  53. varied = datasets.make_blobs(n_samples=n_samples,
  54. cluster_std=[1.0, 2.5, 0.5],
  55. random_state=random_state)
  56. # ============
  57. # Set up cluster parameters
  58. # ============
  59. plt.figure(figsize=(9 * 2 + 3, 12.5))
  60. plt.subplots_adjust(left=.02, right=.98, bottom=.001, top=.96, wspace=.05,
  61. hspace=.01)
  62. plot_num = 1
  63. default_base = {'quantile': .3,
  64. 'eps': .3,
  65. 'damping': .9,
  66. 'preference': -200,
  67. 'n_neighbors': 10,
  68. 'n_clusters': 3}
  69. datasets = [
  70. (noisy_circles, {'damping': .77, 'preference': -240,
  71. 'quantile': .2, 'n_clusters': 2}),
  72. (noisy_moons, {'damping': .75, 'preference': -220, 'n_clusters': 2}),
  73. (varied, {'eps': .18, 'n_neighbors': 2}),
  74. (aniso, {'eps': .15, 'n_neighbors': 2}),
  75. (blobs, {}),
  76. (no_structure, {})]
  77. for i_dataset, (dataset, algo_params) in enumerate(datasets):
  78. # update parameters with dataset-specific values
  79. params = default_base.copy()
  80. params.update(algo_params)
  81. X, y = dataset
  82. # normalize dataset for easier parameter selection
  83. X = StandardScaler().fit_transform(X)
  84. # estimate bandwidth for mean shift
  85. bandwidth = cluster.estimate_bandwidth(X, quantile=params['quantile'])
  86. # connectivity matrix for structured Ward
  87. connectivity = kneighbors_graph(
  88. X, n_neighbors=params['n_neighbors'], include_self=False)
  89. # make connectivity symmetric
  90. connectivity = 0.5 * (connectivity + connectivity.T)
  91. # ============
  92. # Create cluster objects
  93. # ============
  94. ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
  95. two_means = cluster.MiniBatchKMeans(n_clusters=params['n_clusters'])
  96. ward = cluster.AgglomerativeClustering(
  97. n_clusters=params['n_clusters'], linkage='ward',
  98. connectivity=connectivity)
  99. spectral = cluster.SpectralClustering(
  100. n_clusters=params['n_clusters'], eigen_solver='arpack',
  101. affinity="nearest_neighbors")
  102. dbscan = cluster.DBSCAN(eps=params['eps'])
  103. affinity_propagation = cluster.AffinityPropagation(
  104. damping=params['damping'], preference=params['preference'])
  105. average_linkage = cluster.AgglomerativeClustering(
  106. linkage="average", affinity="cityblock",
  107. n_clusters=params['n_clusters'], connectivity=connectivity)
  108. birch = cluster.Birch(n_clusters=params['n_clusters'])
  109. gmm = mixture.GaussianMixture(
  110. n_components=params['n_clusters'], covariance_type='full')
  111. clustering_algorithms = (
  112. ('MiniBatchKMeans', two_means),
  113. ('AffinityPropagation', affinity_propagation),
  114. ('MeanShift', ms),
  115. ('SpectralClustering', spectral),
  116. ('Ward', ward),
  117. ('AgglomerativeClustering', average_linkage),
  118. ('DBSCAN', dbscan),
  119. ('Birch', birch),
  120. ('GaussianMixture', gmm)
  121. )
  122. for name, algorithm in clustering_algorithms:
  123. t0 = time.time()
  124. # catch warnings related to kneighbors_graph
  125. with warnings.catch_warnings():
  126. warnings.filterwarnings(
  127. "ignore",
  128. message="the number of connected components of the " +
  129. "connectivity matrix is [0-9]{1,2}" +
  130. " > 1. Completing it to avoid stopping the tree early.",
  131. category=UserWarning)
  132. warnings.filterwarnings(
  133. "ignore",
  134. message="Graph is not fully connected, spectral embedding" +
  135. " may not work as expected.",
  136. category=UserWarning)
  137. algorithm.fit(X)
  138. t1 = time.time()
  139. if hasattr(algorithm, 'labels_'):
  140. y_pred = algorithm.labels_.astype(np.int)
  141. else:
  142. y_pred = algorithm.predict(X)
  143. plt.subplot(len(datasets), len(clustering_algorithms), plot_num)
  144. if i_dataset == 0:
  145. plt.title(name, size=18)
  146. colors = np.array(list(islice(cycle(['#377eb8', '#ff7f00', '#4daf4a',
  147. '#f781bf', '#a65628', '#984ea3',
  148. '#999999', '#e41a1c', '#dede00']),
  149. int(max(y_pred) + 1))))
  150. plt.scatter(X[:, 0], X[:, 1], s=10, color=colors[y_pred])
  151. plt.xlim(-2.5, 2.5)
  152. plt.ylim(-2.5, 2.5)
  153. plt.xticks(())
  154. plt.yticks(())
  155. plt.text(.99, .01, ('%.2fs' % (t1 - t0)).lstrip('0'),
  156. transform=plt.gca().transAxes, size=15,
  157. horizontalalignment='right')
  158. plot_num += 1
  159. plt.show()
  160. # -
  161. # ## Reference
  162. # * [Comparing different clustering algorithms on toy datasets](http://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html)

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。