用java寫bp神經網絡(三)

孔子曰,吾日三省吾身。咱們若是跟程序打交道,除了一日三省吾身外,還要三日一省吾代碼。看代碼是否能夠更簡潔,更易懂,更容易擴展,更通用,算法是否能夠再優化,結構是否能夠再往上抽象。代碼在不斷的重構過程當中,更臻化境。佝僂者承蜩如是,大匠鑄劍亦復如是,藝雖小,其道一也。所謂苟日新,再日新,日日新。算法

本次對前兩篇文章代碼進行重構,主要重構函數接口體系,和權重矩陣的封裝。網絡

簡單函數

所說函數,是數學概念上的函數。數學上的函數,通常有一自變量$x$(輸入)和對應的值$y=f(x)$(輸出)。其中$x$能夠是個數字,一個向量,一個矩陣等等。咱們用泛型定義以下:app

public interface Function<I,O> {
  O valueAt(I x);
}

 I表明輸入類型,O表明輸出類型。函數

有的函數是可微的,好比神經網絡的激活函數。可微函數除了是一個函數,還可求出給定$x$處的導數,或者梯度。並且梯度類型與自變量類型一致。用泛型定義以下:優化

public interface DifferentiableFunction<I,O> extends Function<I,O> {
  I derivativeAt(I x);
}

同時,考慮到某些函數,在求得值和導數時,共同用到了一些中間變量,或者後一個能夠用到前一個的結果,咱們定義了PreCaculate接口。當咱們斷定一個函數實現了PreCaculate接口時,咱們首先調用它的PreCaculate接口,讓它預先計算出一些有用的中間變量,而後再調用其valueAt和derivativeAt求得其具體的值,這樣能夠節省一些操做步驟。定義以下:this

public interface PreCaculate<I> {
	void preCaculate(I x);
}

 基於上面的定義,咱們定義神經網絡的激活函數的類型爲:orm

public interface ActivationFunction extends DifferentiableFunction<DoubleMatrix, DoubleMatrix>

 即咱們激活函數是一個可微函數,輸入爲一個矩陣(netResult),輸出爲一個矩陣(finalResult)。blog

帶參函數

有些函數,除了自變量外,還有一些其它的係數,或者參數,咱們稱爲超參數。好比偏差函數,目標值爲參數,輸出值爲自變量。這類函數接口定義以下:接口

public interface ParamFunction<I,O,P> {
	O valueAt(I x,P param);
}

 相似的,定義其微分接口以下:get

public interface DifferentiableParamFunction<I, O, P> extends ParamFunction<I, O, P> {
	I derivativeAt(I x,P param);
}

 咱們的偏差函數定義以下:

public interface CostFunction extends DifferentiableParamFunction<DoubleMatrix,DoubleMatrix,DoubleMatrix>

 輸入,輸出,參數都爲矩陣。

組合矩陣

在神經網絡的概念中,每兩層之間有一個權重矩陣,偏置矩陣,若是輸入字向量也要調整,那麼還有一個字典矩陣。這些全部的矩陣隨着迭代過程不斷更新,以期使偏差函數達到最小。從廣義上來說,訓練樣本就是超參數,這些全部的矩陣爲自變量,偏差函數就是優化函數。那麼實質上,在調整權重矩陣時,自變量即這一系列的矩陣能夠展開拉長拼接成一個超長的向量而已,其內部的結構已可有可無。在jare的源碼中,是把這些權重矩陣的值存儲在一個長的double[]中,計算完畢後,再從這個doulbe[]中還原出各矩陣的結構。在這裏,咱們定義了一個類CompactDoubleMatrix名爲超矩陣來從更高一層封裝這些矩陣變量,使其對外表現出好像就是一個矩陣。

這個CompactDoubleMatrix的實現方式爲,在內部維護一個DoubleMatrix的有序列表List<DoubleMatrix>,而後再執行加減乘除操做時,會批量的對列表中的全部矩陣執行。這樣的封裝,咱們隨後會發現將簡化了咱們大量代碼。先把完整定義放上來。

