PA-regressionを試す

はじめに

PAアルゴリズムは回帰問題にも適用することができる、と書いてあったのでとりあえず試してみた。
パラメータの設定などがわかっていないので、要調査。

使用したデータ

結果

  • 学習データでクローズテスト
  • パラメータは適当に設定
    • C = 0.005
    • ε = 1.0
    • 学習回数は、データ全部を10ループ
  • グラフについて
    • 縦軸 : データの一番左の実数値
    • 横軸 : 事例番号(データの上から0,1,2,3,...)
  • 学習データの順番で学習した場合


  • 学習データをランダムな順に学習した場合

  • 意外と学習データの近いところにはいる
    • ぐじゃぐじゃではないが、、、
  • 最大値付近や最小値付近でずれが大きいように見える
    • スケールを調整する必要があるかもしれない
  • 学習の順番がランダムなほど学習データに近い値がでているように見える
    • やっぱり学習する順番はランダムな方がいいかもしれない
  • パラメータの設定の仕方がよくわからない
    • 論文のパラメータCの影響の部分を読む必要がある

用いたコード

  • PAのコードを直したもの
#! /usr/bin/perl
# Usage : perl PA-reg.pl train_file parameter_C parameter_epsilon < test_file
use strict;
use warnings;

#学習ファイル名
my $train_file = shift;
#パラメータ
my $C = shift;
my $epsilon = shift;

#訓練回数(学習データの個数*$loop個の学習を行う)
my $loop = 10;

#重みベクトル
my $w = {};

## 学習データの読み込み
my @x_list;
my @t_list;
open IN, $train_file;
while(<IN>){
    chomp;
    my @list = split(/\s+/, $_);
    push(@t_list, $list[0]);
    my $hash;
    for(my $i=1; $i<@list; $i++){
	my ($a, $b) = split(/:/,$list[$i]);
	$hash->{$a} = $b;
	$w->{$a} = 0;
    }
    push(@x_list, $hash);
}

## 訓練
while($loop--){
    for(my $i = 0; $i < @x_list; $i++){
	train($w, $x_list[$i], $t_list[$i]);
    }
}

## 推定
my $num = 0;
while(<>){
    chomp;
    my @list = split(/\s+/, $_);
    my $hash;
    for(my $i=1; $i<@list; $i++){
	my ($a, $b) = split(/:/,$list[$i]);
	$hash->{$a} = $b;
    }

    my $t = predict($w, $hash);
    print $num,"\t",$t,"\t",$list[0],"\t",($t-$list[0]),"\n";
    $num++;
}


##################################################
#予測
sub predict {
    my ($w, $x) = @_;
    
    my $y = 0;
    foreach my $f (keys %$x){
	if($w->{$f}){
	    $y += ($w->{$f} * $x->{$f});
	}
    }
    return $y;
}

#損失関数
sub loss {
    my ($w, $x, $t) = @_;

    my $y = 0;
    foreach my $f (keys %$x){
	if($w->{$f}){
	    $y += ($w->{$f} * $x->{$f});
	}
    }
    return 0.0 if(abs($y-$t) <= $epsilon);
    return abs($y-$t)-$epsilon;
}

#学習
sub train {
    my ($w, $x, $t) = @_;
    
    my $y = predict($w, $x);
    my $l = loss($w, $x, $t);
    my $sq_x = 0.0;
    foreach my $f (keys %$x){
	$sq_x += ($x->{$f} * $x->{$f});
    }

    ## PA
    #my $tau = $l / $sq_x;

    ## PA-I
    #my $tau = $l / $sq_x;
    #$tau = $C if($tau > $C);

    ## PA-II
    my $tau = $l / ($sq_x + 1.0 / (2 * $C));

    my $sign = 1;
    $sign = -1 if($t-$y < 0);

    # 更新
    foreach my $f(keys %$x){
	$w->{$f} += $sign * $tau * $x->{$f};
    }
}