読者です 読者をやめる 読者になる 読者になる

Web就活日記

愛と夢と人生について書きます

10秒で設定可能なlibsvmで機械学習を行う

Support Vector Machines (Information Science and Statistics)

Support Vector Machines (Information Science and Statistics)

libsvm

前回RでのSVMを簡単に紹介しましたが、今日はlibsvmを利用したirisの分類学習を行いたいと思います。libsvmは導入がめちゃくちゃ簡単なところが売りだと思います。zipをlibsvmサイトからdownloadして展開してgmakeで設定完了です。

設定
$ wget "http://www.csie.ntu.edu.tw/~cjlin/cgi-bin/libsvm.cgi?+http://www.csie.ntu.edu.tw/~cjlin/libsvm+zip"
$ unzip libsvm-3.12.zip
$ cd libsvm-3.12
$ gmake
$ ls
-rw-r--r-- 1 yuta yuta  1497  1月 31  2012 COPYRIGHT
-rw-r--r-- 1 yuta yuta 72186  4月  1 07:17 FAQ.html
-rw-r--r-- 1 yuta yuta   732  1月  2  2012 Makefile
-rw-r--r-- 1 yuta yuta  1087  9月 12  2010 Makefile.win
-rw-r--r-- 1 yuta yuta 27332  2月  3  2012 README
-rw-r--r-- 1 yuta yuta 27670  7月 12  2003 heart_scale
drwxr-xr-x 3 yuta yuta  4096  4月  1 07:18 java
drwxr-xr-x 2 yuta yuta  4096 10月 30  2011 matlab
drwxr-xr-x 2 yuta yuta  4096  3月 22 12:25 python
-rwxr-xr-x 1 yuta yuta 67413  8月 29 07:53 svm-predict
-rw-r--r-- 1 yuta yuta  5381  2月  5  2011 svm-predict.c
-rwxr-xr-x 1 yuta yuta 15650  8月 29 07:53 svm-scale
-rw-r--r-- 1 yuta yuta  7042  5月 28  2011 svm-scale.c
drwxr-xr-x 5 yuta yuta  4096  2月  3  2012 svm-toy
-rwxr-xr-x 1 yuta yuta 71912  8月 29 07:53 svm-train
-rw-r--r-- 1 yuta yuta  8891  5月 27  2011 svm-train.c
-rw-r--r-- 1 yuta yuta 63412 12月 26  2011 svm.cpp
-rw-r--r-- 1 yuta yuta   434  9月 12  2010 svm.def
-rw-r--r-- 1 yuta yuta  3129  2月  3  2012 svm.h
-rw-r--r-- 1 yuta yuta 93208  8月 29 07:53 svm.o
drwxr-xr-x 2 yuta yuta  4096  2月 24  2012 tools
drwxr-xr-x 2 yuta yuta  4096  3月 16 00:44 windows
scale,train,predictコマンド

gmakeを行うと以下のコマンドが生成されます。これらのコマンドに対して学習データ,学習Model,評価データを与える事によりSVMによる機械学習が実現できます。各種コマンドの使い方については以下のサイトが詳しいと思います。またコマンドに対してaliasを貼っておきます。

  • svm-scale : データの正規化を行うコマンド
  • svm-train : 学習データからModelを生成するコマンド
  • svm-predict:評価データとModelから分類とaccuracyを導きだすコマンド
alias svm-predict=/home/yuta/work/libsvm/libsvm-3.12/svm-predict
alias svm-scale=/home/yuta/work/libsvm/libsvm-3.12/svm-scale
alias svm-train=/home/yuta/work/libsvm/libsvm-3.12/svm-train
irisデータの学習

R言語には標準でirisのデータが備わっていましたが、libsvmのdataformatに従ったサンプルも以下にあります。libsvmに対しての学習にはiris.scaleのデータはiris.scaleを利用します。Iris Setosa、Iris Versicolour 、Iris Virginicaといった3つのアヤメ種別(class)と蕚片の長さ/幅、花びらの長さ/幅といった4つの特徴(attribute)を持つものです。種別に対しては1から3のラベルを貼り、各種特徴にも特徴番号を付け、-1〜1の正規化した特徴量を持たせています。以下にデータ取得から学習までのコマンドを記載します。svm-trainコマンドを実行するとiris.scale.modelというmodelファイルが生成されます。

$ wget "http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale"

$ less iris.scale
1 1:-0.555556 2:0.25 3:-0.864407 4:-0.916667 
1 1:-0.666667 2:-0.166667 3:-0.864407 4:-0.916667 
1 1:-0.777778 3:-0.898305 4:-0.916667 
1 1:-0.833333 2:-0.0833334 3:-0.830508 4:-0.916667 
1 1:-0.611111 2:0.333333 3:-0.864407 4:-0.916667 
1 1:-0.388889 2:0.583333 3:-0.762712 4:-0.75 
1 1:-0.833333 2:0.166667 3:-0.864407 4:-0.833333 
1 1:-0.611111 2:0.166667 3:-0.830508 4:-0.916667 
1 1:-0.944444 2:-0.25 3:-0.864407 4:-0.916667 
1 1:-0.666667 2:-0.0833334 3:-0.830508 4:-1 
1 1:-0.388889 2:0.416667 3:-0.830508 4:-0.916667 
1 1:-0.722222 2:0.166667 3:-0.79661 4:-0.916667

