用java寫bp神經網絡(四)

接上篇。java

在(一)和(二)中,程序的體系是Net,Propagation,Trainer,Learner,DataProvider。這篇重構這個體系。node

Net網絡

首先是Net,在上篇從新定義了激活函數和偏差函數後,內容大體是這樣的:app

List<DoubleMatrix> weights = new ArrayList<DoubleMatrix>();
	List<DoubleMatrix> bs = new ArrayList<>();
	List<ActivationFunction> activations = new ArrayList<>();
	CostFunction costFunc;
	CostFunction accuracyFunc;
	int[] nodesNum;
	int layersNum;

public CompactDoubleMatrix getCompact(){
		return new CompactDoubleMatrix(this.weights,this.bs);
	}

 函數getCompact()生成對應的超矩陣。ide

DataProvider函數

DataProvider是數據的提供者。學習

public interface DataProvider {
    DoubleMatrix getInput();
    DoubleMatrix getTarget();
}

 若是輸入爲向量,還包含一個向量字典。this

public interface DictDataProvider extends DataProvider {
	public DoubleMatrix getIndexs();
	public DoubleMatrix getDict();
}

 每一列爲一個樣本。getIndexs()返回輸入向量在字典中的索引。spa

我寫了一個有用的類BatchDataProviderFactory來對樣本進行批量分割,分割成minibatch。.net

int batchSize;
	int dataLen;
	DataProvider originalProvider;
	List<Integer> endPositions;
	List<DataProvider> providers;

	public BatchDataProviderFactory(int batchSize, DataProvider originalProvider) {
		super();
		this.batchSize = batchSize;
		this.originalProvider = originalProvider;
		this.dataLen = this.originalProvider.getTarget().columns;
		this.initEndPositions();
		this.initProviders();
	}

	public BatchDataProviderFactory(DataProvider originalProvider) {
		this(4, originalProvider);
	}

	public List<DataProvider> getProviders() {
		return providers;
	}

 batchSize指明要分多少批,getProviders返回生成的minibatch,被分的原始數據爲originalProvider。

Propagation

Propagation負責對神經網絡的正向傳播過程和反向傳播過程。接口定義以下:

public interface Propagation {
	public PropagationResult propagate(Net net,DataProvider provider);
}

 傳播函數propagate用指定數據對指定網絡進行傳播操做,返回執行結果。

BasePropagation實現了該接口,實現了簡單的反向傳播:

public class BasePropagation implements Propagation{

	// 多個樣本。
	protected ForwardResult forward(Net net,DoubleMatrix input) {
		
		ForwardResult result = new ForwardResult();
		result.input = input;
		DoubleMatrix currentResult = input;
		int index = -1;
		for (DoubleMatrix weight : net.weights) {
			index++;
			DoubleMatrix b = net.bs.get(index);
			final ActivationFunction activation = net.activations
					.get(index);
			currentResult = weight.mmul(currentResult).addColumnVector(b);
			result.netResult.add(currentResult);

			// 乘以導數
			DoubleMatrix derivative = activation.derivativeAt(currentResult);
			result.derivativeResult.add(derivative);
			
			currentResult = activation.valueAt(currentResult);
			result.finalResult.add(currentResult);

		}

		result.netResult=null;// 再也不須要。
		
		return result;
	}

	

