MachineLearning

【5分でOK】K近傍法(knn法)のアルゴリズムとPythonを使った実装

K近傍法 knn法 アルゴリズム Python

 

 

K近傍法とは?

 

後ほど図解含めて書いていきます!!

 

【K近傍法】Pythonを使った実装

 

K近傍法の概略を掴んだところで、Pythonを使った実装をおこなっていきます。今回使ったのは、irisデータです。

※ブログ記事にするためにPythonファイルで書きましたが、通常はnotebook推奨です。

 

コード全体

import pandas as pd

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier

import matplotlib.pyplot as plt


def get_iris_data():
    iris = load_iris()
    X = pd.DataFrame(iris.data, columns=iris.feature_names)
    y = pd.DataFrame(iris.target, columns=['Species'])

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=8)

    return X_train, X_test, y_train, y_test

def get_accuracy_dict(X_train, X_test, y_train, y_test):
    accuracy_dict = {}

    for k in range(1,80):
        knn = KNeighborsClassifier(n_neighbors=k)
        knn.fit(X_train, y_train)
        y_pred = knn.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        accuracy_dict[k] = acc
    print('knn法の最大値:', max(accuracy_dict.values()))

    return accuracy_dict

def plot_accuracy(accuracy_dict):
    x = list(accuracy_dict.keys())
    y = list(accuracy_dict.values())

    plt.plot(x, y)
    plt.show()

if __name__ == "__main__":
    # 学習とテストデータに分割したirisデータを取得
    X_train, X_test, y_train, y_test = get_iris_data()
    # 分割したirisデータからknn法の精度を取得
    accuracy_dict = get_accuracy_dict(X_train, X_test, y_train, y_test)
    # kと精度の関係をグラフ化する
    plot_accuracy(accuracy_dict)

以上がコードの全体像になります。

これだけだと投げやりなので、詳しく解説を入れていきますね。

STEP① : ライブラリのインポート

 

まずは、今回使うPythonライブラリのインポートをおこなっていきます。

import pandas as pd

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier

import matplotlib.pyplot as plt

K近傍法を使うには、`from sklearn.neighbors import KNeighborsClassifier`が必要になってきます。

「学習 → 予測」した後の精度評価には、accuracy_scoreを使っていきます。

 

accuracy_scoreとは、正答率(Accuracy)のことです。つまり、予測データと実際のデータが一致している割合ですね。

あとは、最終的に使う近傍数(=K)と精度の推移をグラフ化するので、matplotlibもインポートしています。

STEP② : irisデータを読み込んで、学習用とテスト用に分割する

 

get_iris_data()で、学習用とテスト用に分割したirisデータを取得します。

def get_iris_data():
    iris = load_iris()
    X = pd.DataFrame(iris.data, columns=iris.feature_names)
    y = pd.DataFrame(iris.target, columns=['Species'])

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=8)

    return X_train, X_test, y_train, y_test

「学習 : テスト = 8 : 2」にしており、random_stateは8で固定しました。

random_stateに数値を渡してあげると、毎回同じ学習結果になります。

逆にrandom_stateを指定しないと、学習データとテストデータの分割がランダムになるので、毎回の学習結果が変わってきます。

 

STEP③ : K近傍法を使って学習&予測する

 

get_accuracy_dict()を使ってあげると、K近傍法を使った学習と精度予測までおこなってくれます。

def get_accuracy_dict(X_train, X_test, y_train, y_test):
    accuracy_dict = {}

    for k in range(1,80):
        knn = KNeighborsClassifier(n_neighbors=k)
        knn.fit(X_train, y_train)
        y_pred = knn.predict(X_test)
        acc = accuracy_score(y_test, y_pred)
        accuracy_dict[k] = acc
    print('knn法の最大値:', max(accuracy_dict.values()))

    return accuracy_dict

少し詳しく書くと、以下のような処理をおこなっています。

  • `knn = KNeighborsClassifier(n_neighbors=k)` : クラス分けするときの近傍に使うデータ数kを指定する
  • `knn.fit(X_train, y_train)` : 指定したkに応じて、学習をおこなう
  • `y_pred = knn.predict(X_test)` : 学習したモデルの予測をおこなう
  • `acc = accuracy_score(y_test, y_pred)` : 正答率を計算する

以上の処理をおこなって、最終的にKごとの精度を格納した辞書を返しています。

つまり、「K=1 ⇒ 精度=0.96、K=2 ⇒ 精度=0.965」のようなデータが入っているということですね。

STEP④ : 近傍数Kに応じた精度を可視化する

 

最終的にplot_accuracy()を使って、近傍数Kと精度のグラフを出力していきます。

def plot_accuracy(accuracy_dict):
    x = list(accuracy_dict.keys())
    y = list(accuracy_dict.values())

    plt.plot(x, y)
    plt.show()

 

この関数を実行すると、以下のようなグラフが出力されるようになっています。

K近傍法

僕の手元で実装したときには、何回か最も良い精度になるKが存在したようです。

またirisデータは、元のデータ数が150しかないので、あまり近傍数Kを大きくすると良い精度が得られないですね。

元のデータ数が150、そのうち学習データは全体の8割なので120。そして、irisデータは3種類のクラス分けになっているので、1つのクラスにつき40ずつになっています。よって、グラフでも分かるとおりで、40くらいから精度が落ちてしまっています。

 

以上がK近傍法の実装でしたが、K=40で精度が落ちる理由も含めて、アルゴリズムの仕組みが説明しやすいですよね。

K近傍法は挙動が分かりやすいアルゴリズムなので、ロジックを説明する必要がある場合には、簡易分析的に使ってみても良いかもしれないですね。

 

K近傍法まとめ

 

というわけで、K近傍法の理論と実装について紹介してきました。

機械学習を勉強している方の参考になれば良いなと。

 

なおK近傍法について、もっと詳しく学習していきたい場合には、以下の書籍がオススメです。

通称「はじパタ」で、機械学習を勉強する人なら、誰もが使う書籍です。

少し難しいですが、しっかりと機械学習を勉強していくなら、読む価値ありですよ(`・ω・´)!

 

というわけで、以上です!