R言語でSVM(Support Vector Machine)による分類学習
- 作者: ネロクリスティアニーニ,ジョンショー‐テイラー,Nello Cristianini,John Shawe‐Taylor,大北剛
- 出版社/メーカー: 共立出版
- 発売日: 2005/03
- メディア: 単行本
- 購入: 8人 クリック: 135回
- この商品を含むブログ (41件) を見る
SVMとは
Support Vector Machineの略で教師あり学習に分類されます。線形、非線形の識別関数があり現在知られている多くの学習モデルの中では最も優れた識別能力があるとされています。いわゆる2値分類を解くための学習モデルであり、線形しきい素子を用いて分類器を構成します。訓練データにおける各データ点と距離が最大になるマージン最大化という基準で線形しきい素子のパラメータを学習させます。シンプルな例は与えられたデータ集合を全て線形に分離する事です。SVMはカーネルトリックという非線形の分離も可能としており、この部分でも優れた性能を発揮する事が分かっています。この記事ではR言語に備わっているデータを利用してSVMによる分類学習を行います。途中でNeuralNetwork、NaiveBayesとの比較も簡単に行います。
R言語でSVM
設定
R言語でSVMを利用するにはkernlabというパッケージを必要とします。最初にinstallします。またlibrary関数でkernlabを読み込みます。
$ sudo R > install.packages( "kernlab" ) > library( kernlab )学習データ/予測データを作成
R言語に標準で入ったテストデータにIrisというものがあります。Irisを辞書で調べてみると以下のようにアヤメのことを差しています。「 アヤメ科アヤメ属の単子葉植物の総称。アヤメ・ハナショウブ・カキツバタなど。一般にはジャーマンアイリス・ダッチアイリスなどの園芸種をいう。」(Yahoo!辞書から引用)。Irisデータは4行のデータであり、蕚片の長さ/幅、花びらの長さ/幅で種別を定義しているデータです。ここでは蕚片の長さ/幅と花びらの長さ/幅を説明変数、種別を目的変数と呼ぶ事にします。トレーニングデータの説明変数を学習させ、評価データの説明変数から目的変数がどれに分類されるかを評価します。まずはIrisデータの50%をトレーニングデータ、残りの50%を予測を行うデータに分類します。Irisのデータは150行のデータなのでそれぞれ75行分のデータが格納されます。
#irisのデータの行数を取得 > rowdata<-nrow(iris) #行数からランダムに行番号を抽出 > random_ids<-sample(rowdata,rowdata*0.5) > random_ids [1] 148 35 114 6 26 129 58 92 20 138 147 107 110 41 88 11 137 52 142 [20] 17 38 55 139 132 21 8 4 49 125 12 84 77 101 122 40 1 25 37 [39] 87 83 61 111 18 5 7 113 56 93 109 3 74 82 134 118 33 42 130 [58] 76 70 103 136 116 106 65 19 16 30 75 143 54 98 60 121 45 94 #学習データを作成 > iris_training<-iris[random_ids, ] > iris_training Sepal.Length Sepal.Width Petal.Length Petal.Width Species 148 6.5 3.0 5.2 2.0 virginica 35 4.9 3.1 1.5 0.2 setosa 114 5.7 2.5 5.0 2.0 virginica 6 5.4 3.9 1.7 0.4 setosa 26 5.0 3.0 1.6 0.2 setosa 129 6.4 2.8 5.6 2.1 virginica 58 4.9 2.4 3.3 1.0 versicolor 92 6.1 3.0 4.6 1.4 versicolor 20 5.1 3.8 1.5 0.3 setosa 138 6.4 3.1 5.5 1.8 virginica 147 6.3 2.5 5.0 1.9 virginica #予測データを作成 > iris_predicting<-iris[-random_ids, ] > iris_predicting Sepal.Length Sepal.Width Petal.Length Petal.Width Species 2 4.9 3.0 1.4 0.2 setosa 9 4.4 2.9 1.4 0.2 setosa 10 4.9 3.1 1.5 0.1 setosa 13 4.8 3.0 1.4 0.1 setosa 14 4.3 3.0 1.1 0.1 setosa 15 5.8 4.0 1.2 0.2 setosa 22 5.1 3.7 1.5 0.4 setosa 23 4.6 3.6 1.0 0.2 setosa 24 5.1 3.3 1.7 0.5 setosa 27 5.0 3.4 1.6 0.4 setosa学習および予測
作成したトレーニングデータをSVMに学習させて、評価データを入れて正解率を見てみます。ksvm関数で学習させたモデルに対して予測データを入れます。予測データの結果と正解をtableで比較します。setosaとversicolorは100%正解しています。virginicaをversicolorを間違えているのが2つ存在しているため、正解率は71/75 = 94%となります。非常に高い正解率と言えます。
#ksvm関数でトレーニングデータを学習 > iris_svm<-ksvm(Species ~., data=iris_training ) Using automatic sigma estimation (sigest) for RBF or laplace kernel > iris_svm Support Vector Machine object of class "ksvm" SV type: C-svc (classification) parameter : cost C = 1 Gaussian Radial Basis kernel function. Hyperparameter : sigma = 0.88455572069582 Number of Support Vectors : 40 Objective Function Value : -3.8445 -4.3933 -11.5324 Training error : 0.026667 #predict関数で予測データを評価 > result_predict<-predict(iris_svm, iris_predicting) > result_predict [1] setosa setosa setosa setosa setosa setosa [7] setosa setosa setosa setosa setosa setosa [13] setosa setosa setosa setosa setosa setosa [19] setosa setosa setosa setosa setosa versicolor [25] versicolor versicolor versicolor versicolor versicolor versicolor [31] versicolor versicolor versicolor versicolor virginica versicolor [37] versicolor virginica versicolor versicolor versicolor versicolor [43] versicolor versicolor versicolor versicolor versicolor versicolor [49] versicolor versicolor versicolor virginica virginica virginica [55] virginica virginica virginica virginica virginica versicolor [61] virginica virginica virginica virginica virginica virginica [67] virginica versicolor virginica virginica virginica virginica [73] virginica virginica virginica Levels: setosa versicolor virginica #予測結果と正解との比較 > table(result_predict,iris_predicting$Species) result_predict setosa versicolor virginica setosa 23 0 0 versicolor 0 26 2 virginica 0 2 22NeuralNetworkでの分類
R言語のnnetパッケージを利用してNeuralNetworkでも分類学習をしてみます。先に結果を書いてしまいますが、これによりSVMとの正解率を測定したかったのですが、75行のデータに対しては全く同じ精度となりました。正解率は71/75 = 94%となります。本格的に精度を検証するには学習データ、予測データともに数を増やさないといけません。
# パッケージインストール > install.packages( "nnet" ) # nnetパッケージを読み込み > library( nnet ) # nnet関数でNeuralNetworkに学習させる > iris_nnet<-nnet(Species ~ ., data = iris_training, size = 2, rang = .1, decay = 5e-4, maxit = 200) # 未分類のデータを予測する > result_predict_nnet<-predict(iris_nnet,iris_predicting,type="class") # 正解と比較 > table(result_predict_nnet,iris_predicting$Species) result_predict_nnet setosa versicolor virginica setosa 23 0 0 versicolor 0 26 2 virginica 0 2 22NaiveBayesでの分類
R言語のe1071パッケージを利用してnaiveBayesによる分類学習も行います。正解との比較ではSVM、NeuralNetworkよりも悪い結果となりました。正解率は69/75 = 92%となっています。
# e1071パッケージをインストール > install.packages( "e1071" ) # e1071の読み込み > library( e1071 ) # naiveBayesによる学習 > iris_nb<-naiveBayes(Species ~ ., data = iris_training) # 未分類のデータを予測する > result_predict_nb<-predict(iris_nb,iris_predicting,type="class") # 正解と比較 > table(result_predict_nb,iris_predicting$Species) result_predict_nb setosa versicolor virginica setosa 23 0 0 versicolor 0 24 2 virginica 0 4 22
SAPM判定分類学習
Iris以外にもkernlabパッケージにはSAPMメールのデータがあります。行数が4601あるのでIrisよりも正確な正解率が出そうです。SPAMデータをSVM、NeuralNetwork、naiveBayesのそれぞれのModelに掛けて予測結果の評価を行います。
SVM
(1349+783)/(1349+111+58+783) = 92.65%が正解率となりました。
> library(kernlab) > data(spam) > rowdata<-nrow(spam) > random_ids<-sample(rowdata,rowdata*0.5) > spam_training<-spam[random_ids,] > spam_predicting<-spam[-random_ids,] > spam_svm<-ksvm(type ~., data=spam_training ) > spam_predict<-predict(spam_svm,spam_predicting[,-58]) > table(spam_predict, spam_predicting[,58]) #結果 spam_predict nonspam spam nonspam 1349 111 spam 58 783NeuralNetwork
(1261+877)/(1261+50+113+877) = 92.91%が正解率となりました。
> library( nnet ) > spam_nn<-nnet(type ~., data=spam_training,size = 2, rang = .1, decay = 5e-4, maxit = 200 ) > spam_predict<-predict(spam_nn,spam_predicting[,-58],type="class") > table(spam_predict, spam_predicting[,58]) spam_predict nonspam spam nonspam 1261 50 spam 113 877NaiveBayes
(758+866)/(758+61+616+866) = 70.57%が正解率となりました。
> library( e1071 ) > spam_nn<-naiveBayes(type ~., data=spam_training) > spam_predict<-predict(spam_nn,spam_predicting[,-58],type="class") > table(spam_predict, spam_predicting[,58]) spam_predict nonspam spam nonspam 758 61 spam 616 866