孔子曰,吾日三省吾身。咱們若是跟程序打交道,除了一日三省吾身外,還要三日一省吾代碼。看代碼是否能夠更簡潔,更易懂,更容易擴展,更通用,算法是否能夠再優化,結構是否能夠再往上抽象。代碼在不斷的重構過程當中,更臻化境。佝僂者承蜩如是,大匠鑄劍亦復如是,藝雖小,其道一也。所謂苟日新,再日新,日日新。算法
本次對前兩篇文章代碼進行重構,主要重構函數接口體系,和權重矩陣的封裝。網絡
所說函數,是數學概念上的函數。數學上的函數,通常有一自變量$x$(輸入)和對應的值$y=f(x)$(輸出)。其中$x$能夠是個數字,一個向量,一個矩陣等等。咱們用泛型定義以下:app
public interface Function<I,O> { O valueAt(I x); }
I表明輸入類型,O表明輸出類型。函數
有的函數是可微的,好比神經網絡的激活函數。可微函數除了是一個函數,還可求出給定$x$處的導數,或者梯度。並且梯度類型與自變量類型一致。用泛型定義以下:優化
public interface DifferentiableFunction<I,O> extends Function<I,O> { I derivativeAt(I x); }
同時,考慮到某些函數,在求得值和導數時,共同用到了一些中間變量,或者後一個能夠用到前一個的結果,咱們定義了PreCaculate接口。當咱們斷定一個函數實現了PreCaculate接口時,咱們首先調用它的PreCaculate接口,讓它預先計算出一些有用的中間變量,而後再調用其valueAt和derivativeAt求得其具體的值,這樣能夠節省一些操做步驟。定義以下:this
public interface PreCaculate<I> { void preCaculate(I x); }
基於上面的定義,咱們定義神經網絡的激活函數的類型爲:orm
public interface ActivationFunction extends DifferentiableFunction<DoubleMatrix, DoubleMatrix>
即咱們激活函數是一個可微函數,輸入爲一個矩陣(netResult),輸出爲一個矩陣(finalResult)。blog
有些函數,除了自變量外,還有一些其它的係數,或者參數,咱們稱爲超參數。好比偏差函數,目標值爲參數,輸出值爲自變量。這類函數接口定義以下:接口
public interface ParamFunction<I,O,P> { O valueAt(I x,P param); }
相似的,定義其微分接口以下:get
public interface DifferentiableParamFunction<I, O, P> extends ParamFunction<I, O, P> { I derivativeAt(I x,P param); }
咱們的偏差函數定義以下:
public interface CostFunction extends DifferentiableParamFunction<DoubleMatrix,DoubleMatrix,DoubleMatrix>
輸入,輸出,參數都爲矩陣。
在神經網絡的概念中,每兩層之間有一個權重矩陣,偏置矩陣,若是輸入字向量也要調整,那麼還有一個字典矩陣。這些全部的矩陣隨着迭代過程不斷更新,以期使偏差函數達到最小。從廣義上來說,訓練樣本就是超參數,這些全部的矩陣爲自變量,偏差函數就是優化函數。那麼實質上,在調整權重矩陣時,自變量即這一系列的矩陣能夠展開拉長拼接成一個超長的向量而已,其內部的結構已可有可無。在jare的源碼中,是把這些權重矩陣的值存儲在一個長的double[]中,計算完畢後,再從這個doulbe[]中還原出各矩陣的結構。在這裏,咱們定義了一個類CompactDoubleMatrix名爲超矩陣來從更高一層封裝這些矩陣變量,使其對外表現出好像就是一個矩陣。
這個CompactDoubleMatrix的實現方式爲,在內部維護一個DoubleMatrix的有序列表List<DoubleMatrix>,而後再執行加減乘除操做時,會批量的對列表中的全部矩陣執行。這樣的封裝,咱們隨後會發現將簡化了咱們大量代碼。先把完整定義放上來。
public class CompactDoubleMatrix { List<DoubleMatrix> mats = new ArrayList<DoubleMatrix>(); @SafeVarargs public CompactDoubleMatrix(List<DoubleMatrix>... matListArray) { super(); this.append(matListArray); } public CompactDoubleMatrix(DoubleMatrix... matArray) { super(); this.append(matArray); } public CompactDoubleMatrix() { super(); } public CompactDoubleMatrix addi(CompactDoubleMatrix other) { this.assertSize(other); for (int i = 0; i < this.length(); i++) this.get(i).addi(other.get(i)); return this; } public void subi(CompactDoubleMatrix other) { this.assertSize(other); for (int i = 0; i < this.length(); i++) this.get(i).subi(other.get(i)); } public CompactDoubleMatrix add(CompactDoubleMatrix other) { this.assertSize(other); CompactDoubleMatrix result = new CompactDoubleMatrix(); for (int i = 0; i < this.length(); i++) { result.append(this.get(i).add(other.get(i))); } return result; } public CompactDoubleMatrix sub(CompactDoubleMatrix other) { this.assertSize(other); CompactDoubleMatrix result = new CompactDoubleMatrix(); for (int i = 0; i < this.length(); i++) { result.append(this.get(i).sub(other.get(i))); } return result; } public CompactDoubleMatrix mul(CompactDoubleMatrix other) { this.assertSize(other); CompactDoubleMatrix result = new CompactDoubleMatrix(); for (int i = 0; i < this.length(); i++) { result.append(this.get(i).mul(other.get(i))); } return result; } public CompactDoubleMatrix muli(double d) { for (int i = 0; i < this.length(); i++) { this.get(i).muli(d); } return this; } public CompactDoubleMatrix mul(double d) { CompactDoubleMatrix result = new CompactDoubleMatrix(); for (int i = 0; i < this.length(); i++) { result.append(this.get(i).mul(d)); } return result; } public CompactDoubleMatrix dup() { CompactDoubleMatrix result = new CompactDoubleMatrix(); for (int i = 0; i < this.length(); i++) { result.append(this.get(i).dup()); } return result; } public double dot(CompactDoubleMatrix other) { double sum = 0; for (int i = 0; i < this.length(); i++) { sum += this.get(i).dot(other.get(i)); } return sum; } public double norm() { double sum = 0; for (int i = 0; i < this.length(); i++) { double subNorm = this.get(i).norm2(); sum += subNorm * subNorm; } return Math.sqrt(sum); } public void assertSize(CompactDoubleMatrix other) { assert (other != null && this.length() == other.length()); for (int i = 0; i < this.length(); i++) { assert (this.get(i).sameSize(other.get(i))); } } @SuppressWarnings("unchecked") public void append(List<DoubleMatrix>... matListArray) { for (List<DoubleMatrix> list : matListArray) { this.mats.addAll(list); } } public void append(DoubleMatrix... matArray) { for (DoubleMatrix mat : matArray) this.mats.add(mat); } public int length() { return mats.size(); } public DoubleMatrix get(int index) { return this.mats.get(index); } public DoubleMatrix getLast() { return this.mats.get(this.length() - 1); } }
以上介紹了對各抽象概念的封裝,下章介紹使用這些封裝如何簡化咱們的代碼。