接上篇。java
在(一)和(二)中,程序的體系是Net,Propagation,Trainer,Learner,DataProvider。這篇重構這個體系。node
Net網絡
首先是Net,在上篇從新定義了激活函數和偏差函數後,內容大體是這樣的:app
List<DoubleMatrix> weights = new ArrayList<DoubleMatrix>(); List<DoubleMatrix> bs = new ArrayList<>(); List<ActivationFunction> activations = new ArrayList<>(); CostFunction costFunc; CostFunction accuracyFunc; int[] nodesNum; int layersNum; public CompactDoubleMatrix getCompact(){ return new CompactDoubleMatrix(this.weights,this.bs); }
函數getCompact()生成對應的超矩陣。ide
DataProvider函數
DataProvider是數據的提供者。學習
public interface DataProvider { DoubleMatrix getInput(); DoubleMatrix getTarget(); }
若是輸入爲向量,還包含一個向量字典。this
public interface DictDataProvider extends DataProvider { public DoubleMatrix getIndexs(); public DoubleMatrix getDict(); }
每一列爲一個樣本。getIndexs()返回輸入向量在字典中的索引。spa
我寫了一個有用的類BatchDataProviderFactory來對樣本進行批量分割,分割成minibatch。.net
int batchSize; int dataLen; DataProvider originalProvider; List<Integer> endPositions; List<DataProvider> providers; public BatchDataProviderFactory(int batchSize, DataProvider originalProvider) { super(); this.batchSize = batchSize; this.originalProvider = originalProvider; this.dataLen = this.originalProvider.getTarget().columns; this.initEndPositions(); this.initProviders(); } public BatchDataProviderFactory(DataProvider originalProvider) { this(4, originalProvider); } public List<DataProvider> getProviders() { return providers; }
batchSize指明要分多少批,getProviders返回生成的minibatch,被分的原始數據爲originalProvider。
Propagation
Propagation負責對神經網絡的正向傳播過程和反向傳播過程。接口定義以下:
public interface Propagation { public PropagationResult propagate(Net net,DataProvider provider); }
傳播函數propagate用指定數據對指定網絡進行傳播操做,返回執行結果。
BasePropagation實現了該接口,實現了簡單的反向傳播:
public class BasePropagation implements Propagation{ // 多個樣本。 protected ForwardResult forward(Net net,DoubleMatrix input) { ForwardResult result = new ForwardResult(); result.input = input; DoubleMatrix currentResult = input; int index = -1; for (DoubleMatrix weight : net.weights) { index++; DoubleMatrix b = net.bs.get(index); final ActivationFunction activation = net.activations .get(index); currentResult = weight.mmul(currentResult).addColumnVector(b); result.netResult.add(currentResult); // 乘以導數 DoubleMatrix derivative = activation.derivativeAt(currentResult); result.derivativeResult.add(derivative); currentResult = activation.valueAt(currentResult); result.finalResult.add(currentResult); } result.netResult=null;// 再也不須要。 return result; } // 多個樣本梯度平均值。 protected BackwardResult backward(Net net,DoubleMatrix target, ForwardResult forwardResult) { BackwardResult result = new BackwardResult(); DoubleMatrix output = forwardResult.getOutput(); DoubleMatrix outputDerivative = forwardResult.getOutputDerivative(); result.cost = net.costFunc.valueAt(output, target); DoubleMatrix outputDelta = net.costFunc.derivativeAt(output, target).muli(outputDerivative); if (net.accuracyFunc != null) { result.accuracy=net.accuracyFunc.valueAt(output, target); } result.deltas.add(outputDelta); for (int i = net.layersNum - 1; i >= 0; i--) { DoubleMatrix pdelta = result.deltas.get(result.deltas.size() - 1); // 梯度計算,取全部樣本平均 DoubleMatrix layerInput = i == 0 ? forwardResult.input : forwardResult.finalResult.get(i - 1); DoubleMatrix gradient = pdelta.mmul(layerInput.transpose()).div( target.columns); result.gradients.add(gradient); // 偏置梯度 result.biasGradients.add(pdelta.rowMeans()); // 計算前一層delta,若i=0,delta爲輸入層偏差,即input調整梯度,不做平均處理。 DoubleMatrix delta = net.weights.get(i).transpose().mmul(pdelta); if (i > 0) delta = delta.muli(forwardResult.derivativeResult.get(i - 1)); result.deltas.add(delta); } Collections.reverse(result.gradients); Collections.reverse(result.biasGradients); //其它的delta都不須要。 DoubleMatrix inputDeltas=result.deltas.get(result.deltas.size()-1); result.deltas.clear(); result.deltas.add(inputDeltas); return result; } @Override public PropagationResult propagate(Net net, DataProvider provider) { ForwardResult forwardResult=this.forward(net, provider.getInput()); BackwardResult backwardResult=this.backward(net, provider.getTarget(), forwardResult); PropagationResult result=new PropagationResult(backwardResult); result.output=forwardResult.getOutput(); return result; }
咱們定義的PropagationResult略爲:
public class PropagationResult{ DoubleMatrix output;// 輸出結果矩陣:outputLen*sampleLength DoubleMatrix cost;// 偏差矩陣:1*sampleLength DoubleMatrix accuracy;// 準確度矩陣:1*sampleLength private List<DoubleMatrix> gradients;// 權重梯度矩陣 private List<DoubleMatrix> biasGradients;// 偏置梯度矩陣 DoubleMatrix inputDeltas;//輸入層delta矩陣:inputLen*sampleLength public CompactDoubleMatrix getCompact(){ return new CompactDoubleMatrix(gradients,biasGradients); } }
另外一個實現了該接口的類爲MiniBatchPropagation。他在內部用並行方式對樣本進行傳播,而後對每一個minipatch結果進行綜合,內部用到了BatchDataProviderFactory類和BasePropagation類。
Trainer
Trainer接口定義爲:
public interface Trainer { public void train(Net net,DataProvider provider); }
簡單的實現類爲:
public class CommonTrainer implements Trainer { int ecophs; Learner learner; Propagation propagation; List<Double> costs = new ArrayList<>(); List<Double> accuracys = new ArrayList<>(); public void trainOne(Net net, DataProvider provider) { PropagationResult propResult = this.propagation .propagate(net, provider); learner.learn(net, propResult, provider); Double cost = propResult.getMeanCost(); Double accuracy = propResult.getMeanAccuracy(); if (cost != null) costs.add(cost); if (accuracy != null) accuracys.add(accuracy); } @Override public void train(Net net, DataProvider provider) { for (int i = 0; i < this.ecophs; i++) { System.out.println("echops:"+i); this.trainOne(net, provider); } } }
簡單的迭代echops此,沒有智能中止功能,每次迭代用Learner調節權重。
Learner
Learner根據每次傳播結果對網絡權重進行調整,接口定義以下:
public interface Learner<N extends Net,P extends DataProvider> { public void learn(N net,PropagationResult propResult,P provider); }
一個簡單的根據動量因子-自適應學習率進行調整的實現類爲:
public class MomentAdaptLearner<N extends Net, P extends DataProvider> implements Learner<N, P> { double moment = 0.7; double lmd = 1.05; double preCost = 0; double eta = 0.01; double currentEta = eta; double currentMoment = moment; CompactDoubleMatrix preGradient; public MomentAdaptLearner(double moment, double eta) { super(); this.moment = moment; this.eta = eta; this.currentEta = eta; this.currentMoment = moment; } public MomentAdaptLearner() { } @Override public void learn(N net, PropagationResult propResult, P provider) { if (this.preGradient == null) init(net, propResult, provider); double cost = propResult.getMeanCost(); this.modifyParameter(cost); System.out.println("current eta:" + this.currentEta); System.out.println("current moment:" + this.currentMoment); this.updateGradient(net, propResult, provider); } public void updateGradient(N net, PropagationResult propResult, P provider) { CompactDoubleMatrix netCompact = this.getNetCompact(net, propResult, provider); CompactDoubleMatrix gradCompact = this.getGradientCompact(net, propResult, provider); gradCompact = gradCompact.mul(currentEta * (1 - currentMoment)).addi( preGradient.mul(currentMoment)); netCompact.subi(gradCompact); this.preGradient = gradCompact; } public CompactDoubleMatrix getNetCompact(N net, PropagationResult propResult, P provider) { return net.getCompact(); } public CompactDoubleMatrix getGradientCompact(N net, PropagationResult propResult, P provider) { return propResult.getCompact(); } public void modifyParameter(double cost) { if (this.currentEta > 10) { this.currentEta = 10; } else if (this.currentEta < 0.0001) { this.currentEta = 0.0001; } else if (cost < this.preCost) { this.currentEta *= 1.05; this.currentMoment = moment; } else if (cost < 1.04 * this.preCost) { this.currentEta *= 0.7; this.currentMoment *= 0.7; } else { this.currentEta = eta; this.currentMoment = 0.1; } this.preCost = cost; } public void init(Net net, PropagationResult propResult, P provider) { PropagationResult pResult = new PropagationResult(net); preGradient = pResult.getCompact().dup(); } }
在上面的代碼中,咱們能夠看到CompactDoubleMatrix類對權重自變量的封裝,使代碼更加簡潔,它在此表現出來的就是一個超矩陣,超向量,徹底忽略了內部的結構。
同時,其子類實現了同步更新字典的功能,代碼也很簡潔,只是簡單的把須要調整的矩陣append到超矩陣中去便可,在父類中會統一對其進行調整:
public class DictMomentLearner extends MomentAdaptLearner<Net, DictDataProvider> { public DictMomentLearner(double moment, double eta) { super(moment, eta); } public DictMomentLearner() { super(); } @Override public CompactDoubleMatrix getNetCompact(Net net, PropagationResult propResult, DictDataProvider provider) { CompactDoubleMatrix result = super.getNetCompact(net, propResult, provider); result.append(provider.getDict()); return result; } @Override public CompactDoubleMatrix getGradientCompact(Net net, PropagationResult propResult, DictDataProvider provider) { CompactDoubleMatrix result = super.getGradientCompact(net, propResult, provider); result.append(DictUtil.getDictGradient(provider, propResult)); return result; } @Override public void init(Net net, PropagationResult propResult, DictDataProvider provider) { DoubleMatrix preDictGradient = DoubleMatrix.zeros( provider.getDict().rows, provider.getDict().columns); super.init(net, propResult, provider); this.preGradient.append(preDictGradient); } }