LibSVMのcross validationオプションでprecision/recallを出力する

SVMの定番ツールのひとつであるlibsvmにはcross validationオプション(-v) があり,ユーザが指定したFoldのcross validationを実行してくれる.

実行例

% ./svm-train -v 2 heart_scale
*
optimization finished, #iter = 96
nu = 0.431885
obj = -45.653900, rho = 0.152916
nSV = 70, nBSV = 49
Total nSV = 70
*
optimization finished, #iter = 84
nu = 0.512665
obj = -57.742885, rho = 0.134158
nSV = 78, nBSV = 61
Total nSV = 78
Cross Validation Accuracy = 81.8519%


ただ,accuracy (正答率) しか出力してくれないため,各クラスのprecision (適合率), recall (再現率), F1値などを確認したい場合には使えない.せっかくlibsvm側でcross validation部分を自分で書くのは少し面倒なので,libsvm側でprecision, recallを計算して出力するよう変更してみた.

svm-train.cの中を見てみると,do_cross_validation()という関数があり,そこで精度を計算している模様なのでここをいじってみることにした.

precision, recall, F1値の計算方法

precision, recallは「各クラス」に対して計算されることに注意する.以下,2値分類の例で説明をする.precision, recallを計算するためには,TP, FP, TN, FNの4つを計算すればよい.それぞれ下記のとおり.

  • True Positive (TP) = "正しく" positive と分類
  • False Positive (FP) = "誤って" positive と分類 (= 本当はnegativeクラス)
  • True Negative (TN) = "正しく" negativeと分類
  • False Negative (FN) = "誤って" negativeと分類 (= 本当はpositiveクラス)

この "正しく","誤って" という覚え方をするようになってから一度も間違えることがなくなったので,これはお薦めの覚え方.

libsvmが出力してくれる accuracy とは,分類した数のうち,どれだけ正解したか,という数なので,

  • Accuracy = (TP + TN) / (TP + FP + TN + FN)

で計算できる.


さてprecisionは,分類したうち,どれだけ正しく分類できたかという指標で,

  • positiveクラスのprecision = TP / (TP + FP)
  • negativeクラスのprecision = TN / (TN + FN)

で計算することができる.


recallは,全てのサンプルのうち,どれだけ正しく分類できたかという指標で,

  • positiveクラスのrecall = TP / (TP + FN)
  • negativeクラスのrecall = TN / (TN + FP)

で計算することができる.positiveクラスを例に取ると,分母がTP+FNとなっており,真のpositiveクラスの事例数を表していることがわかる.


ついでによく使われる指標としてF1値というものがある.F1値の前身について以前書いた記事 (F値の前身はE値?) をご参照のこと.

これは,precisionとrecallの調和平均で計算され,二つの値が同じ値を取る時に算術平均と一致し,それ以外の場合は算術平均,幾何平均よりも小さな値になる.すなわちprecision, recall両者がバランスよく高い値を示しているかを確認するために用いることができる.ここでprecisionをP,recallをRと表現すると,

\frac{1}{\frac{1}{2} \frac{1}{P} \frac{1}{R}} = \frac{2 P R}{P+R}

で計算できる.計算する際には右辺を利用するが,左辺の方が調和平均 (逆数の平均の逆数) を覚えやすいので僕はこちらで覚えている.

なお,調和平均が算術平均,幾何平均よりも小さくなるのはJensenの不等式で証明できる.これも過去の記事 (相加平均≧相乗平均≧調和平均の証明 with Jensenの不等式) をご参照のこと.

svm-train.c のパッチ

前置きが長くなってしまったけれど,libsvmの該当部分を変更したパッチを用意してみた.
(2013-02-03 ume さんのご指摘で修正しました.ありがとうございます)

