Rによるknn法の実装方法を調べてみた
caret::knn3を使ってkNN法の勉強をしていたときに、どのように実装しているのか気になりました。methodsやgetAnywhereで探してもメインのソースコードが見あたらないため、とりあえずネットで他の実装例を探しておおまかな理解で済まそうと思います。
「R knn source code」などでググると色々でてきますが、私は以下のコードでkNN法のロジックを確認しました。わかりやすさ重視のためかなり遅い実装です。
benchR/knn.R at master · allr/benchR · GitHub
リンクの実装方法とcaret::knn3を比較しました。なお、結果が比較できればよいのでtraindata = testdataにしています。
# compare knn3 with linked implementation. library(caret) # test with a small subset of iris set.seed(2019) td = iris[sample(1:nrow(iris), 20),] pred = myknn("Species", traindata = td, testdata = td, k = 3) print(pred) print(table(pred, td$Species)) # caret::knn3 knn3_fit <- knn3(Species ~ ., data = td, k = 3) pred_knn3 <- predict(knn3_fit, td, type = "class") confusionMatrix(data = pred_knn3, reference = td$Species)$table
実行結果は一致します。
> print(table(pred, td$Species)) pred setosa versicolor virginica setosa 9 0 0 versicolor 0 6 1 virginica 0 0 4 > confusionMatrix(data = pred_knn3, reference = td$Species)$table Reference Prediction setosa versicolor virginica setosa 9 0 0 versicolor 0 6 1 virginica 0 0 4
ちなみにset.seed()の値を変えても両者の結果は一致します。
# compare knn3 with linked implementation. library(caret) # test with a small subset of iris set.seed(2020) td = iris[sample(1:nrow(iris), 20),] pred = myknn("Species", traindata = td, testdata = td, k = 3) print(pred) print(table(pred, td$Species)) # caret::knn3 knn3_fit <- knn3(Species ~ ., data = td, k = 3) pred_knn3 <- predict(knn3_fit, td, type = "class") confusionMatrix(data = pred_knn3, reference = td$Species)$table
> print(table(pred, td$Species)) pred setosa versicolor virginica setosa 5 0 0 versicolor 0 10 0 virginica 0 0 5 > confusionMatrix(data = pred_knn3, reference = td$Species)$table Reference Prediction setosa versicolor virginica setosa 5 0 0 versicolor 0 10 0 virginica 0 0 5
以上