머신러닝

1-2 Scikit-learn 에 의한 k-means++ clustering 비지도학습

coding art 2020. 1. 3. 20:40
728x90


고전적인 K-means 기법에 의해서 int=’random’ 조건을 사용하면서 클러스터링 수의 초기 값을 잘못 줄 경우 예기치 않은 결과가 얻어질 수 있는 사례이다. 앞서의 중심이 3개인 클러스터링 예제에서 클러스터링 수를 12로 계산한 결과를 관찰해 보자. 클러스터의 중심을 1 즉 하나로 두면 전체의 중심이 포착되며 2일 때에는 하나는 제대로 중심을 찾았으나 나머지는 전체의 중심점을 찾게 된다.


이와 같은 클러스터링 과정에서 초기 seed 점을 선정하는 방법에 random k-means++ 이 있다. random 은 말 그대로 데이터 포인트들 중에서 무작위로 선정하여 중심을 찾는 방법이다. 하지만 랜덤하게 초기 seed 를 잡아서 실패하는 사례가 있을 수 있어 좀 더 개선된 알고리듬이 k-means++ 이다. K-means++ 은 한꺼번에 무작위로 중심점들을 다 잡는 것이 아니라 일단 랜덤하게 첫 번째 중심점을 하나 설정 한 다음 나머지 점들과의 거리를 계산하여 가장 멀리 있는 점을 두 번째 중심점 후보로 잡는다. 그 다음은 이미 설정된 두 점에서 가장 먼점을 설정하는 이러한 방식으로 지정된 수만큼의 중심점들을 설정 한 후에 K-means 알고리듬이 적용되면 실패 가능성이 최소한으로 줄어들게 된다.

KMeans 루틴에서 init=’k-means++’ 로 두면 되고 결과는 init=’random’ 조건 사용 시와 거의 동일하다.

 

#k-means_data_plot_01.py

from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import time

start_time = time.time()
## Grouping objects by similarity using k-means
## K-means clustering using scikit-learn

X, y = make_blobs(n_samples=150,
                  n_features=2,
                  centers=3,
                  cluster_std=0.5,
                  shuffle=True,
                  random_state=0)

plt.scatter(X[:, 0], X[:, 1],
            c='white', marker='o',
            edgecolor='black', s=100)
plt.grid()
plt.tight_layout()
plt.show()

km = KMeans(n_clusters=3, init='k-means++', n_init=10, max_iter=300,
            tol=1e-04,random_state=0)

y_km = km.fit_predict(X)

plt.scatter(X[y_km == 0, 0],X[y_km == 0, 1], s=50, c='lightgreen',
            marker='s',edgecolor='black', label='cluster 1')
plt.scatter(X[y_km == 1, 0], X[y_km == 1, 1], s=50, c='orange',
            marker='o',edgecolor='black',label='cluster 2')
plt.scatter(X[y_km == 2, 0], X[y_km == 2, 1],s=50, c='lightblue',
            marker='v', edgecolor='black',label='cluster 3')
plt.scatter(km.cluster_centers_[:, 0], km.cluster_centers_[:, 1],
            s=250, marker='*', c='red', edgecolor='black',
            label='centroids')

plt.legend(scatterpoints=1)
plt.grid()
plt.tight_layout()
plt.show()

end_time = time.time()
print( "Completed in ", end_time - start_time , " seconds")