--- svm-train.c	2011-09-15 05:22:31.000000000 +0900
+++ svm-train-new.c	2011-09-15 05:25:39.000000000 +0900
@@ -122,11 +122,17 @@
 void do_cross_validation()
 {
 	int i;
-	int total_correct = 0;
+	// int total_correct = 0;
 	double total_error = 0;
 	double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
 	double *target = Malloc(double,prob.l);
 
+	//** To caluculate precision/recall for each class **/
+	int tp = 0;
+	int fp = 0;
+	int tn = 0;
+	int fn = 0;
+
 	svm_cross_validation(&prob,&param,nr_fold,target);
 	if(param.svm_type == EPSILON_SVR ||
 	   param.svm_type == NU_SVR)
@@ -150,11 +156,46 @@
 	}
 	else
 	{
-		for(i=0;i<prob.l;i++)
-			if(target[i] == prob.y[i])
-				++total_correct;
-		printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
+
+	  for(i=0;i<prob.l; i++) {
+
+	    if(prob.y[i] == 1) { // True label = +1
+	      if(target[i] == prob.y[i]) {
+		tp++;
+	      } else {
+		fn++;
+	      }
+	    } else { // True label = -1
+	      if (target[i] == prob.y[i]) {
+		tn++;
+	      } else {
+		fp++;
+	      }
+	    }
+	  }
+
+	  printf("Cross Validation Accuracy = %g%%\n",100.0 * ((double)(tp + tn) / (double)(tp + fp + tn + fn)) );
+
+	  // Precision and recall
+	  double pos_prec   = ((double)tp/(double)(tp + fp));
+	  double pos_rec    = ((double)tp/(double)(tp + fn));
+	  double pos_f1     = (2 * pos_prec * pos_rec) / (pos_prec + pos_rec);
+
+	  double neg_prec   = ((double)tn/(double)(tn + fn));
+	  double neg_rec    = ((double)tn/(double)(tn + fp));
+	  double neg_f1     = (2 * neg_prec * neg_rec) / (neg_prec + neg_rec);
+	    
+	  printf("Positive (+1) class:\n");
+	  printf("  precision = %g\n", pos_prec );
+	  printf("     recall = %g\n", pos_rec );
+	  printf("   F1 value = %g\n\n", pos_f1 );
+
+	  printf("Negative (-1) class:\n");
+	  printf("  precision = %g\n", neg_prec );
+	  printf("     recall = %g\n", neg_rec );
+	  printf("   F1 value = %g\n\n", neg_f1 );
 	}
+
 	free(target);
 }

パッチはここからもダウンロードできる.

パッチの当て方

libsvm-3.1のsvm-train.cと同じディレクトリにsvm-train.patchをコピーし,

% cd libsvm-3.1
% ls
COPYRIGHT  Makefile      README       java    python         svm-scale.c  svm-train.c  svm.def  tools
FAQ.html   Makefile.win  heart_scale  matlab  svm-predict.c  svm-toy      svm.cpp      svm.h    windows
% wget http://sleepyheads.jp/software/svm-train.patch
% patch < svm-train.patch

あとは再びmakeすれば,パッチの当たったソースでコンパイルしてくれる.パッチ後のsvm-trainの動作は以下のようになる.
(2013-02-02 パッチ修正によってprecisionとrecallが反対になっていたのを修正しました)

先ほどと同じデータに対する実行例

% ./svm-train -v 2 heart_scale
*
optimization finished, #iter = 96
nu = 0.431885
obj = -45.653900, rho = 0.152916
nSV = 70, nBSV = 49
Total nSV = 70
*
optimization finished, #iter = 84
nu = 0.512665
obj = -57.742885, rho = 0.134158
nSV = 78, nBSV = 61
Total nSV = 78
Cross Validation Accuracy = 81.8519%
Positive (+1) class:
  precision = 0.81982
     recall = 0.758333
   F1 value = 0.787879

Negative (-1) class:
  precision = 0.81761
     recall = 0.866667
   F1 value = 0.841424

このようにクラス毎にprecision, recall, F1値を出力してくれるようになった.

機械学習の評価ツールは予測値とテストデータを入力として計算することが多いと思うが,こんな風にソースにちょっと手を加えるだけで,目当ての値が利用できる,という好例だったので紹介してみた.

ソースを修正する際に否が応でもソースコードを読むことになるので,既存の実装に手を加えてみるというのは,あまりコードを読む習慣がない僕のような人間にはちょうど良いのかもしれない.(アルゴリズム本体ではないので,今回は最適化部分とか全く読まずに済んでしまったのだけれど..)

パッチのつくり方

生まれて初めてパッチというものを書いてみたのでメモ.実は単なるdiffコマンドだったということを知って驚愕.

% diff -u svm-train.c svm-train-new.c > svm-train.patch
  (diff -c でも可)