實現步驟git
- 1.隨機初始化若干個智能體(神經網絡),並讓智能體識別訓練數據,並對識別結果進行排序
- 2.隨機在排序結果中選擇一個做爲母本,並在比母本識別率更高的智能體中隨機選擇一個做爲父本
- 3.隨機選擇母本或父本同位的神經網絡超參組成新的智能體
- 4.按照母本的排序對智能體進行超參調整,排序越靠後調整幅度越大(1%~10%)之間
- 5.讓新的智能體識別訓練集並放入排行榜,並移除排行榜最後一位
- 6.重複2~5過程,讓識別率愈來愈高
這個過程就相似於天然界的優勝劣汰,將神經網絡超參看做dna,超參的調整看做dna的突變;固然還能夠把擁有不一樣隱藏層的神經網絡看做不一樣的物種,讓競爭過程更加多樣化.固然咱們這裏只討論一種神經網絡的狀況安全
優點: 能夠解決不少沒有頭緒的問題 劣勢: 訓練效率極低網絡
gitee地址:app
https://gitee.com/ichiva/gnn.git
實現步驟 1.進化接口dom
public interface Evolution { /** * 遺傳 * @param mDna * @param fDna * @return */ INDArray inheritance(INDArray mDna,INDArray fDna); /** * 突變 * @param dna * @param v * @param r 突變範圍 * @return */ INDArray mutation(INDArray dna,double v, double r); /** * 置換 * @param dna * @param v * @return */ INDArray substitution(INDArray dna,double v); /** * 外源 * @param dna * @param v * @return */ INDArray other(INDArray dna,double v); /** * DNA 是否同源 * @param mDna * @param fDna * @return */ boolean iSogeny(INDArray mDna, INDArray fDna); }
一個比較通用的實現ide
public class MnistEvolution implements Evolution { private static final MnistEvolution instance = new MnistEvolution(); public static MnistEvolution getInstance() { return instance; } @Override public INDArray inheritance(INDArray mDna, INDArray fDna) { if(mDna == fDna) return mDna; long[] mShape = mDna.shape(); if(!iSogeny(mDna,fDna)){ throw new RuntimeException("非同源dna"); } INDArray nDna = Nd4j.create(mShape); NdIndexIterator it = new NdIndexIterator(mShape); while (it.hasNext()){ long[] next = it.next(); double val; if(Math.random() > 0.5){ val = fDna.getDouble(next); }else { val = mDna.getDouble(next); } nDna.putScalar(next,val); } return nDna; } @Override public INDArray mutation(INDArray dna, double v, double r) { long[] shape = dna.shape(); INDArray nDna = Nd4j.create(shape); NdIndexIterator it = new NdIndexIterator(shape); while (it.hasNext()) { long[] next = it.next(); if(Math.random() < v){ dna.putScalar(next,dna.getDouble(next) + ((Math.random() - 0.5) * r * 2)); }else { nDna.putScalar(next,dna.getDouble(next)); } } return nDna; } @Override public INDArray substitution(INDArray dna, double v) { long[] shape = dna.shape(); INDArray nDna = Nd4j.create(shape); NdIndexIterator it = new NdIndexIterator(shape); while (it.hasNext()) { long[] next = it.next(); if(Math.random() > v){ long[] tag = new long[shape.length]; for (int i = 0; i < shape.length; i++) { tag[i] = (long) (Math.random() * shape[i]); } nDna.putScalar(next,dna.getDouble(tag)); }else { nDna.putScalar(next,dna.getDouble(next)); } } return nDna; } @Override public INDArray other(INDArray dna, double v) { long[] shape = dna.shape(); INDArray nDna = Nd4j.create(shape); NdIndexIterator it = new NdIndexIterator(shape); while (it.hasNext()) { long[] next = it.next(); if(Math.random() > v){ nDna.putScalar(next,Math.random()); }else { nDna.putScalar(next,dna.getDouble(next)); } } return nDna; } @Override public boolean iSogeny(INDArray mDna, INDArray fDna) { long[] mShape = mDna.shape(); long[] fShape = fDna.shape(); if (mShape.length == fShape.length) { for (int i = 0; i < mShape.length; i++) { if (mShape[i] != fShape[i]) { return false; } } return true; } return false; } }
定義智能體配置接口函數
public interface AgentConfig { /** * 輸入量 * @return */ int getInput(); /** * 輸出量 * @return */ int getOutput(); /** * 神經網絡配置 * @return */ MultiLayerConfiguration getMultiLayerConfiguration(); }
按手寫數字識別進行配置實現測試
public class MnistConfig implements AgentConfig { @Override public int getInput() { return 28 * 28; } @Override public int getOutput() { return 10; } @Override public MultiLayerConfiguration getMultiLayerConfiguration() { return new NeuralNetConfiguration.Builder() .seed((long) (Math.random() * Long.MAX_VALUE)) .updater(new Nesterovs(0.006, 0.9)) .l2(1e-4) .list() .layer(0, new DenseLayer.Builder() .nIn(getInput()) .nOut(1000) .activation(Activation.RELU) .weightInit(WeightInit.XAVIER) .build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer .nIn(1000) .nOut(getOutput()) .activation(Activation.SOFTMAX) .weightInit(WeightInit.XAVIER) .build()) .pretrain(false).backprop(true) .build(); } }
智能體基類ui
@Getter public class Agent { private final AgentConfig config; private final INDArray dna; private final MultiLayerNetwork multiLayerNetwork; /** * 採用默認方法初始化參數 * @param config */ public Agent(AgentConfig config){ this(config,null); } /** * * @param config * @param dna */ public Agent(AgentConfig config, INDArray dna){ if(dna == null){ this.config = config; MultiLayerConfiguration conf = config.getMultiLayerConfiguration(); this.multiLayerNetwork = new MultiLayerNetwork(conf); multiLayerNetwork.init(); this.dna = multiLayerNetwork.params(); }else { this.config = config; MultiLayerConfiguration conf = config.getMultiLayerConfiguration(); this.multiLayerNetwork = new MultiLayerNetwork(conf); multiLayerNetwork.init(dna,true); this.dna = dna; } } }
手寫數字智能體實現類this
@Getter @Setter public class MnistAgent extends Agent { private static final AtomicInteger index = new AtomicInteger(0); private String name; /** * 環境適應分數 */ private double score; /** * 驗證分數 */ private double validScore; public MnistAgent(AgentConfig config) { this(config,null); } public MnistAgent(AgentConfig config, INDArray dna) { super(config, dna); name = "agent-" + index.incrementAndGet(); } public static MnistConfig mnistConfig = new MnistConfig(); public static MnistAgent newInstance(){ return new MnistAgent(mnistConfig); } public static MnistAgent create(INDArray dna){ return new MnistAgent(mnistConfig,dna); } }
手寫數字識別環境構建
@Slf4j public class MnistEnv { /** * 環境數據 */ private static final ThreadLocal<MnistDataSetIterator> tLocal = ThreadLocal.withInitial(() -> { try { return new MnistDataSetIterator(128, true, 0); } catch (IOException e) { throw new RuntimeException("mnist 文件讀取失敗"); } }); private static final ThreadLocal<MnistDataSetIterator> testLocal = ThreadLocal.withInitial(() -> { try { return new MnistDataSetIterator(128, false, 0); } catch (IOException e) { throw new RuntimeException("mnist 文件讀取失敗"); } }); private static final MnistEvolution evolution = MnistEvolution.getInstance(); /** * 環境承載上限 * * 超過上限AI會進行激烈競爭 */ private final int max; private Double maxScore,minScore; /** * 環境中的生命體 * * 新生代與歷史代共同排序,選出最適應環境的個體 */ //2個變量,一個隊列保存KEY的順序,一個MAP保存KEY對應的具體對象的數據 線程安全map private final TreeMap<Double,MnistAgent> lives = new TreeMap<>(); /** * 初始化環境 * * 1.向環境中初始化ai * 2.將初始化ai進行環境適應性測試,並排序 * @param max */ public MnistEnv(int max){ this.max = max; for (int i = 0; i < max; i++) { MnistAgent agent = MnistAgent.newInstance(); test(agent); synchronized (lives) { lives.put(agent.getScore(),agent); } log.info("初始化智能體 name = {} , score = {}",i,agent.getScore()); } synchronized (lives) { minScore = lives.firstKey(); maxScore = lives.lastKey(); } } /** * 環境適應性評估 * @param ai */ public void test(MnistAgent ai){ MultiLayerNetwork network = ai.getMultiLayerNetwork(); MnistDataSetIterator dataIterator = tLocal.get(); Evaluation eval = new Evaluation(ai.getConfig().getOutput()); try { while (dataIterator.hasNext()) { DataSet data = dataIterator.next(); INDArray output = network.output(data.getFeatures(), false); eval.eval(data.getLabels(),output); } }finally { dataIterator.reset(); } ai.setScore(eval.accuracy()); } /** * 遷移評估 * * @param ai */ public void validation(MnistAgent ai){ MultiLayerNetwork network = ai.getMultiLayerNetwork(); MnistDataSetIterator dataIterator = testLocal.get(); Evaluation eval = new Evaluation(ai.getConfig().getOutput()); try { while (dataIterator.hasNext()) { DataSet data = dataIterator.next(); INDArray output = network.output(data.getFeatures(), false); eval.eval(data.getLabels(),output); } }finally { dataIterator.reset(); } ai.setValidScore(eval.accuracy()); } /** * 進化 * * 每輪隨機建立ai並放入環境中進行優勝劣汰 * @param n 進化次數 */ public void evolution(int n){ BlockThreadPool blockThreadPool=new BlockThreadPool(2); for (int i = 0; i < n; i++) { blockThreadPool.execute(() -> contend(newLive())); } // for (int i = 0; i < n; i++) { // contend(newLive()); // } } /** * 競爭 * @param ai */ public void contend(MnistAgent ai){ test(ai); quality(ai); double score = ai.getScore(); if(score <= minScore){ UI.put("沒法生存",String.format("name = %s, score = %s", ai.getName(),ai.getScore())); return; } Map.Entry<Double, MnistAgent> lastEntry; synchronized (lives) { lives.put(score,ai); if (lives.size() > max) { MnistAgent lastAI = lives.remove(lives.firstKey()); UI.put("淘 汰 ",String.format("name = %s, score = %s", lastAI.getName(),lastAI.getScore())); } lastEntry = lives.lastEntry(); minScore = lives.firstKey(); } Double lastScore = lastEntry.getKey(); if(lastScore > maxScore){ maxScore = lastScore; MnistAgent agent = lastEntry.getValue(); validation(agent); UI.put("max驗證",String.format("score = %s,validScore = %s",lastScore,agent.getValidScore())); try { Warehouse.write(agent); } catch (IOException ex) { log.error("保存對象失敗",ex); } } } ArrayList<Double> scoreList = new ArrayList<>(100); ArrayList<Integer> avgList = new ArrayList<>(); private void quality(MnistAgent ai) { synchronized (scoreList) { scoreList.add(ai.getScore()); if (scoreList.size() >= 100) { double avg = scoreList.stream().mapToDouble(e -> e) .average().getAsDouble(); avgList.add((int) (avg * 1000)); StringBuffer buffer = new StringBuffer(); avgList.forEach(e -> buffer.append(e).append('\t')); UI.put("平均得分",String.format("aix100 avg = %s",buffer.toString())); scoreList.clear(); } } } /** * 隨機生成新智能體 * * 徹底隨機產生母本 * 隨機從比目標相同或更高評分中選擇父本 * * 基因進化在1%~10%之間進行,評分越高基於越穩定 */ public MnistAgent newLive(){ double r = Math.random(); //基因突變率 double v = r / 11 + 0.01; //母本 MnistAgent mAgent = getMother(r); //父本 MnistAgent fAgent = getFather(r); int i = (int) (Math.random() * 3); INDArray newDNA = evolution.inheritance(mAgent.getDna(), fAgent.getDna()); switch (i){ case 0: newDNA = evolution.other(newDNA,v); break; case 1: newDNA = evolution.mutation(newDNA,v,0.1); break; case 2: newDNA = evolution.substitution(newDNA,v); break; } return MnistAgent.create(newDNA); } /** * 父本只選擇比母本評分高的樣本 * @param r * @return */ private MnistAgent getFather(double r) { r += (Math.random() * (1-r)); return getMother(r); } private MnistAgent getMother(double r) { int index = (int) (r * max); return getMnistAgent(index); } private MnistAgent getMnistAgent(int index) { synchronized (lives) { Iterator<Map.Entry<Double, MnistAgent>> it = lives.entrySet().iterator(); for (int i = 0; i < index; i++) { it.next(); } return it.next().getValue(); } } }
主函數
@Slf4j public class Program { public static void main(String[] args) { UI.put("開始時間",new Date().toLocaleString()); MnistEnv env = new MnistEnv(128); env.evolution(Integer.MAX_VALUE); } }
運行截圖
gitee地址:
https://gitee.com/ichiva/gnn.git