grbl1.1+Arduino CNCシールドV3.5+bCNCを使用中。
BluetoothモジュールおよびbCNCのPendant機能でスマホからもワイヤレス操作可能。
その他、電子工作・プログラミング、機械学習などもやっています。
MacとUbuntuを使用。

CNCマシン全般について:
国内レーザー加工機と中国製レーザー加工機の比較
中国製レーザーダイオードについて
CNCミリングマシンとCNCルーターマシンいろいろ
その他:
利用例や付加機能など:
CNCルーター関係:



*CNCマシンの制作記録は2016/04/10〜の投稿に書いてあります。


2017年6月2日金曜日

線形回帰のプログラミング(その2):勾配降下法

前回、グラフ平面上の複数の点(データ群)に対して直線フィッティングをするために最小二乗法を試してみましたが、今回は勾配降下法によるプログラムです。
最小二乗法は、公式に数値を入れるとすぐに答えがでてくるような感じでしたが、勾配降下法は直線の式に動的に近づいていく求め方なので、コンピューターを利用した求め方に、より適しているような気がします。公式によって一発で答えを導き出すというよりも、経験的に答えを探しにいくという感じでしょうか。

黄直線:勾配降下法、赤直線:最小二乗法

とはいっても、やはり公式があるようで、前回同様求めたい直線の式を

y = m * x + b

とすれば、

Error(m, b) = 1/n * Σ((m * xi + b) - yi)^2

が、求めたい式と暫定的な式とのmとbのトータルの誤差を表しているようです。Σは、i=1からi=nまでの合計値(for文上では、i=0からi=n-1まで)。
式中の (m * xi + b) - yi は、m * xi + bで求めたyの値から、ある点のyiを差し引いているので、その差が誤差(Error)になるというのがわかります。二乗しているのは、前回の最小二乗法のように符号を+にするということですが、そもそも最終的な目的が、誤差の最小化なので二乗してあっても問題ないということらしいです。
1/nのnは点の数なので、Σで合計したエラー値をnで割ることで平均値が得られるという感じでしょうか。
最終的には、このError(誤差)が0に近づくようにmとbを探していけばいいので、そのためにはこのErrorのグラフ上の傾き(接線)を調べると、式中のΣ以降の部分である

((m * xi + b) - yi)^2

を微分して接線(x,y地点の傾き)を調べると

2*((m * xi + b) - yi)

となり、求めたいmとbについての偏微分(変化率)は

Δm = -2/n * Σ xi * ((m * xi + b) - yi)
Δb = -2/n * Σ (m * xi + b) - yi

となって、このΔmとΔbをプログラム上で加算(更新)していくことで、徐々に求めたい式に近づいていくという方法のようです。
暫定的な直線の式にxとyの値を代入して誤差を求めて、その誤差が0になるようにmとbを更新していくというイメージはなんとなく分かるのですが、実際計算上でどのようにそうなっていくのか?このあたりを今回もプログラミングを通して実感できないかという試みです。
数式だけだと分かりにくいのですが、プログラミングでは今回の値に前回の値を再代入して、それを繰り返して徐々にある目標の値に近づいていくという方法はよくあるので、今回もそんなイメージで考えています。

勾配降下法においては、一発で答えを導くわけではないので、学習率という変数が登場してくるようです。少しづつ求めたい値に近づくために、オーバーシュートしないように0.1や0.01を掛けて小刻みにステップさせる係数のようなものだと思います。
さきほどのΔmとΔbの偏微分の方程式を書くと(今回もProcessingを使用)、

float m = 0;
float b = 0;
float learningRate = 0.1;
float db = 0;
float dm = 0;
for(int i = 0; i < x.length; i++){
    dm += x[i] * ((m * x[i] + b) - y[i]);
    db += (m * x[i] + b) - y[i];
}

m -= 2.0/x.length * dm * learningRate;
b -= 2.0/x.length * db * learningRate;

こんな感じになります。ΔmとΔbはdmとdb、学習率はlearningRate=0.1、for文で点の個数分繰り返しΔmとΔbの加算処理をしています。最終的なmとbは、学習率0.1を掛けたΔmとΔbと今回のmとbとの差分を差し引いて、次回また計算し直して徐々に求めたいmとbに近づいていくというプロセスです。偏微分の方程式があるため、それに従って計算してしまえばいいので、今回は結構短いコードでmとbを求めていくことができます。
しかし、最後に求めたmとbの式をよく見ると、適当に決めた学習率(今回の場合0.1)を掛けているので、

m -= dm * learningRate;
b -= db * learningRate;

のように式中の2.0/x.lengthを消してしまっても、あまり大差なさそうです。今回は方程式のまま書いてみましたが、計算をシンプルにするには、learningRateに係数を含ませてもいいと思います。
あるいは、もっと式をシンプルにすれば、

float m = 0;
float b = 0;
float learningRate = 0.1;
for(int i = 0; i < x.length; i++){
    float error = y[i] - (m * x[i] + b);
    m += error * x[i] * learningRate;
    b += error * learningRate;
}

