読者です 読者をやめる 読者になる 読者になる

Webデータレポート

株式会社ルーターのデータレポートブログです

RとH2Oを使って Deep Learning でキャラクター判定をやらせてみた

何かとDeepLearningがブームになっています。うちでもやってみました。

学習させること

画像認識させてみたいと思います。手始めに、ミッフィー、キティ、マイメロのキャラクターを判別させてみます。人間が見てもおばあちゃんくらいになると厳しいこの3者が区別できるのか?

結論

3つのキャラのうちどれ?って判定をさせると、6割くらいまでの精度にはなりました。普通は背景を取り除いて、キャラの顔だけを切り抜いてから分類させるとは思うのですが、背景込でキャラ判別させた割にはまぁまぁの精度かなと思います。キャラの顔だけをくり抜くことをさせるには、OpenCVとかを使うはずですが、OpenCVをいじりまくるのはまた今度ということで。

以下の図は、活性化関数の種類別に、横軸が繰り返し回数、縦軸が正答率です。20回のRectifierが一番正答率が高いという感じですね。キティは9割がた当たるのですが、ミッフィーマイメロは区別しにくいようです。単に輪郭を区別しただけという気もしますが、一次元のベクトルから二次元の形を当てるというのもまた一興。

f:id:webdatareport:20151229153030p:plain

ただ、画像認識でやりたいことがはっきりしている場合には、OpenCVで行けるところまでいくのが定石。

以下は、手順のメモです。

用意するもの

DeepLearning専用のフレームワークを使わず、ここは手軽にRとH2Oを使ってみます。Rであれば、統計な人たちにも馴染み深いですね。

通常はRはRstudioで手元のPCで使うことが多いかと思いますが、機械学習には時間がかかるので、Rをクラウドサーバに入れたいと思います。クラウドサーバだと、特定の時間だけ、サーバースペックをあげることができるので便利ですね。

centOSにRを入れる

wget "http://cran.ism.ac.jp/src/base/R-3/R-3.2.2.tar.gz"
tar xvzf R-3.2.2.tar.gz
cd R-3.2.2
yum install gcc-gfortran libgfortran
./configure --with-readline=no --with-x=no
yum install java-1.7.0-openjdk-devel
make
make install

centOSの RにH2Oを入れる

> yum install libcurl-devel #パッケージインストールするのに必要
>R
>>install.packages("h2o")
接続先を選択する画面
HTTPS CRAN mirror

1: 0-Cloud [https] 2: Austria [https]
3: China (Beijing 4) [https] 4: China (Hefei) [https]
5: Colombia (Cali) [https] 6: France (Lyon 2) [https]
7: Iceland [https] 8: Russia (Moscow 1) [https]
9: Switzerland [https] 10: UK (Bristol) [https]
11: UK (Cambridge) [https] 12: USA (CA 1) [https]
13: USA (KS) [https] 14: USA (MI 1) [https]
15: USA (TN) [https] 16: USA (TX) [https]
17: USA (WA) [https] 18: (HTTP mirrors) ←これ選択

新たに選択肢が増えるので48:Japan(Tokyo)を選択

ダウンロード完了

用意する画像

画像検索を使って、ミッフィー、キティ、マイメロの画像をそれぞれ落としてきます。イラストのこともあれば、グッズの写真ってこともあります。そういうのも一旦お構いなしに、インプットとしてみます。

実際に学習するのは、一次元のベクトル

画像を学習させるとは言うものの、機械学習にかけるのは、1次元のベクトルです。画像をグレースケールにして(この時点でマイメロの色がなくなり輪郭がミッフィーとおなじになる)、100×100のピクセルを1万の一次元の要素にしちゃいます。ホントは色付きでもっとたくさんのピクセル数でやりたいところですが、めちゃくちゃ時間がかかるので、この辺で許さないと厳しかったです。

こんなcsvができあがります。 f:id:webdatareport:20151229145947p:plain

一番左の列がラベルで、ここにキティかミッフィーマイメロが入ります。それ移行の列は1万個のピクセルの値ですね。0~255までのグレースケールの値が入ります。

このデータを、学習用と判定用の両方作ります。

h2oのパラメータ

以下のパラメータを指定しました。

パラメータ 説明 今回のモデル
x モデルの文字列変数を含むベクトル
100*100ピクセルだったら2列目から1001列目
2:10001
y モデル内の応答変数の名前。ラベルが付いている列(たぶん) 1
training_frame 学習用データフレーム
activation 活性化関数を選択
Tanh
TanhWithDropout
Rectifier
RectifierWithDropout
Maxout
MaxoutWithDropout
Tanh
TanhWithDropout
Rectifier
RectifierWithDropout
Maxout
MaxoutWithDropout|Tanh
TanhWithDropout
Rectifier
RectifierWithDropout
hidden 隠れレイヤーのユニット数をベクトルで並べる
入力次元の1/10 or 1/100ユニットが普通
例:rep(20,5) 20ユニット×5層
rep(1000,4)
epochs 繰り返し回数
デフォルトだと1回
10~40

