Gradient Checkingを実装した
Sponsored Links
皆さんこんにちは
お元気ですか?私は凍えています。
今日はGradient Checkingを紹介します。
Gradient Checking
機械学習の実装を行っていると、
勾配の計算をする箇所が出ることが多いです。(Back Propagationなど)
簡単な計算をするには適当に手で計算したものを実装すればよいですが
だんだん本当に合っているのかわからないとか、実装方針変えた場合でも等価になるのかなど
不安になってくる箇所が出てきます。(多分)
そんなときにGradient Checkingを使いましょう。
数式
をパラメータに基づく関数J(関数J、パラメータの内容は自分で定義、構築してください)
と定義すると以下がGradient Checkingを実装する式となります。
EPSILONは小さい定数です。大体0.0001程度が良いそうです。
プログラムとして上記式を構築するには、右を計算する関数と左を計算する関数を用意し、
それぞれ比較すれば良いことになります。
実装
今回は簡単な数式で実験しようと思います。
式:
微分式:
左辺(gradient_calc)と右辺(function)が
Source Code
# coding:utf-8 import numpy as np def function(x): return x ** 2 def gradient_calc(x): return 2 * x def gradient_miss_calc(x): return 3 * x def gradient_checking(forward, gradient, x): """ :params forward: forward function :params backward: gradient function :params x: data :return: correct(True) or fall(False) """ epsilon = 0.0001 gradient_checker = (forward(x + epsilon) - forward(x - epsilon)) / (2 * epsilon) diff = np.abs(gradient(x) - gradient_checker) if diff < 0.0001: print "correct" return True else: print "incorrect" return False if __name__ == '__main__': x = 3 print "correct pattern" gradient_checking(function, gradient_calc, x) print "incorrect pattern" gradient_checking(function, gradient_miss_calc, x)
標準出力
correct pattern correct incorrect pattern incorrect
最初は正しい微分の式を与えていますが、次の式には異なる微分式を与えています。
ちゃんと微分式が間違っていることを検出していることがわかります。
今回は、左辺と右辺はだいたい等しいとのことなので、一定以下の差であれば
正しいということにしています。