とするだけでもいいのかもしれません。
式中のerrorは、実際のy[i]の値からm*x[i]+bで求めたyの値(もともとy=m*x+bという関係なので)を差し引いた誤差です。
結局は、learingRateで挙動を調整することができるので、ここまで変えてしまっても問題なさそうです。


以下が今回のプログラム(Processing/Processing.js)。前回同様に、以下の空のグラフ上の任意位置をクリックすると点が追加されます。複数の点(2点以上)に対する直線が現れます。 赤の直線は前回の最小二乗法、黄色の直線が今回の勾配降下法によるものです。 点の数や位置によっては、微妙に最小二乗法と勾配降下法の結果がずれることがあります。求め方の違いからくるためなのでしょうか?ブラウザリロードで初期化します。

表示が変な場合(線が点滅するなど)は、前回(一つ前の投稿)のページ上のプログラムと干渉しあっているために起こると思われます。 そうならないようにするためには、この投稿だけのページに切り替えるといいと思います。ここをクリックでこの投稿だけのページへ移動



以下が、全体のコード。
前回と少し違うのは、画面サイズ400x400を一旦0〜1.0に正規化してあります。つまり座標(200,200)の位置をクリックすれば、プログラム内部では(0.5,0.5)という値に相当します。0~1.0で計算しておいて(そうしないと計算の際に数値が溢れてしまったので)、その後画面表示するときは、再度400x400のフォーマットに変換し直しています。

float[] x = new float[0];
float[] y = new float[0];
float m_gd = 0; //勾配降下法のmの変数
float b_gd = 1; //勾配降下法のbの変数
float m_ls = 0; //最小二乗法にmの変数
float b_ls = 0; //最小二乗法にbの変数

void setup(){
  size(400,400);
  rectMode(CORNERS);
}

void draw(){
  background(51); 
  stroke(255);
  fill(255);
  for(int i = 0; i < x.length; i++){//点の描画
    ellipse(x[i]*width,y[i]*height,4,4);
  }
  
  if(x.length > 1){//点が2点以上の場合
    leastSquare();
    gradientDescend();
  }
}

void gradientDescend(){//勾配降下法
  float learningRate = 0.1;
  float db = 0;
  float dm = 0;
  for(int i = 0; i < x.length; i++){  
    dm += x[i] * ((m_gd * x[i] + b_gd) - y[i]);
    db += (m_gd * x[i] + b_gd) - y[i];
  }
  
  m_gd -= 2.0/x.length * dm * learningRate;
  b_gd -= 2.0/x.length * db * learningRate;

  stroke(255,255,0);
  drawLine(m_gd, b_gd);  //勾配降下法による直線描画
}

void leastSquare(){//最小二乗法
  float xsum = 0;
  float ysum = 0;
  for(int i = 0; i < x.length; i++){
    xsum += x[i];
    ysum += y[i];
  }  
  float xmean = xsum / x.length;
  float ymean = ysum / y.length;  
  float xy = 0;
  float xx = 0;
  for(int i = 0; i < x.length; i++){
    xy += (x[i] - xmean) * (y[i] - ymean);
    xx += (x[i] - xmean) * (x[i] - xmean);    
  }  
  m_ls = xy / xx;
  b_ls = ymean - m_ls * xmean;
  
  stroke(255,0,0);
  drawLine(m_ls, b_ls); //最小二乗法による直線描画
}

void drawLine(float M, float B){//mとbによる直線描画
  float x1 = 0;
  float y1 = M * x1 + B;
  float x2 = 1;
  float y2 = M * x2 + B;  
  line(x1 * width, y1 * height, x2 * width, y2 * height);
}

void mousePressed(){//クリックするごとに新たな座標値を配列に追加
  x = append(x, 1.0 * mouseX / width); //0〜1.0に正規化して配列に格納
  y = append(y, 1.0 * mouseY / height);
}


学習率:learningRateは0.1に設定してありますが、0.01に落としてもいいかもしれません。ただその場合、動きはゆっくりになります。
点の数が少なすぎたり、互いに点が近すぎると、最小二乗法による直線(赤)と勾配降下法による直線(黄)とに微妙な違いがでてくるようです。学習率(単位ステップ数)を調整すればいいのかもしれませんが、そのへんについては検討中。

今回の最小二乗法と勾配降下法については、数学的に証明可能なレベルまで理解したり、ここで登場してくる方程式を自ら導き出せるほど完全に理解しているというわけではないのですが、プログラミングにおけるライブラリのように、ライブラリをつくるのは大変だけれども、ある程度の仕組みを理解して使いこなすということはできると思います。画像認識のOpenCVあるいはPID制御やFFT(高速フーリエ変換)などのライブラリもそんな感じです。まったく理解していないとライブラリを使うこともできませんが、そこそこ理解していれば、これまで計算できなかったことが可能になるので、その程度は勉強しておいたほうがいいのかもしれません。もしくは、完全な理解を得るには時間がかかってしまうので(そもそもその分野の専門家でもないので)、深い理解はとりあえずはスキップし、まずは使い方を身につけて、使っているうちに徐々にその仕組がわかってくるという順番のほうがいいのかもしれません(知識欲のために理解しようとしているわけではないので)。

続き:
CourseraのMachine LearningコースとDeep Learningコース

0 件のコメント:

コメントを投稿

人気の投稿