dl4j(Deeplearning4j)使用遺傳神經網絡完成手寫數字識別

實現步驟git

  • 1.隨機初始化若干個智能體(神經網絡),並讓智能體識別訓練數據,並對識別結果進行排序
  • 2.隨機在排序結果中選擇一個做爲母本,並在比母本識別率更高的智能體中隨機選擇一個做爲父本
  • 3.隨機選擇母本或父本同位的神經網絡超參組成新的智能體
  • 4.按照母本的排序對智能體進行超參調整,排序越靠後調整幅度越大(1%~10%)之間
  • 5.讓新的智能體識別訓練集並放入排行榜,並移除排行榜最後一位
  • 6.重複2~5過程,讓識別率愈來愈高

這個過程就相似於天然界的優勝劣汰,將神經網絡超參看做dna,超參的調整看做dna的突變;固然還能夠把擁有不一樣隱藏層的神經網絡看做不一樣的物種,讓競爭過程更加多樣化.固然咱們這裏只討論一種神經網絡的狀況安全

優點: 能夠解決不少沒有頭緒的問題 劣勢: 訓練效率極低網絡

gitee地址:app

https://gitee.com/ichiva/gnn.git

實現步驟 1.進化接口dom

public interface Evolution {

    /**
     * 遺傳
     * @param mDna
     * @param fDna
     * @return
     */
    INDArray inheritance(INDArray mDna,INDArray fDna);

    /**
     * 突變
     * @param dna
     * @param v
     * @param r 突變範圍
     * @return
     */
    INDArray mutation(INDArray dna,double v, double r);

    /**
     * 置換
     * @param dna
     * @param v
     * @return
     */
    INDArray substitution(INDArray dna,double v);

    /**
     * 外源
     * @param dna
     * @param v
     * @return
     */
    INDArray other(INDArray dna,double v);

    /**
     * DNA 是否同源
     * @param mDna
     * @param fDna
     * @return
     */
    boolean iSogeny(INDArray mDna, INDArray fDna);
}

一個比較通用的實現ide

public class MnistEvolution implements Evolution {

    private static final MnistEvolution instance = new MnistEvolution();

    public static MnistEvolution getInstance() {
        return instance;
    }

    @Override
    public INDArray inheritance(INDArray mDna, INDArray fDna) {
        if(mDna == fDna) return mDna;

        long[] mShape = mDna.shape();
        if(!iSogeny(mDna,fDna)){
            throw new RuntimeException("非同源dna");
        }

        INDArray nDna = Nd4j.create(mShape);
        NdIndexIterator it = new NdIndexIterator(mShape);
        while (it.hasNext()){
            long[] next = it.next();

            double val;
            if(Math.random() > 0.5){
                val = fDna.getDouble(next);
            }else {
                val = mDna.getDouble(next);
            }

            nDna.putScalar(next,val);
        }

        return nDna;
    }

    @Override
    public INDArray mutation(INDArray dna, double v, double r) {
        long[] shape = dna.shape();
        INDArray nDna = Nd4j.create(shape);
        NdIndexIterator it = new NdIndexIterator(shape);
        while (it.hasNext()) {
            long[] next = it.next();
            if(Math.random() < v){
                dna.putScalar(next,dna.getDouble(next) + ((Math.random() - 0.5) * r * 2));
            }else {
                nDna.putScalar(next,dna.getDouble(next));
            }
        }

        return nDna;
    }

    @Override
    public INDArray substitution(INDArray dna, double v) {
        long[] shape = dna.shape();
        INDArray nDna = Nd4j.create(shape);
        NdIndexIterator it = new NdIndexIterator(shape);
        while (it.hasNext()) {
            long[] next = it.next();
            if(Math.random() > v){
                long[] tag = new long[shape.length];
                for (int i = 0; i < shape.length; i++) {
                    tag[i] = (long) (Math.random() * shape[i]);
                }

                nDna.putScalar(next,dna.getDouble(tag));
            }else {
                nDna.putScalar(next,dna.getDouble(next));
            }
        }

        return nDna;
    }

    @Override
    public INDArray other(INDArray dna, double v) {
        long[] shape = dna.shape();
        INDArray nDna = Nd4j.create(shape);
        NdIndexIterator it = new NdIndexIterator(shape);
        while (it.hasNext()) {
            long[] next = it.next();
            if(Math.random() > v){
                nDna.putScalar(next,Math.random());
            }else {
                nDna.putScalar(next,dna.getDouble(next));
            }
        }

        return nDna;
    }

    @Override
    public boolean iSogeny(INDArray mDna, INDArray fDna) {
        long[] mShape = mDna.shape();
        long[] fShape = fDna.shape();

        if (mShape.length == fShape.length) {
            for (int i = 0; i < mShape.length; i++) {
                if (mShape[i] != fShape[i]) {
                    return false;
                }
            }

            return true;
        }

        return false;
    }
}

定義智能體配置接口函數

public interface AgentConfig {
    /**
     * 輸入量
     * @return
     */
    int getInput();

