인공지능&머신러닝

[머신러닝]K-최근접 이웃(KNN모델)원리 및 코드 예시

zzheng 2024. 6. 24. 22:54

K-최근접 이웃(K-Nearest Neighbors, KNN) 알고리즘은 가장 간단하고 직관적인 지도 학습 알고리즘 중 하나로, 분류와 회귀 문제 모두에 사용됩니다. 

KNN이란?

KNN은 새로운 데이터 포인트의 클래스를 예측할 때, 그 포인트와 가장 가까운 K개의 데이터 포인트의 클래스를 참고합니다. 여기서 "가까움"은 거리 계산을 통해 측정됩니다. 

거리 계산: 새로운 데이터 포인트와 모든 훈련 데이터 포인트 간의 거리를 계산합니다. 일반적으로 유클리드 거리(Euclidean distance)를 사용하지만, 맨해튼 거리(Manhattan distance) 등 다른 거리 척도도 사용될 수 있습니다.

 

알고리즘

  1. 학습데이터가 주어짐 : 데이터를 클래스별로 저장해 놓음
  2. 분류할 새로운 데이터가 들어옴
    1. 입력 데이터와 가장 가까운 k개의 학습 데이터를 찾음
    2. 찾은 k개의 점이 속한 그룹 중에서 데이터가 가장 많은 그룹을 입력데이터의 그룹으로 정함(k는 항상 홀수, 3-10사이의 값이 최적)

 

KNN의 주요 특징

  1. 비모수적 알고리즘: KNN은 데이터의 분포에 대해 어떠한 가정을 하지 않으며, 모델 학습 단계에서 실제로는 학습을 하지 않고 데이터를 저장만 합니다.
  2. 단순성: 알고리즘이 매우 단순하여 이해하고 구현하기 쉽습니다.
  3. 연산 비용: 새로운 데이터 포인트를 예측할 때 저장된 모든 데이터 포인트와의 거리를 계산해야 하므로, 데이터셋이 클 경우 계산 비용이 큽니다.

 

데이터 정규화 변환

  • Minmax scaling

 

코드 예시

knn_scores = []
n_values = range(1, 21)
for k in n_values:
    knn_classifier = KNeighborsClassifier(n_neighbors= k)
    knn_classifier.fit(x_train_minmax, y_train)
    knn_scores.append(knn_classifier.score(x_test_minmax, y_test))
  • 시각화하기
plt.figure( figsize=(15, 3))
colors = rainbow(np.linspace(0, 1, len(n_values)) )
plt.bar(n_values, knn_scores, color = colors)

for i, n  in zip(range(len(n_values)), n_values):
    plt.text(n, knn_scores[i], knn_scores[i])
plt.xticks([n for n in n_values])
plt.xlabel('Number of Neighbors (K)')
plt.ylabel('Scores')
plt.title('K Neighbors Classifier scores for different K values')
plt.show()

  • 성능점수 출력
max_idx = np.argmax(knn_scores)
print("The score for K Neighbors Classifier is {}% with {} nieghbors.".format(knn_scores[max_idx]*100,
                                                                              n_values[max_idx]) )