데이터 분석/머신러닝 독학하기

머신러닝 - MNIST 알아보기 (분류) - Python

Jerry Jun 2020. 11. 23. 21:49
728x90

머신러닝을 공부하는 사람이라면 한 번쯤 경험해봤거나 들어봤을 데이터셋이다. MNIST!

 

이 데이터 셋은 미국 고등학생들과 인구조사국 직원분들이 손으로 직접 쓴 숫자 데이터셋이다.

 

시작해보자.

 

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
mnist.keys()
---------------------------------------------
dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])

사이킷런을 통해서 mnist 데이터셋을 받아오는 과정이다.

데이터 내에 어떤 데이터가 있는지 keys() 로 확인해보았다.

 

 

 

1. 주요 데이터 확인하기

X, y = mnist['data'], mnist["target"]
print(X.shape, y.shape)
-------------------------------------
(70000, 784) (70000,)

data : 70000개의 이미지 데이터 / 데이터마다 784개의 feature 확인

target : 이미지 데이터가 어떤 숫자인지 알려줌.

 

 

 

import matplotlib.pyplot as plt

digit = X[0]
digit_image = digit.reshape(28, 28)

print(y[0])
plt.imshow(digit_image, cmap="binary")
plt.axis("off")
plt.show()
-------------------------------------
5

mnist['data'] 의 첫번째 데이터를 불러와서 사진을 출력했다. 5 처럼 보이는데 y 데이터를 보니 5로 나온다.

 

 

 

2. 훈련 / 테스트 데이터 만들기.

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

train 데이터에 60000개를 넣고 test 데이터에 10000개를 넣었습니다.

 

 

* MNIST 이진 분류기

y_train_5 = (y_train == 5)
y_test = (y_test == 5)

from sklearn.linear_model import SGDClassifier

sqd_clf = SGDClassifier(random_state=30)
sqd_clf.fit(X_train, y_train_5)

단순하게 숫자 5를 판단할 수 있는 훈련입니다. 5이면 true, 이외는 false 가 되는 것입니다.

훈련에서 사용한 모델은 SGDClassifier 이다.

 

SGDClassifier 는 "확률적 경사 하강법" 으로 큰 데이터셋을 가지고 있을 때 장점을 가지고 있다.

훈련 데이터를 하나씩 독립적으로 처리하기 때문이다.

 

코드를 보면 random_state 가 있는데 매개변수로 아무 숫자라도 상관이 없다.

 

sqd_clf.predict([digit])
-------------------------
array([False])

predict 는 훈련시킨 것에 대해 데이터를 넣어 예측하는 기능을 한다.

훈련을 한 sqd_clf 에 digit 를 넣었는데 False 가 나왔다. 정확성이 아직 뛰어나지 않나보다.

 

 


# 모델 성능 측정 방법 (1) - 교차 검증

from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone

skfolds = StratifiedKFold(n_splits = 3, random_state = 30, shuffle = True)

for train_index, test_index in skfolds.split(X_train, y_train_5):
    clone_clf = clone(sqd_clf)
    X_train_folds = X_train[train_index]
    y_train_folds = y_train_5[train_index]
    X_test_fold = X_train[test_index]
    y_test_fold = y_train_5[test_index]
    
    clone_clf.fit(X_train_folds, y_train_folds)
    y_pred = clone_clf.predict(X_test_fold)
    n_correct = sum(y_pred == y_test_fold)
    print(n_correct / len(y_pred))
-----------------------------------------------------------------
0.9659
0.95205
0.96675

모델 성능 측정 방법에는 많은 종류가 있는데 지금은 "교차 검증" 을 알아보았다.

교차 검증에 사용되는 StratifiedKFold 는 클래스별로 비율이 유지되도록 폴드를 만들기 위해 계층적 샘플링을 수행합니다. 여기에서 폴드란 비슷한 크기의 부분집합을 뜻합니다. for 문 반복문을 돌 때마다 객체를 복제하여 훈련 폴드로 훈련시키고 테스트 폴드로 예측을 실행합니다. 그 후, 올바른 예측의 수를 세어서 비율을 출력하였습니다.

 

비슷한 기능으로 coss_val_score() 함수를 이용해보겠습니다.

 

from sklearn.model_selection import cross_val_score

cross_val_score(sqd_clf, X_train, y_train_5, cv = 3, scoring = "accuracy")
--------------------------------------------------------
array([0.9577 , 0.95815, 0.9622 ])

모든 교차 검증 폴드에 대한 정확도가 95%를 넘은 것이 나타났습니다.

이 코드에서 cv 는 폴드의 수를 나타냅니다. 

이 코드에서 scoring 은 평가 지표를 설정하는 것이며 "accuracy" 는 정확도를 의미합니다.

 


# 모델 성능 측정 방법 (2) - 오차 행렬