    /**
     * 輸出量
     * @return
     */
    int getOutput();

    /**
     * 神經網絡配置
     * @return
     */
    MultiLayerConfiguration getMultiLayerConfiguration();
}

按手寫數字識別進行配置實現測試

public class MnistConfig implements AgentConfig {
    @Override
    public int getInput() {
        return 28 * 28;
    }

    @Override
    public int getOutput() {
        return 10;
    }

    @Override
    public MultiLayerConfiguration getMultiLayerConfiguration() {
        return new NeuralNetConfiguration.Builder()
                .seed((long) (Math.random() * Long.MAX_VALUE))
                .updater(new Nesterovs(0.006, 0.9))
                .l2(1e-4)
                .list()
                .layer(0, new DenseLayer.Builder()
                        .nIn(getInput())
                        .nOut(1000)
                        .activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
                        .nIn(1000)
                        .nOut(getOutput())
                        .activation(Activation.SOFTMAX)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .pretrain(false).backprop(true)
                .build();
    }
}

智能體基類ui

@Getter
public class Agent {

    private final AgentConfig config;
    private final INDArray dna;
    private final MultiLayerNetwork multiLayerNetwork;

    /**
     * 採用默認方法初始化參數
     * @param config
     */
    public Agent(AgentConfig config){
        this(config,null);
    }


    /**
     *
     * @param config
     * @param dna
     */
    public Agent(AgentConfig config, INDArray dna){
        if(dna == null){
            this.config = config;
            MultiLayerConfiguration conf = config.getMultiLayerConfiguration();
            this.multiLayerNetwork = new MultiLayerNetwork(conf);
            multiLayerNetwork.init();
            this.dna = multiLayerNetwork.params();
        }else {
            this.config = config;
            MultiLayerConfiguration conf = config.getMultiLayerConfiguration();
            this.multiLayerNetwork = new MultiLayerNetwork(conf);
            multiLayerNetwork.init(dna,true);
            this.dna = dna;
        }
    }
}

手寫數字智能體實現類this

@Getter
@Setter
public class MnistAgent extends Agent {

    private static final AtomicInteger index = new AtomicInteger(0);

    private String name;

    /**
     * 環境適應分數
     */
    private double score;

    /**
     * 驗證分數
     */
    private double validScore;

    public MnistAgent(AgentConfig config) {
        this(config,null);
    }

    public MnistAgent(AgentConfig config, INDArray dna) {
        super(config, dna);
        name = "agent-" + index.incrementAndGet();
    }

    public static MnistConfig mnistConfig = new MnistConfig();

    public static MnistAgent newInstance(){
        return new MnistAgent(mnistConfig);
    }

    public static MnistAgent create(INDArray dna){
        return new MnistAgent(mnistConfig,dna);
    }
}

手寫數字識別環境構建

@Slf4j
public class MnistEnv {
    /**
     * 環境數據
     */
    private static final ThreadLocal<MnistDataSetIterator> tLocal = ThreadLocal.withInitial(() -> {
        try {
            return new MnistDataSetIterator(128, true, 0);
        } catch (IOException e) {
            throw new RuntimeException("mnist 文件讀取失敗");
        }
    });

    private static final ThreadLocal<MnistDataSetIterator> testLocal = ThreadLocal.withInitial(() -> {
        try {
            return new MnistDataSetIterator(128, false, 0);
        } catch (IOException e) {
            throw new RuntimeException("mnist 文件讀取失敗");
        }
    });

    private static final MnistEvolution evolution = MnistEvolution.getInstance();

    /**
     * 環境承載上限
     *
     * 超過上限AI會進行激烈競爭
     */
    private final int max;
    private Double maxScore,minScore;

    /**
     * 環境中的生命體
     *
     * 新生代與歷史代共同排序,選出最適應環境的個體
     */
    //2個變量,一個隊列保存KEY的順序,一個MAP保存KEY對應的具體對象的數據  線程安全map
    private final TreeMap<Double,MnistAgent> lives = new TreeMap<>();

    /**
     * 初始化環境
     *
     * 1.向環境中初始化ai
     * 2.將初始化ai進行環境適應性測試,並排序
     * @param max
     */
    public MnistEnv(int max){
        this.max = max;

        for (int i = 0; i < max; i++) {
            MnistAgent agent = MnistAgent.newInstance();
            test(agent);
            synchronized (lives) {
                lives.put(agent.getScore(),agent);
            }
            log.info("初始化智能體 name = {} , score = {}",i,agent.getScore());
        }



        synchronized (lives) {
            minScore = lives.firstKey();
            maxScore = lives.lastKey();
        }
    }

    /**
     * 環境適應性評估
     * @param ai
     */
    public void test(MnistAgent ai){
        MultiLayerNetwork network = ai.getMultiLayerNetwork();
        MnistDataSetIterator dataIterator = tLocal.get();
        Evaluation eval = new Evaluation(ai.getConfig().getOutput());
        try {
            while (dataIterator.hasNext()) {
                DataSet data = dataIterator.next();
                INDArray output = network.output(data.getFeatures(), false);
                eval.eval(data.getLabels(),output);
            }
        }finally {
            dataIterator.reset();
        }
        ai.setScore(eval.accuracy());
    }

