머신러닝

1-4 Iris flowers 데이터 Scikit k-means clustering 비지도학습

coding art 2020. 1. 4. 12:39
728x90




Setosa, Versicolor Virginica 로 구성된 Iris flowers 데이터에 대한 K-means 클러스터링 결과를 비교해 보자. Versicolor Viginica가 겹치는 인접 부위의 Classification 결과에 많은 오차가 있음을 알 수 있다. 두 종류의 데이터가 섞일 정도로 인접하게 되는 영역의 데이터들을 Support Vectors라고 하며 이러한 영역에서 Overfitting을 피하면서 Classification을 효율적으로 수행할 수 있는 SVM(Support Vector Machine) 기법을 참고하도록 한다.


K-means 클러스트링도 Classification 기법이지만 사전 데이터 학습이란 개념이 없이 통계학적인 평균, 분산 및 기하학적인 거리 요소를 사용하기 때문에 위와 같은 결과가 필연적으로 얻어질 수밖에 없다.

 

#Iris_data_kmeans_plot_03.py


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap


# ### Plotting the Iris data
# select setosa and versicolor
df = pd.read_csv('https://archive.ics.uci.edu/ml/'
        'machine-learning-databases/iris/iris.data', header=None)
y = df.iloc[0:150, 4].values
y = np.where(y == 'Iris-setosa', -1, 1)

# extract sepal length and petal length
X = df.iloc[0:150, [0, 2]].values

# plot data
plt.scatter(X[:50, 0], X[:50, 1],
            color='red', marker='o', label='setosa')
plt.scatter(X[50:100, 0], X[50:100, 1],
            color='blue', marker='x', label='versicolor')
plt.scatter(X[100:150, 0], X[100:150, 1],
            color='green', marker='d', label='virginica')

plt.xlabel('sepal length [cm]')
plt.ylabel('petal length [cm]')
plt.legend(loc='upper left')
plt.show()


#Kmeans processing
from sklearn.cluster import KMeans

km = KMeans(n_clusters=3, init='k-means++', n_init=10, max_iter=300,
            tol=1e-04,random_state=0)

y_km = km.fit_predict(X)

plt.scatter(X[y_km == 0, 0],X[y_km == 0, 1], s=50, c='lightgreen',
            marker='s',edgecolor='black', label='cluster 1')
plt.scatter(X[y_km == 1, 0], X[y_km == 1, 1], s=50, c='orange',
            marker='o',edgecolor='black',label='cluster 2')
plt.scatter(X[y_km == 2, 0], X[y_km == 2, 1],s=50, c='lightblue',
            marker='v', edgecolor='black',label='cluster 3')
plt.scatter(km.cluster_centers_[:, 0], km.cluster_centers_[:, 1],
            s=250, marker='*', c='red', edgecolor='black',
            label='centroids')
plt.xlabel('sepal length [cm]')
plt.ylabel('petal length [cm]')
plt.legend(loc='upper left')