    // 多個樣本梯度平均值。
	protected BackwardResult backward(Net net,DoubleMatrix target,
			ForwardResult forwardResult) {
		BackwardResult result = new BackwardResult();
		
		DoubleMatrix output = forwardResult.getOutput();
		DoubleMatrix outputDerivative = forwardResult.getOutputDerivative();
		
		result.cost = net.costFunc.valueAt(output, target);
		DoubleMatrix outputDelta = net.costFunc.derivativeAt(output, target).muli(outputDerivative);
		if (net.accuracyFunc != null) {
			result.accuracy=net.accuracyFunc.valueAt(output, target);
		}

		result.deltas.add(outputDelta);
		for (int i = net.layersNum - 1; i >= 0; i--) {
			DoubleMatrix pdelta = result.deltas.get(result.deltas.size() - 1);

			// 梯度計算,取全部樣本平均
			DoubleMatrix layerInput = i == 0 ? forwardResult.input
					: forwardResult.finalResult.get(i - 1);
			DoubleMatrix gradient = pdelta.mmul(layerInput.transpose()).div(
					target.columns);
			result.gradients.add(gradient);
			// 偏置梯度
			result.biasGradients.add(pdelta.rowMeans());

			// 計算前一層delta,若i=0,delta爲輸入層偏差,即input調整梯度,不做平均處理。
			DoubleMatrix delta = net.weights.get(i).transpose().mmul(pdelta);
			if (i > 0)
				delta = delta.muli(forwardResult.derivativeResult.get(i - 1));
			result.deltas.add(delta);
		}
		Collections.reverse(result.gradients);
		Collections.reverse(result.biasGradients);
		
		//其它的delta都不須要。
		DoubleMatrix inputDeltas=result.deltas.get(result.deltas.size()-1);
		result.deltas.clear();
		result.deltas.add(inputDeltas);
		
		return result;
	}

	@Override
	public PropagationResult propagate(Net net, DataProvider provider) {
		ForwardResult forwardResult=this.forward(net, provider.getInput());
		BackwardResult backwardResult=this.backward(net, provider.getTarget(), forwardResult);
		PropagationResult result=new PropagationResult(backwardResult);
		result.output=forwardResult.getOutput();
		return result;
	}

 咱們定義的PropagationResult略爲:

public class PropagationResult{
		DoubleMatrix output;// 輸出結果矩陣:outputLen*sampleLength
		DoubleMatrix cost;// 偏差矩陣:1*sampleLength
		DoubleMatrix accuracy;// 準確度矩陣:1*sampleLength
		private List<DoubleMatrix> gradients;// 權重梯度矩陣
		private List<DoubleMatrix> biasGradients;// 偏置梯度矩陣
		DoubleMatrix inputDeltas;//輸入層delta矩陣:inputLen*sampleLength
		
		public CompactDoubleMatrix getCompact(){
			return new CompactDoubleMatrix(gradients,biasGradients);
		}
		
	}

 另外一個實現了該接口的類爲MiniBatchPropagation。他在內部用並行方式對樣本進行傳播,而後對每一個minipatch結果進行綜合,內部用到了BatchDataProviderFactory類和BasePropagation類。

Trainer

Trainer接口定義爲:

public interface Trainer {
    public void train(Net net,DataProvider provider);
}

簡單的實現類爲:

public class CommonTrainer implements Trainer {
	int ecophs;
	Learner learner;
	Propagation propagation;
	List<Double> costs = new ArrayList<>();
	List<Double> accuracys = new ArrayList<>();
	public void trainOne(Net net, DataProvider provider) {
		PropagationResult propResult = this.propagation
				.propagate(net, provider);
		learner.learn(net, propResult, provider);

		Double cost = propResult.getMeanCost();
		Double accuracy = propResult.getMeanAccuracy();
		if (cost != null)
			costs.add(cost);
		if (accuracy != null)
			accuracys.add(accuracy);
	}

	@Override
	public void train(Net net, DataProvider provider) {
		for (int i = 0; i < this.ecophs; i++) {
			System.out.println("echops:"+i);
			this.trainOne(net, provider);
		}

	}
}

 簡單的迭代echops此,沒有智能中止功能,每次迭代用Learner調節權重。

Learner

Learner根據每次傳播結果對網絡權重進行調整,接口定義以下:

public interface Learner<N extends Net,P extends DataProvider> {
    public void learn(N net,PropagationResult propResult,P provider);
}

