Rで単純パーセプトロンを実装してみた
機械学習の勉強を始めました。TJOさんの記事を見る限り、単純パーセプトロンから始めるのがよさそうです。ネットで記事を漁りながら、自分で実装してみることにしました。
単純パーセプトロンの数式的なところは上記のTJOさんの記事と以下のJundollさんの記事を参考にしました。
実装については、SAM猫さんの記事を参考にしました。
Rで単純パーセプトロンを組んでみる - About connecting the dots.
前提
- 識別関数をwxで表しています。
- wは要素を2つ持っており、w1が切片項(x1の係数)、w2がx2の係数です。
- xのうち、x1はw1を切片項にするために要素が全て1のベクトル、x2が説明変数(特徴量)です。
- 以下の関数は説明変数が1つの場合(y = ax + b)のみ対応しており、それ以上の場合の汎用性はありません。
library(stringr) library(tidyverse) simple_perceptron <- function(x, l, w, m) { # plot data xl <- as_tibble(cbind(x, l)) xl_m1 <- xl %>% filter(l == -1) xl_p1 <- xl %>% filter(l == 1) ggp <- ggplot() + geom_point(data = xl_m1, aes(x = x2, y = 0), color = "blue") + geom_point(data = xl_p1, aes(x = x2, y = 0), color = "green") # initialise counter n <- 0 while(n < nrow(x)) { n <- 0 for(i in 1:nrow(x)) { judge <- (l[i] * (w %*% x[i, ]) <= 0) if (judge) { # update weight parameter w <- w + m * x[i, ] * l[i] # add abline ggp <- ggp + geom_abline(intercept = w[1], slope = w[2], color = "grey") } else { n <- n + 1 } } print(str_c("w1 = ", w[1], " w2 = ", w[2])) } # draw graph formula <- tibble( label = str_c("y = ", w[1], " + ", w[2], "x") ) ggp <- ggp + geom_abline(intercept = w[1], slope = w[2], color = "red") + geom_text(data = formula, aes(x = Inf, y = Inf, label = label), vjust = "top", hjust = "right") print(ggp) print(str_c("RESULT : ", "w1 = ", w[1], " w2 = ", w[2])) }
これをもとに、実際に学習した結果が以下です。なお、SAM猫さんの結果と比較するために、インプットを同じにしております。
# input x1 <- c(1, 1, 1, 1, 1, 1) x2 <- c(3, 7, 1, 5, 4, 2) l <- c(-1, 1, -1, 1, 1, -1) w <- c(0, 0) m <- 0.5 # execute x <- cbind(x1, x2) simple_perceptron(x, l, w, m)
得られた識別関数は、切片が-7.5で係数が2.0です。SAM猫さんの結果を再現できました。
追加分析
上記の関数を使って、以下を試してみました。
- x2の要素の順番を入れ替える
- パラメータ更新のロジックを修正する
x2の要素の順番を入れ替える
以下のインプットでもう一度学習しなおしてみます。
# input x1 <- c(1, 1, 1, 1, 1, 1) x2 <- c(1, 2, 3, 4, 5, 7) # 要素の順番を変更 l <- c(-1, -1, -1, 1, 1, 1) # x2の変更に合わせて変更 w <- c(0, 0) m <- 0.5 # execute x <- cbind(x1, x2) simple_perceptron(x, l, w, m)
すると学習結果も以下のように変わります。
この結果から、上記のような単純な学習ロジックでは、使用データの並び順が学習結果に影響を与えることがわかります。
パラメータ更新のロジックを修正する①
次に、simple_perceptron()内のパラメータ更新のロジックを変えてみます。上記のロジックでは、例えば i =3 で誤判定が出た場合に、wを更新し、引き続き残りの i = 4 ~ 6で判定を行い、またi = 1に戻って i = 6まで判定を行います。余分な判定を除くために、一度誤判定が出た場合、wを更新したあとfor文をbreakで抜けて、i = 1 から判定し直すロジックにしてみました。
library(stringr) library(tidyverse) simple_perceptron <- function(x, l, w, m) { # plot data xl <- as_tibble(cbind(x, l)) xl_m1 <- xl %>% filter(l == -1) xl_p1 <- xl %>% filter(l == 1) ggp <- ggplot() + geom_point(data = xl_m1, aes(x = x2, y = 0), color = "blue") + geom_point(data = xl_p1, aes(x = x2, y = 0), color = "green") # initialise counter n <- 0 while(n < nrow(x)) { n <- 0 for(i in 1:nrow(x)) { judge <- (l[i] * (w %*% x[i, ]) <= 0) if (judge) { # update weight parameter w <- w + m * x[i, ] * l[i] # add abline ggp <- ggp + geom_abline(intercept = w[1], slope = w[2], color = "grey") ### escape loop ### break } else { n <- n + 1 } } print(str_c("w1 = ", w[1], " w2 = ", w[2])) } # draw graph formula <- tibble( label = str_c("y = ", w[1], " + ", w[2], "x") ) ggp <- ggp + geom_abline(intercept = w[1], slope = w[2], color = "red") + geom_text(data = formula, aes(x = Inf, y = Inf, label = label), vjust = "top", hjust = "right") print(ggp) print(str_c("RESULT : ", "w1 = ", w[1], " w2 = ", w[2])) }
変更箇所は真ん中あたりのbreakのところのみです。wの更新が終わり次第、最初から識別し直すように変更しています。
この状態で、一番最初と同じインプットで学習してみます。
# input x1 <- c(1, 1, 1, 1, 1, 1) x2 <- c(3, 7, 1, 5, 4, 2) l <- c(-1, 1, -1, 1, 1, -1) w <- c(0, 0) m <- 0.5 # execute x <- cbind(x1, x2) simple_perceptron(x, l, w, m)
結果を見ると、切片が-6.5で係数が2.0です。こちらの結果は上記2ついずれとも異なります。
パラメータ更新のロジックを修正する②
ロジック修正のもう一つのバージョンとして、解く問題自体を変更しました。上記のプログラムはJundollさんの記事で説明されている通り、識別関数を以下のように定義しています。
そして、正しく判定されているときの符号条件が以下のようになっています。
ここで、識別関数の符号を反対にしてみます。(もともとどちらを+1でどちらを-1にするかは任意です。)
この場合、正しく判定されているときの条件は以下のように符号が反転します。
これに併せて、損失関数と解く問題を再定義すると以下のようになります。
損失関数:
解く問題:
損失関数の最小化問題が、損失関数の最大化問題に変わりました。
このときのwの更新方法は以下になります。最大化問題のため、の前の符号が逆転しているのに注意が必要です。簡単な説明をこの記事の最後に残しておきます。
まとめると以下のようになります。
再定義した問題をもとに、simple_perceptron()のロジックを書き直すと以下のようになります。
library(stringr) library(tidyverse) simple_perceptron <- function(x, l, w, m) { # plot data xl <- as_tibble(cbind(x, l)) xl_m1 <- xl %>% filter(l == -1) xl_p1 <- xl %>% filter(l == 1) ggp <- ggplot() + geom_point(data = xl_m1, aes(x = x2, y = 0), color = "blue") + geom_point(data = xl_p1, aes(x = x2, y = 0), color = "green") # initialise counter n <- 0 while(n < nrow(x)) { n <- 0 for(i in 1:nrow(x)) { judge <- (l[i] * (w %*% x[i, ]) >= 0) if (judge) { # update weight parameter w <- w - m * x[i, ] * l[i] # add abline ggp <- ggp + geom_abline(intercept = w[1], slope = w[2], color = "grey") } else { n <- n + 1 } } print(str_c("w1 = ", w[1], " w2 = ", w[2])) } # draw graph formula <- tibble( label = str_c("y = ", w[1], " + ", w[2], "x") ) ggp <- ggp + geom_abline(intercept = w[1], slope = w[2], color = "red") + geom_text(data = formula, aes(x = Inf, y = Inf, label = label), vjust = "top", hjust = "right") print(ggp) print(str_c("RESULT : ", "w1 = ", w[1], " w2 = ", w[2])) }
judgeの判定条件の符号とパラメータ更新式が変わっています。
これをもとに、下記の条件で識別を行ってみます。
# input x1 <- c(1, 1, 1, 1, 1, 1) x2 <- c(3, 7, 1, 5, 4, 2) l <- c(-1, 1, -1, 1, 1, -1) w <- c(0, 0) m <- 0.5 # execute x <- cbind(x1, x2) simple_perceptron(x, l, w, m)
当初の問題定義では切片が-7.5、傾きが2.0でした。今回は切片が7.5、傾きが-2.0になっています。いずれも識別はできていますが、得られる識別関数が異なっております。
追加分析の結果から、学習結果は実装方法に依存することがわかりました。
補足:パラメータ更新の式について
手書きで恐縮ですが、更新式の導出イメージは以下になります。なお、冒頭のTJOさんのスライドp45あたりに同様のことが書かれています。
以上
- 作者: 金明哲
- 出版社/メーカー: 森北出版
- 発売日: 2017/03/25
- メディア: 単行本(ソフトカバー)
- この商品を含むブログを見る