public class CompactDoubleMatrix {
	List<DoubleMatrix> mats = new ArrayList<DoubleMatrix>();

	@SafeVarargs
	public CompactDoubleMatrix(List<DoubleMatrix>... matListArray) {
		super();
		this.append(matListArray);
	}

	public CompactDoubleMatrix(DoubleMatrix... matArray) {
		super();
		this.append(matArray);
	}

	public CompactDoubleMatrix() {
		super();
	}

	public CompactDoubleMatrix addi(CompactDoubleMatrix other) {
		this.assertSize(other);
		for (int i = 0; i < this.length(); i++)
			this.get(i).addi(other.get(i));
		return this;
	}

	public void subi(CompactDoubleMatrix other) {
		this.assertSize(other);
		for (int i = 0; i < this.length(); i++)
			this.get(i).subi(other.get(i));
	}

	public CompactDoubleMatrix add(CompactDoubleMatrix other) {
		this.assertSize(other);
		CompactDoubleMatrix result = new CompactDoubleMatrix();
		for (int i = 0; i < this.length(); i++) {
			result.append(this.get(i).add(other.get(i)));
		}
		return result;
	}

	public CompactDoubleMatrix sub(CompactDoubleMatrix other) {
		this.assertSize(other);
		CompactDoubleMatrix result = new CompactDoubleMatrix();
		for (int i = 0; i < this.length(); i++) {
			result.append(this.get(i).sub(other.get(i)));
		}
		return result;
	}

	public CompactDoubleMatrix mul(CompactDoubleMatrix other) {
		this.assertSize(other);
		CompactDoubleMatrix result = new CompactDoubleMatrix();
		for (int i = 0; i < this.length(); i++) {
			result.append(this.get(i).mul(other.get(i)));
		}
		return result;
	}

	public CompactDoubleMatrix muli(double d) {

		for (int i = 0; i < this.length(); i++) {
			this.get(i).muli(d);
		}
		return this;
	}

	public CompactDoubleMatrix mul(double d) {
		CompactDoubleMatrix result = new CompactDoubleMatrix();
		for (int i = 0; i < this.length(); i++) {
			result.append(this.get(i).mul(d));
		}
		return result;
	}

	public CompactDoubleMatrix dup() {
		CompactDoubleMatrix result = new CompactDoubleMatrix();
		for (int i = 0; i < this.length(); i++) {
			result.append(this.get(i).dup());
		}
		return result;
	}

	public double dot(CompactDoubleMatrix other) {
		double sum = 0;
		for (int i = 0; i < this.length(); i++) {
			sum += this.get(i).dot(other.get(i));
		}
		return sum;
	}

	public double norm() {
		double sum = 0;
		for (int i = 0; i < this.length(); i++) {
			double subNorm = this.get(i).norm2();
			sum += subNorm * subNorm;
		}
		return Math.sqrt(sum);
	}

	public void assertSize(CompactDoubleMatrix other) {
		assert (other != null && this.length() == other.length());
		for (int i = 0; i < this.length(); i++) {
			assert (this.get(i).sameSize(other.get(i)));
		}
	}

	@SuppressWarnings("unchecked")
	public void append(List<DoubleMatrix>... matListArray) {
		for (List<DoubleMatrix> list : matListArray) {
			this.mats.addAll(list);
		}
	}

	public void append(DoubleMatrix... matArray) {
		for (DoubleMatrix mat : matArray)
			this.mats.add(mat);
	}

	public int length() {
		return mats.size();
	}

	public DoubleMatrix get(int index) {
		return this.mats.get(index);
	}

	public DoubleMatrix getLast() {
		return this.mats.get(this.length() - 1);
	}
}

 以上介紹了對各抽象概念的封裝,下章介紹使用這些封裝如何簡化咱們的代碼。

相關文章
相關標籤/搜索