 一個簡單的根據動量因子-自適應學習率進行調整的實現類爲:

public class MomentAdaptLearner<N extends Net, P extends DataProvider>
		implements Learner<N, P> {
	double moment = 0.7;
	double lmd = 1.05;
	double preCost = 0;
	double eta = 0.01;
	double currentEta = eta;
	double currentMoment = moment;
	CompactDoubleMatrix preGradient;

	public MomentAdaptLearner(double moment, double eta) {
		super();
		this.moment = moment;
		this.eta = eta;
		this.currentEta = eta;
		this.currentMoment = moment;
	}

	public MomentAdaptLearner() {

	}

	@Override
	public void learn(N net, PropagationResult propResult, P provider) {
		if (this.preGradient == null)
			init(net, propResult, provider);

		double cost = propResult.getMeanCost();
		this.modifyParameter(cost);
		System.out.println("current eta:" + this.currentEta);
		System.out.println("current moment:" + this.currentMoment);
		this.updateGradient(net, propResult, provider);

	}

	public void updateGradient(N net, PropagationResult propResult, P provider) {
		CompactDoubleMatrix netCompact = this.getNetCompact(net, propResult,
				provider);
		CompactDoubleMatrix gradCompact = this.getGradientCompact(net,
				propResult, provider);
		gradCompact = gradCompact.mul(currentEta * (1 - currentMoment)).addi(
				preGradient.mul(currentMoment));
		netCompact.subi(gradCompact);
		this.preGradient = gradCompact;
	}

	public CompactDoubleMatrix getNetCompact(N net,
			PropagationResult propResult, P provider) {
		return net.getCompact();
	}

	public CompactDoubleMatrix getGradientCompact(N net,
			PropagationResult propResult, P provider) {
		return propResult.getCompact();
	}

	public void modifyParameter(double cost) {

		if (this.currentEta > 10) {
			this.currentEta = 10;
		} else if (this.currentEta < 0.0001) {
			this.currentEta = 0.0001;
		} else if (cost < this.preCost) {
			this.currentEta *= 1.05;
			this.currentMoment = moment;
		} else if (cost < 1.04 * this.preCost) {
			this.currentEta *= 0.7;
			this.currentMoment *= 0.7;
		} else {
			this.currentEta = eta;
			this.currentMoment = 0.1;
		}
		this.preCost = cost;
	}

	public void init(Net net, PropagationResult propResult, P provider) {
		PropagationResult pResult = new PropagationResult(net);
		preGradient = pResult.getCompact().dup();
	}

}

 在上面的代碼中,咱們能夠看到CompactDoubleMatrix類對權重自變量的封裝,使代碼更加簡潔,它在此表現出來的就是一個超矩陣,超向量,徹底忽略了內部的結構。

同時,其子類實現了同步更新字典的功能,代碼也很簡潔,只是簡單的把須要調整的矩陣append到超矩陣中去便可,在父類中會統一對其進行調整:

public class DictMomentLearner extends
		MomentAdaptLearner<Net, DictDataProvider> {

	public DictMomentLearner(double moment, double eta) {
		super(moment, eta);
	}

	public DictMomentLearner() {
		super();
	}

	@Override
	public CompactDoubleMatrix getNetCompact(Net net,
			PropagationResult propResult, DictDataProvider provider) {
		CompactDoubleMatrix result = super.getNetCompact(net, propResult,
				provider);
		result.append(provider.getDict());
		return result;
	}

	@Override
	public CompactDoubleMatrix getGradientCompact(Net net,
			PropagationResult propResult, DictDataProvider provider) {
		CompactDoubleMatrix result = super.getGradientCompact(net, propResult,
				provider);
		result.append(DictUtil.getDictGradient(provider, propResult));
		return result;
	}

	@Override
	public void init(Net net, PropagationResult propResult,
			DictDataProvider provider) {
		DoubleMatrix preDictGradient = DoubleMatrix.zeros(
				provider.getDict().rows, provider.getDict().columns);
		super.init(net, propResult, provider);
		this.preGradient.append(preDictGradient);
	}
}
相關文章
相關標籤/搜索