Deep Learningの性能を見てみよう ~Iris編~

こんにちは、今井です。

前回は広告という「原因」から売上という「結果」に与える影響(因果効果)の推定を行いました。結果に対してその原因が与える因果効果の推定をしてモデルを作り、そのモデルから予測をするのが予測精度を高める最も良い方法にみえます。しかし、一般的に因果効果の推定と予測では目的が異なるため、この方法による予測精度が他の手法と比べて良いとは限りません。もちろんこの方法でも実務的に十分精度の高い予測はできますが、予測の部分をより高精度に行うための方法として機械学習を用いることがあります。

今回は機械学習の中でも 最近流行のDeep Learningの性能を簡単なデータセットを用いて調べてみたという話をします。

Deep Learningとは何かにつきましては、既にネットでも分かりやすい記事がたくさんありますのでそちらをご覧になってください。日本語の資料では、以下の資料がよくまとまっていて分かりやすいと思います。

ディープラーニングチュートリアル

Deep Learningは特徴量も自動に抽出してくれるので画像認識や音声認識のような分野で使われることが多いのですが、特徴量が既に決まっていて特徴量数が少ない場合の性能はどうなんだろうという疑問を持ったので調べてみました。

今回用いたデータは「Iris flower data set」と呼ばれるもので、アヤメの花びらの幅、長さ、がくの幅、長さ、アヤメの種類(setosa,versicolor,virginica)の情報が入っています。この花びらの幅、長さ、がくの幅、長さの4つの特徴量からアヤメの種類を予測して、その精度を比較します。今回はleave-one-out cross-validationという、標本から1つを抜き出してテスト事例とし、残り全てのデータを用いて学習を行った後テスト事例の予測をする、ということを全ての標本がテスト事例となるように繰り返して検証を行います。

Deep Learningの性能比較対象として以下の機械学習の手法を用いました。

  • 決定木
  • ランダムフォレスト
  • Extremely Randomized Trees (ERT)
  • サポートベクターマシーン
  • ニューラルネットワーク
  • ブースティング(adaboost/弱学習器は決定木)
  • バギング(弱学習器は決定木)

Deep Learningを含め、それぞれの手法はハイパーパラメータによって性能が変わりますが、今回は単純にRのパッケージのデフォルト値を用いています。

比較のためのRのコードは以下となります。

#Deep Learningを使うためにH2Oパッケージのインストール
install.packages("h2o", repos=(c("http://s3.amazonaws.com/h2o-release/h2o/master/1542/R", getOption("repos"))))

library(h2o)
localH2O <- h2o.init(ip = "localhost", port = 54321, startH2O = TRUE, nthreads=-1)
irisPath <- system.file("extdata", "iris.csv", package = "h2o")
irisdata <- h2o.importFile(localH2O, path = irisPath)

res.err.dl <- numeric(nrow(irisdata))

for(i in 1:nrow(irisdata)){
  #訓練データとテストデータに分割
  iris.train <- irisdata[-i,]
  iris.test <- irisdata[i,]

  #Deep Learningによる学習と予測
  res.dl <- h2o.deeplearning(x = 1:4, y = 5, data = iris.train, validation = iris.test, activation = "TanhWithDropout")

  #結果の格納-正解なら0, 不正解なら1の値が入る
  res.err.dl[i] <- res.dl@model$valid_class_error
}

#以下のライブラリはcranからインストールしておく
library(rpart)
library(randomForest)
library(extraTrees)
library(kernlab)
library(nnet)
library(adabag)
library(ipred)

iris.data<-read.csv(irisPath, header=F)
res.err.rf <- res.err.svm <- res.err.nn <- res.err.extratree <- res.err.cart <- res.err.adaboost <- res.err.bagging <- numeric(nrow(iris.data))

set.seed(123)
for(i in 1:nrow(iris.data)){
  iris.train <- iris.data[-i,]
  iris.test <- iris.data[i,]

  #決定木による学習と予測、結果の格納
  res.cart <- rpart(V5~.,data = iris.train)
  pred.cart <- predict(res.cart, iris.test[,-5], type = "class")
  res.err.cart[i] <- ifelse(pred.cart==iris.test[,5], 0, 1)

  #ランダムフォレストによる学習と予測、結果の格納
  res.forest <- randomForest(V5~.,data = iris.train)
  pred.forest <- predict(res.forest, iris.test[,-5])
  res.err.rf[i] <- ifelse(pred.forest==iris.test[,5], 0, 1)

  #Extremely Randomized Treesによる学習と予測、結果の格納
  res.extratree <- extraTrees(iris.train[,1:4],iris.train[,5])
  pred.extratree <- predict(res.extratree, iris.test[,1:4])
  res.err.extratree[i] <- ifelse(pred.extratree==iris.test[,5], 0, 1)

  #サポートベクターマシーンによる学習と予測、結果の格納
  res.svm <- ksvm(V5~.,data = iris.train)
  pred.svm <- predict(res.svm, iris.test[,-5])
  res.err.svm[i] <- ifelse(pred.svm==iris.test[,5], 0, 1)

  #ニューラルネットワークによる学習と予測、結果の格納
  res.nn <- nnet(V5~.,data = iris.train, size = 20)
  pred.nn <- predict(res.nn, iris.test[,-5], type = "class")
  res.err.nn[i] <- ifelse(pred.nn==iris.test[,5], 0, 1)

  #ブースティング(adaboost/弱学習器はrpartの決定木)による学習と予測、結果の格納
  res.adaboost <- boosting(V5~.,data = iris.train)
  pred.adaboost <- predict.boosting(res.adaboost, newdata=iris.test)
  res.err.adaboost[i] <- ifelse(pred.adaboost$class==iris.test[,5], 0, 1)

  #バギング(ipredパッケージ/弱学習器はrpartの決定木)による学習と予測、結果の格納
  res.bagging <- ipred::bagging(V5~.,data = iris.train)
  pred.bagging <- predict(res.bagging, iris.test[,-5], type = "class")
  res.err.bagging[i] <- ifelse(as.character(pred.bagging)==as.character(iris.test[,5]), 0, 1)
}

このコードによって得た結果を、各手法ごとにエラー率のプロットしたものが以下の図になります。(Deep Learningの結果はランダム性を含んでいるのでエラー率は毎回若干異なります)

DL_comparison_iris2

 

Deep Learningが他の手法を抑えてエラー率が一番小さい結果となりました。学習器がランダム性を含んでいるものもあるので、乱数を変えて行うと多少違う結果を得ることになりますが、今回のケースでは乱数を変えて行ってもDeep Learningのエラー率が一番小さくなることが多い結果となりました。

また、一般的にデータによってどの手法が良いかも変わってきますが、UCI machine learning repositoryのいくつかのデータで試したところでもDeep Learningが高い性能を示しました。もし興味を持たれましたら、上記のUCI machine learning repositoryのデータを用いてDeep Learningを試してみると面白いと思います。