Java分佈式神經網絡庫Deeplearning4j之上手實踐手寫數字圖像識別與模型訓練

環境的搭建能夠參考另外一篇文章。java

                    

  • 第一步運行MnistImagePipelineExampleSave代碼下載數據集,並進行訓練和保存

須要下載一個文件(windows默認保存在C:\Users\Administrator\AppData\Local\Temp\dl4j_Mnist)。文件存在git。若是網絡很差。建議手動下載並解壓。而後註釋掉代碼中的下載方法便可。如圖所示:git

                

訓練須要一段時間等待便可。時間長短取決於本身電腦配置。windows

  • 第二步運行MnistImagePipelineLoadChooser代碼。並選中一個手寫數字圖像。進行識別測試
package org.deeplearning4j.examples.dataexamples;

import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.swing.*;
import java.io.File;
import java.util.Arrays;
import java.util.List;

/**
 * 
 * 給定用戶一個文件選擇框來選中要測試的手寫數字圖像
 * 0-9數字 白色或者黑色背景進行識別
 */
public class MnistImagePipelineLoadChooser {
    private static Logger log = LoggerFactory.getLogger(MnistImagePipelineLoadChooser.class);


    /*
    Create a popup window to allow you to chose an image file to test against the
    trained Neural Network
    Chosen images will be automatically
    scaled to 28*28 grayscale
     */
    public static String fileChose(){
        JFileChooser fc = new JFileChooser();
        int ret = fc.showOpenDialog(null);
        if (ret == JFileChooser.APPROVE_OPTION)
        {
            File file = fc.getSelectedFile();
            String filename = file.getAbsolutePath();
            return filename;
        }
        else {
            return null;
        }
    }

    public static void main(String[] args) throws Exception{
        int height = 28;
        int width = 28;
        int channels = 1;

        List<Integer> labelList = Arrays.asList(0,1,2,3,4,5,6,7,8,9);

        // pop up file chooser
        String filechose = fileChose().toString();

        //LOAD NEURAL NETWORK

        // MnistImagePipelineExampleSave訓練並保存模型
        File locationToSave = new File("trained_mnist_model.zip");
        // 檢查保存的模型是否存在
        if(locationToSave.exists()){
            System.out.println("\n######存在保存的訓練模型######\n");
        }else{
            System.out.println("\n\n#######File not found!#######");
            System.out.println("This example depends on running ");
            System.out.println("MnistImagePipelineExampleSave");
            System.out.println("Run that Example First");
            System.out.println("#############################\n\n");


            System.exit(0);
        }

        MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(locationToSave);

        log.info("*********TEST YOUR IMAGE AGAINST SAVED NETWORK********");

        // 選擇一個文件

        File file = new File(filechose);

        // 使用NativeImageLoader轉換爲數值矩陣

        NativeImageLoader loader = new NativeImageLoader(height, width, channels);

        // 獲得圖像並賦值INDArray

        INDArray image = loader.asMatrix(file);

        // 0-255
        // 0-1
        DataNormalization scaler = new ImagePreProcessingScaler(0,1);
        scaler.transform(image);
        // 傳遞到神經網絡 並獲得機率值
        INDArray output = model.output(image);

        log.info("## The FILE CHOSEN WAS " + filechose);
        log.info("## The Neural Nets Pediction ##");
        log.info("## list of probabilities per label ##");
        //log.info("## List of Labels in Order## ");
        //有序狀態
        log.info(output.toString());
        log.info(labelList.toString());

    }



}

 

  • 選擇圖片運行後的結果
######Saved Model Found######

o.n.l.f.Nd4jBackend - Loaded [CpuBackend] backend
o.n.n.NativeOpsHolder - Number of threads used for NativeOps: 2
o.n.n.Nd4jBlas - Number of threads used for BLAS: 2
o.n.l.a.o.e.DefaultOpExecutioner - Backend used: [CPU]; OS: [Windows 7]
o.n.l.a.o.e.DefaultOpExecutioner - Cores: [4]; Memory: [1.8GB];
o.n.l.a.o.e.DefaultOpExecutioner - Blas vendor: [OPENBLAS]
o.d.n.m.MultiLayerNetwork - Starting MultiLayerNetwork with WorkspaceModes set to [training: NONE; inference: SEPARATE]
o.d.e.d.MnistImagePipelineLoadChooser - *********TEST YOUR IMAGE AGAINST SAVED NETWORK********
o.d.e.d.MnistImagePipelineLoadChooser - ## The FILE CHOSEN WAS C:\Users\Administrator\Desktop\93.png
o.d.e.d.MnistImagePipelineLoadChooser - ## The Neural Nets Pediction ##
o.d.e.d.MnistImagePipelineLoadChooser - ## list of probabilities per label ##
o.d.e.d.MnistImagePipelineLoadChooser - [0.00,  0.00,  0.00,  1.00,  0.00,  0.00,  0.00,  0.00,  0.00,  0.00]
o.d.e.d.MnistImagePipelineLoadChooser - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
圖中的數字爲: 3
數字的置信度爲:100.0%

Process finished with exit code 0

選擇的圖片爲:api

可見模型對黑白的手寫數字識別度還算是能夠的。網絡

相關資料。建議仍是去官網查閱。本博客只是進行上手實踐分佈式

https://deeplearning4j.org/cn/測試

相關文章
相關標籤/搜索