    /**
     * 遷移評估
     *
     * @param ai
     */
    public void validation(MnistAgent ai){
        MultiLayerNetwork network = ai.getMultiLayerNetwork();
        MnistDataSetIterator dataIterator = testLocal.get();
        Evaluation eval = new Evaluation(ai.getConfig().getOutput());
        try {
            while (dataIterator.hasNext()) {
                DataSet data = dataIterator.next();
                INDArray output = network.output(data.getFeatures(), false);
                eval.eval(data.getLabels(),output);
            }
        }finally {
            dataIterator.reset();
        }
        ai.setValidScore(eval.accuracy());
    }


    /**
     * 進化
     *
     * 每輪隨機建立ai並放入環境中進行優勝劣汰
     * @param n 進化次數
     */
    public void evolution(int n){
        BlockThreadPool blockThreadPool=new BlockThreadPool(2);
        for (int i = 0; i < n; i++) {
            blockThreadPool.execute(() -> contend(newLive()));
        }

//        for (int i = 0; i < n; i++) {
//            contend(newLive());
//        }
    }

    /**
     * 競爭
     * @param ai
     */
    public void contend(MnistAgent ai){
        test(ai);
        quality(ai);
        double score = ai.getScore();
        if(score <= minScore){
            UI.put("沒法生存",String.format("name = %s,  score = %s", ai.getName(),ai.getScore()));
            return;
        }

        Map.Entry<Double, MnistAgent> lastEntry;
        synchronized (lives) {
            lives.put(score,ai);
            if (lives.size() > max) {
                MnistAgent lastAI = lives.remove(lives.firstKey());
                UI.put("淘 汰 ",String.format("name = %s, score = %s", lastAI.getName(),lastAI.getScore()));
            }
            lastEntry = lives.lastEntry();
            minScore = lives.firstKey();
        }

        Double lastScore = lastEntry.getKey();
        if(lastScore > maxScore){
            maxScore = lastScore;
            MnistAgent agent = lastEntry.getValue();
            validation(agent);
            UI.put("max驗證",String.format("score = %s,validScore = %s",lastScore,agent.getValidScore()));

            try {
                Warehouse.write(agent);
            } catch (IOException ex) {
                log.error("保存對象失敗",ex);
            }
        }
    }

    ArrayList<Double> scoreList = new ArrayList<>(100);
    ArrayList<Integer> avgList = new ArrayList<>();
    private void quality(MnistAgent ai) {
        synchronized (scoreList) {
            scoreList.add(ai.getScore());
            if (scoreList.size() >= 100) {
                double avg = scoreList.stream().mapToDouble(e -> e)
                        .average().getAsDouble();
                avgList.add((int) (avg * 1000));
                StringBuffer buffer = new StringBuffer();
                avgList.forEach(e -> buffer.append(e).append('\t'));
                UI.put("平均得分",String.format("aix100 avg = %s",buffer.toString()));
                scoreList.clear();
            }
        }
    }

    /**
     * 隨機生成新智能體
     *
     * 徹底隨機產生母本
     * 隨機從比目標相同或更高評分中選擇父本
     *
     * 基因進化在1%~10%之間進行,評分越高基於越穩定
     */
    public MnistAgent newLive(){
        double r = Math.random();
        //基因突變率
        double v = r / 11 + 0.01;
        //母本
        MnistAgent mAgent = getMother(r);
        //父本
        MnistAgent fAgent = getFather(r);

        int i = (int) (Math.random() * 3);
        INDArray newDNA = evolution.inheritance(mAgent.getDna(), fAgent.getDna());
        switch (i){
            case 0:
                newDNA = evolution.other(newDNA,v);
                break;
            case 1:
                newDNA = evolution.mutation(newDNA,v,0.1);
                break;
            case 2:
                newDNA = evolution.substitution(newDNA,v);
                break;
        }
        return MnistAgent.create(newDNA);
    }

    /**
     * 父本只選擇比母本評分高的樣本
     * @param r
     * @return
     */
    private MnistAgent getFather(double r) {
        r += (Math.random() * (1-r));
        return getMother(r);
    }

    private MnistAgent getMother(double r) {
        int index = (int) (r * max);
        return getMnistAgent(index);
    }

    private MnistAgent getMnistAgent(int index) {
        synchronized (lives) {
            Iterator<Map.Entry<Double, MnistAgent>> it = lives.entrySet().iterator();
            for (int i = 0; i < index; i++) {
                it.next();
            }

            return it.next().getValue();
        }
    }
}

主函數

@Slf4j
public class Program {

    public static void main(String[] args) {
        UI.put("開始時間",new Date().toLocaleString());
        MnistEnv env = new MnistEnv(128);
        env.evolution(Integer.MAX_VALUE);
    }
}

運行截圖

gitee地址:

https://gitee.com/ichiva/gnn.git
相關文章
相關標籤/搜索