오차 행렬은 훈련된 샘플에 대해 비교할 다른 샘플을 만들고 서로 비교하는 것입니다.

오차 행렬을 만들기 위해서는 실제 타겟과 비교할 수 있도록 예측값을 생성해야 합니다.

 

from sklearn.model_selection import cross_val_predict

y_train_pred = cross_val_predict(sqd_clf, X_train, y_train_5, cv = 3)
y_train_pred
-------------------------------------------------------------
array([ True, False, False, ...,  True, False, False])

이번 경우는 cross_val_predict() 함수를 사용하여 샘플을 만들었습니다.

아직까지 오차 행렬은 만들어지지 않았습니다.

 

from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_5, y_train_pred)
-----------------------------------------------
array([[54023,   556],
       [ 1883,  3538]], dtype=int64)

이제 오차행렬이 만들어졌습니다.

array 에서 행은 실제 클래스를 나타내고 열은 예측한 클래스를 나타냅니다.

해석하면 첫번째 행은 '5가 아닌 이미지' 에 대한 것으로 54,023 개의 이미지를 5가 아니라고 분류(True Negative)하였고, 556 개를 5 라고 분류(False Positive)하였습니다. 

두번째 행은 '5인 이미지' 에 대한 것으로 1,883 개를 5 아님으로 잘못 분류(False Negative)하였고, 3,538 개를 정확히 5라고 분류(True Positive)한 경우입니다. 만약 오차 행렬이 아닌 완벽한 분류기라면 556과 1883의 위치는 0이 될 것입니다.

 

 

y_train_perfect_predictions = y_train_5
confusion_matrix(y_train_5, y_train_perfect_predictions)
--------------------------------------------
array([[54579,     0],
       [    0,  5421]], dtype=int64)

이것이 완벽한 분류기의 형태입니다.

여기에서 알아야 할 것은 분류기의 정밀도(precision) 입니다.

정밀도는 TP / (TP + FP) 로 계산하며, 재현율이라는 지표와 함께 자주 사용합니다.

 

재현율은 분류기가 정확하게 감지한 양성 샘플의 비율로 민감도(sensitivity) 또는 진짜 양성 비율(True positive rate) 라고도 합니다. 재현율은 TP / (TP + FN) 으로 계산합니다. 

 

 

사이킷런에서는 정밀도와 재현율을 계산하는 여러 함수를 지원합니다.

from sklearn.metrics import precision_score, recall_score
print(precision_score(y_train_5, y_train_pred)) # 3538 / (3538 + 556)
print(recall_score(y_train_5, y_train_pred))    # 3538 / (3538 + 1883)
--------------------------------------------------
0.8641914997557401
0.6526471130787678

precision_score( ) 함수로 정밀도를 계산하였고 recall_score( ) 함수로 재현율을 계산하였습니다.

여기에서는 정확도가 86.4% 로 나름 높게 나왔습니다. 하지만 전체 5인 이미지에서 65.2%만 인식했다고 나오네요.

 

일반적으로 정밀도와 재현율을 하나로 합쳐 F 점수(F score) 또는 조화 평균으로 쓰이기도 합니다. 

F = TP / ( TP + ( FN + FP) / 2 ) 로 계산하며 사이킷런에서는 f1_score 함수를 사용하여 계산합니다.

 

 

from sklearn.metrics import f1_score
f1_score(y_train_5, y_train_pred)
---------------------------------
0.7436678928008408

f1_score 를 통해 나오는 점수는 정밀도와 재현율이 비슷한 정도에 비례합니다. 정밀도와 재현율이 비슷하면 그만큼 f1_score 점수가 커집니다. 하지만 항상 이 점수에 의존하면 안됩니다. 정밀도 점수가 중요한 경우가 있고 재현율 경우가 중요한 경우가 있기 때문입니다.

 

정밀도를 높이고 싶으면 재현율이 내려가고... 재현율을 높이고 싶으면 정밀도가 내려가는 현상을 정밀도/재현율 트레이드오프 라고 합니다. SGDClassifier 가 분류를 하는데 여러 임계 값이 있고 기준에 따라 양성, 음성으로 바뀝니다. 파이썬을 통해 여러 임계값을 보이려면 matplotlib 을 이용하면 됩니다.

 

y_scores = cross_val_predict(sqd_clf, X_train, y_train_5, cv = 3, method = 'decision_function')

from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)

def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[:-1], "y-", label = "정밀도")
    plt.plot(thresholds, recalls[:-1], "g--", label = "재현율")
    
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.legend()
plt.show()

chart

정밀도에 옥에티가 있지만 이런 곡선이 나옵니다. 정밀도가 올라가면 재현율이 낮아지고...

이러한 트레이드오프를 볼 수 있습니다. 다음 타임에는 다중 분류를 알아보겠습니다.

300x250