グレースケールに変換してベクトルに変換するRスクリプト

ImageMagicとRubyとかの組み合わせでもできちゃうけど、Rでもここまでできるってことで。

library(raster)

#グレースケール変換関数
as.grayscale.array<-function(file,flag){
  image<-brick(file)
  image.rgb<-getValues(image)
  rgbcol<-ncol(image.rgb)
  #元画像がカラーであればRGB値をグレースケールに変換
  if (rgbcol>1){
    image.bw<-image.rgb[,1]*0.21+image.rgb[,2]*0.72+image.rgb[,3]*0.07
    return (c(flag,image.bw))  
  #元画像が白黒であればそのまま    
  }else{
    return (c(flag,image.rgb))
  }  
}
#各画像のパスをキャラごとに格納
kity <- list.files("C:/rooter/pic_resize100/",pattern = "kity",full.names=T)
miffy <- list.files("C:/rooter/pic_resize100/",pattern = "miffy",full.names=T)
mymelo <- list.files("C:/rooter/pic_resize100/",pattern = "mymelo",full.names=T)

#評価用データ
pre_kity <- list.files("C:/rooter/pic_eva/",pattern="kity",full.names=T)
pre_miffy <- list.files("C:/rooter/pic_eva/",pattern="miffy",full.names=T)
pre_mymelo <- list.files("C:/rooter/pic_eva/",pattern="mymelo",full.names=T)

#3つのキャラクターを1つのベクトルに結合
res.dl<-rbind(t(sapply(kity, as.grayscale.array , "kity")),
              t(sapply(miffy, as.grayscale.array , "miffy")),
              t(sapply(mymelo, as.grayscale.array , "mymelo")) )

#評価用データのベクトル生成
pred.dl<-rbind(t(sapply(pre_kity, as.grayscale.array , "kity")),
              t(sapply(pre_miffy, as.grayscale.array , "miffy")),
              t(sapply(pre_mymelo, as.grayscale.array , "mymelo")) )

#学習用・評価用データをデータフレーム化
res.dl<-data.frame(res.dl)
pred.dl<-data.frame(pred.dl)

#出力
write.csv(res.dl,"C:/rooter/training_data.csv",row.names=FALSE, quote=TRUE)
write.csv(pred.dl,"C:/rooter/prediction_data.csv",row.names=FALSE, quote=TRUE)

学習と評価のRスクリプト

library(h2o)

#あらかじめJava SE 7をダウンロードしておく
localH2O <- h2o.init(ip="localhost" , nthreads = -1)

#データファイルをH2Oクラスに変換
#学習用
target<-read.csv("training_data.csv")
target<-data.frame(target)

act<-"TanhWithDropout"
epo<-40

start<-proc.time()

#学習
res.dl<-h2o.deeplearning(x = 2:10001,
                         y = 1,
                         training_frame = as.h2o(target),
                         activation = act,
                         hidden = rep(1000,4),
                         epochs = epo,
)

end<-proc.time()

#評価用
pred.dt<-read.csv("prediction_data.csv")

#予測
prediction<-h2o.predict(object = res.dl , newdata = as.h2o(pred.dt))

#精度を算出
pred<-as.data.frame(prediction)
result.ar<-t(rbind(pred.dt[,1],pred[,1]))


#正答率を計算
kity.num<-0
miffy.num<-0
mymelo.num<-0
num<-0
for(i in 1:nrow(result.ar)){
  if(result.ar[i,1]==result.ar[i,2]){
    if(result.ar[i,1]==1){
        kity.num<-kity.num+1
    }
    if(result.ar[i,1]==2){
        miffy.num<-miffy.num+1
    }
    if(result.ar[i,1]==3){
        mymelo.num<-mymelo.num+1
    }
    num<-num+1
  }
}

kity.res<-format(kity.num/sum(result.ar[,1]==1), nsmall=5)
miffy.res<-format(miffy.num/sum(result.ar[,1]==2), nsmall=5)
mymelo.res<-format(mymelo.num/sum(result.ar[,1]==3), nsmall=5)

res<-format(num/nrow(result.ar),nsmall=5)

#出力
pred.ar<-data.frame(chara=c("kity","miffy","mymelo","all","time"), pre=c(kity.res,miffy.res,mymelo.res,res,as.character(end-start)[3]))
write.table(pred.ar,paste("log/",act,"/epo",epo,".txt",sep=""))
write.csv(pred,paste("log/",act,"/epo",epo,".csv",sep=""))