$ svm-train iris.scale 
*
optimization finished, #iter = 12
nu = 0.092569
obj = -5.138435, rho = -0.062041
nSV = 11, nBSV = 8
*
optimization finished, #iter = 25
nu = 0.048240
obj = -2.538201, rho = 0.016007
nSV = 8, nBSV = 2
*
optimization finished, #iter = 36
nu = 0.447449
obj = -32.467419, rho = 0.105572
nSV = 46, nBSV = 42
Total nSV = 58

$ ls
drwxr-xr-x 2 yuta yuta 4096  8月 29 09:03 .
drwxr-xr-x 4 yuta yuta 4096  8月 29 08:57 ..
-rw-r--r-- 1 yuta yuta 6954  6月  8  2005 iris.scale
-rw-r--r-- 1 yuta yuta 3226  8月 29 09:03 iris.scale.model
Modelによる評価

予測はsvm-predictコマンドで行います。ここでは学習データで使ったデータをそのまま予測用データとして利用します。svm-predictの第一引数に評価データ、第二引数に学習で生成したmodel、第三引数にラベル付けの結果ファイルを指定します。標準出力でaccuracyが出てくるのでそれをファイルにリダイレクトしておきます。個々の結果ではaccuracyが97%と出ました。

$ svm-predict iris.scale iris.scale.model iris.scale.output > accuracy.txt
$ less accuracy.txt
Accuracy = 97.3333% (146/150) (classification)
K-分割交差検定

上の評価ではaccuracyが97%と出ていますが、学習データと評価データが同じでは本来のModel精度が分からないので、まずは学習データと評価データを分離します。ここでは一般的なCross Validationにしたがって150個のirisサンプルを50個ずつに分割します。50個のデータセットをA,B,Cというファイルで定義した場合、学習と評価を3パターンで試します。交差検定を"K-fold cross-validation"と呼びますが、ここではK=3となります。

  • 学習(A,B) 評価(C)
  • 学習(A,C) 評価(B)
  • 学習(B,C) 評価(A)

行を完全にrandomでshuffleして100行の学習データ/50行の評価データを作成します。

$ perl -MList::Util=shuffle -e 'print shuffle(<>)' < iris.scale | split -l 50
$ cat xaa xab | wc -l
$ cat xaa xab > 1.train.txt
$ cat xaa xac > 2.train.txt
$ cat xab xac > 3.train.txt
$ mv xac 1.predict.txt
$ mv xab 2.predict.txt
$ mv xaa 3.predict.txt

続いて学習と予測、accuracy算出までをやります。下の結果から分かるようにAccuracyの算出が94%、96%、98%といずれも高い数値が出ました。Cross Validationでもかなり高い確率で予測が出来ていると言えます。

$ svm-train 1.train.txt
$ svm-train 2.train.txt
$ svm-train 3.train.txt
*
optimization finished, #iter = 21
nu = 0.065188
obj = -2.237066, rho = 0.057124
nSV = 7, nBSV = 1
*
optimization finished, #iter = 30
nu = 0.558458
obj = -26.839488, rho = -0.084482
nSV = 39, nBSV = 34
*
optimization finished, #iter = 8
nu = 0.119403
obj = -4.403712, rho = 0.039139
nSV = 10, nBSV = 7
Total nSV = 49
*
optimization finished, #iter = 22
nu = 0.070101
obj = -2.437480, rho = 0.018715
nSV = 8, nBSV = 2
*
optimization finished, #iter = 17
nu = 0.443524
obj = -20.533308, rho = -0.040417
nSV = 30, nBSV = 28
*
optimization finished, #iter = 12
nu = 0.128258
obj = -4.819891, rho = -0.074851
nSV = 11, nBSV = 8
Total nSV = 44
*
optimization finished, #iter = 8
nu = 0.129629
obj = -4.619123, rho = 0.001523
nSV = 11, nBSV = 8
*
optimization finished, #iter = 9
nu = 0.067031
obj = -2.349066, rho = 0.175743
nSV = 9, nBSV = 3
*
optimization finished, #iter = 39
nu = 0.506626
obj = -27.356803, rho = 0.120588
nSV = 38, nBSV = 33
Total nSV = 50

$ svm-predict 1.predict.txt 1.train.txt.model 1.output.txt > 1.accuracy.txt
$ svm-predict 2.predict.txt 2.train.txt.model 2.output.txt > 2.accuracy.txt
$ svm-predict 3.predict.txt 3.train.txt.model 3.output.txt > 3.accuracy.txt

$ less 1.accuracy.txt 
Accuracy = 94% (47/50) (classification)
$ less 2.accuracy.txt 
Accuracy = 96% (48/50) (classification)
$ less 3.accuracy.txt 
Accuracy = 98% (49/50) (classification)

スポンサーリンク