最简单的K-means++算法原理和实践教程
发布日期:2021-06-29 16:00:31 浏览次数:2 分类:技术文章

本文共 3867 字,大约阅读时间需要 12 分钟。

在之前的最后我提到了这样的一个问题,你可以通过一些实验发现,K-means算法的最后聚类结果和初始化k个中心的位置有着极大的关系。

而我们在前文中提到过一些不同的初始化方法(前文中使用的是第一种初始化方法)。我们这里的K-mean++算法使用的初始化方法,实际是第三种:

  • 随机地选择第一个点,或取所有点的质心作为第一个点。然后,对于每个后继初始质心,选择离已经选取过的初始质心最远的点。使用这种方法,确保了选择的初始质心不仅是随机的,而且是散开的。但是,这种方法可能选中离群点。此外,求离当前初始质心集最远的点开销也非常大。为了克服这个问题,通常该方法用于点样本。由于离群点很少(多了就不是离群点了),它们多半不会在随机样本中出现。计算量也大幅减少。

我们知道之前的K-means算法思路是这样子的:

  1. 选取k个初始中心点 C=c1,...,ck .
  2. 对于每一个 i1,...,k , 将 Ci 设置为 X 中比所有
    ji
    都靠近的点 cj 的集合.
  3. 对于每一个 i1,...,k , 将 ci 设置为 Ci 中所有点的质心: ci=1|Ci|xCix
  4. 重复(2)(3),直到所有 C 值的变化小于给定阈值或者达到最大迭代次数。

现在的K-means++算法思路是这样子的

  1. X
    中随机选取一个中心点 c1 .
  2. 计算 X 中的每一个样本点与
    c1
    之间的距离,通过计算概率 D(x)2xXD(x)2 ( D(x) 表示每个样本点到最近中心的距离), 选出概率最大的值对应的点作为下一个中心 ci=xX
  3. 重复步骤(2),直到我们选择了所有k个中心
  4. 对k个初始化的中心,利用K-means算法计算最终的中心。

沿着上述算法思路,我们可以很快的给出对应的code

'''_k_init(data, k, x_squared_norms, random_state, n_local_trials=None)作用:根据k-means++初始化质心data : 输入数据k : 中心数x_squared_norms : 每个数据点的2范数的平方random_state : 随机数生成器,用于初始化中心n_local_trials :通过一种特别的方式对K-means聚类选择初始簇中心,从而加快收敛速度'''      def _k_init(self, x_squared_norms, random_state, n_local_trials=None):    n_samples, n_features = self.data.shape    centers = np.empty((self.k, n_features), dtype=self.data.dtype)    assert x_squared_norms is not None, 'x_squared_norms None in _k_init'    if n_local_trials is None:        # This is what Arthur/Vassilvitskii tried, but did not report        # specific results for other than mentioning in the conclusion        # that it helped.        n_local_trials = 2 + int(np.log(self.k))    # 随机的选择第一个中心    center_id = random_state.randint(n_samples)    if sp.issparse(self.data):        centers[0] = self.data[center_id].toarray()    else:        centers[0] = self.data[center_id]    # 初始化最近距离的列表,并计算当前概率    closest_dist_sq = euclidean_distances(        centers[0, np.newaxis], self.data, Y_norm_squared=x_squared_norms,        squared=True)#计算X与中心的距离的平方得到距离矩阵    current_pot = closest_dist_sq.sum()#距离矩阵的和    # 选择其余n_clusters-1点    for c in range(1, self.k):        # 通过概率的比例选择中心点候选点        # 离已经存在的中心最近的距离的平方        rand_vals = random_state.random_sample(n_local_trials) * current_pot        #将rand_vals插入原有序数组距离矩阵的累积求和矩阵中,并返回插入元素的索引值        candidate_ids = np.searchsorted(stable_cumsum(closest_dist_sq),                                        rand_vals)        # 计算离中心候选点的距离        distance_to_candidates = euclidean_distances(            self.data[candidate_ids], self.data, Y_norm_squared=x_squared_norms, squared=True)        # 决定哪个中心候选点是最好        best_candidate = None        best_pot = None        best_dist_sq = None        for trial in range(n_local_trials):            # Compute potential when including center candidate            new_dist_sq = np.minimum(closest_dist_sq,                                     distance_to_candidates[trial])            new_pot = new_dist_sq.sum()            # 如果是到目前为止最好的实验结果则存储该结果            if (best_candidate is None) or (new_pot < best_pot):                best_candidate = candidate_ids[trial]                best_pot = new_pot                best_dist_sq = new_dist_sq        # Permanently add best center candidate found in local tries        if sp.issparse(self.data):            centers[c] = self.data[best_candidate].toarray()        else:            centers[c] = self.data[best_candidate]        current_pot = best_pot        closest_dist_sq = best_dist_sq    return centers

以上代码作为K-means初始化部分,选用的是scikit-learn中的做法。

所以如果我们要使用K-means算法就很简单了,只要安装了scikit-learn,通过下面的代码就可以解决了

from sklearn.cluster import KMeansimport numpy as npimport matplotlib.pyplot as pltX = np.array([[1,3], [4, 3], [2, 4], [3, 1]])kmeans = KMeans(n_clusters=2, init='k-means++').fit_predict(X)plt.scatter(X[:, 0], X[:, 1], c=kmeans, s=100)plt.show()print(kmeans)

reference:

Arthur D, Vassilvitskii S. k-means++:the advantages of careful seeding[C]// Eighteenth Acm-Siam Symposium on Discrete Algorithms. Society for Industrial and Applied Mathematics, 2007:1027-1035.

转载地址:https://coordinate.blog.csdn.net/article/details/79804871 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:Iterative Visual Reasoning Beyond Convolutions论文笔记
下一篇:最简单的K-means算法原理和实践教程

发表评论

最新留言

很好
[***.229.124.182]2024年04月19日 10时33分50秒