Deeplearning4j 手寫體數字識別

最近這幾年,深度學習很火,包括本身在內的不少對機器學習仍是隻知其一;不知其二的小白也開始用深度學習作些應用。因爲小白的等級不高,算法本身寫不出來,因此就用了開源庫。Deep Learning的開源庫有多,若是以語言來劃分的話,就有Python系列的tensowflow,theano,keras,C/C++系列的Caffe,還有Lua系列的torch等等。但我們公司是用Java爲主,大部分項目最終也是作成一個Java Web的服務,因此我最終選擇了Deeplearning4j。java

    Deeplearning4j是國外創業公司Skymind的產品。目前最新的版本更新到了0.7.2。源碼所有公開並託管在github上(https://github.com/deeplearning4j/deeplearning4j)。從這個庫的名字上能夠看出,它就是轉爲Java程序員寫的Deep Learning庫。其實這個庫吸引人的地方不只僅在於它支持Java,更爲重要的是它能夠支持Spark。因爲Deep Learning模型的訓練須要大量的內存,並且原始數據的存儲有時候也須要很大的外存空間,因此若是能夠利用集羣來處理即是最好不過了。固然,除了Deeplearning4j之外,還有一些Deep Learning的庫能夠支持Spark,好比yahoo/CaffeOnSpark,AMPLab/SparkNet以及Intel最近開源的BigDL。這些庫我本身都沒怎麼用過,因此就很少說了,這裏重點說說Deeplearning4j的使用。python

    通常開始使用別人的代碼庫,都會先跑一些demo,或者說Hello World的例子,就好像學習一門編程語言同樣,第一行代碼都是打印Hello World。Deep Learning的Hello World的例子通常是兩個,一個是Mnist數據集的分類,另外一個就是Word2Vec找類似詞。因爲Word2Vec並非嚴格意義上的深度神經網絡,所以這裏就用Lenet網絡處理Mnist數據集來做爲Deep Learning的Hello World。Mnist是開源的28x28的黑白手寫體數字圖片集(http://yann.lecun.com/exdb/mnist/),其中包含6W張訓練圖片和1W張測試圖片。至於Lenet的相關結構描述,能夠參考這個連接:http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf。下面就詳細講述下,利用Deeplearning4j如何進行建模、訓練和預測評估。git

    首先,咱們創建一個maven項目。而後在pom文件里加入Deeplearning4j的一些相關依賴。最主要的有三個:deeplearning4j-core,datavec,nd4j。deeplearning4j-core是神經網絡結構實現的代碼,nd4j是用於作張量運算的庫,經過JavaCPP來調用編譯好的C++庫(可選:ATAL, MKL, 和OpenBLAS),datavec則主要負責數據的ETL。具體可見代碼:程序員

<properties>  
  <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>  
  <nd4j.version>0.7.1</nd4j.version>  
  <dl4j.version>0.7.1</dl4j.version>  
  <datavec.version>0.7.1</datavec.version>  
  <scala.binary.version>2.10</scala.binary.version>  
</properties>  
<dependencies>  
<dependency>  
    <groupId>org.nd4j</groupId>  
    <artifactId>nd4j-native</artifactId>   
    <version>${nd4j.version}</version>  
</dependency>  
<dependency>  
    <groupId>org.deeplearning4j</groupId>  
    <artifactId>dl4j-spark_2.11</artifactId>  
    <version>${dl4j.version}</version>  
</dependency>  
     <dependency>  
          <groupId>org.datavec</groupId>  
          <artifactId>datavec-spark_${scala.binary.version}</artifactId>  
          <version>${datavec.version}</version>  
    </dependency>  
      <dependency>  
   <groupId>org.deeplearning4j</groupId>  
   <artifactId>deeplearning4j-core</artifactId>  
   <version>${dl4j.version}</version>  
</dependency>  
</dependencies>  
  1.     這些依賴裏面有和Spark相關的,主要是跑Spark要用到。不過沒有關係,先引進來便可。

    接着,咱們解釋下面的代碼。咱們先要定義一些具體的參數,好比分類的個數(outputNum),mini-batch的數量(batchSize)等等,具體在圖中已經作了註釋。須要說明的是MnistDataSetIterator這個迭代器類。這個類實際上是一個讀取二進制Mnist數據集的high-level的封裝。經過debug咱們能夠發現,其中包括從網絡中下載Mnist數據集,讀取數據和標註,再構建迭代器的過程。在源碼中,默認將下載的文件放在系統的user.home目錄下,具體每一個人不一樣會有所不一樣。因爲我本身所處的環境網絡不咋的,因此頗有可能在利用這種high-level的接口的時候,由於下載Mnist數據失敗而拋出異常,最終沒法訓練。因此,你們能夠先自行下載好這些數據,而後按照源碼的要求,放到相應的目錄下並根據源碼正確命名文件,那這樣就依然能夠利用這種high-level的接口。具體須要參考的是MnistDataFetcher類中相關代碼。github

int nChannels = 1;      //black & white picture, 3 if color image
        int outputNum = 10;     //number of classification
        int batchSize = 64;     //mini batch size for sgd
        int nEpochs = 10;       //total rounds of training
        int iterations = 1;     //number of iteration in each traning round
        int seed = 123;         //random seed for initialize weights

        log.info("Load data....");
        DataSetIterator mnistTrain = null;
        DataSetIterator mnistTest = null;

        mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
        mnistTest = new MnistDataSetIterator(batchSize, false, 12345);

 

當咱們正確讀取數據後,咱們須要定義具體的神經網絡結構,這裏我用的是Lenet,Deeplearning4j的實現參考了官網(https://github.com/deeplearning4j/dl4j-examples)的例子。具體代碼以下:算法

MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations)
                .regularization(true).l2(0.0005)
                .learningRate(0.01)//.biasLearningRate(0.02)
                //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
                .weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list()
                .layer(0, new ConvolutionLayer.Builder(5, 5)
                        //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
                        .nIn(nChannels)
                        .stride(1, 1)
                        .nOut(20)
                        .activation("identity")
                        .build())
                .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2,2)
                        .stride(2,2)
                        .build())
                .layer(2, new ConvolutionLayer.Builder(5, 5)
                        //Note that nIn need not be specified in later layers
                        .stride(1, 1)
                        .nOut(50)
                        .activation("identity")
                        .build())
                .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2,2)
                        .stride(2,2)
                        .build())
                .layer(4, new DenseLayer.Builder().activation("relu")
                        .nOut(500).build())
                .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(outputNum)
                        .activation("softmax")
                        .build())
                .backprop(true).pretrain(false)
                .cnnInputSize(28, 28, 1);
        // The builder needs the dimensions of the image along with the number of channels. these are 28x28 images in one channel
        //new ConvolutionLayerSetup(builder,28,28,1);

        MultiLayerConfiguration conf = builder.build();
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();        
        model.setListeners(new ScoreIterationListener(1));         // a listener which can print loss function score after each iteration

能夠發現,神經網絡須要定義不少的超參數,學習率、正則化係數、卷積核的大小、激勵函數等都是須要人爲設定的。不一樣的超參數,對結果的影響很大,其實後來發現,不少時間都花在數據處理和調參方面。畢竟本身設計網絡的能力有限,通常都是參考大牛的論文,而後本身照葫蘆畫瓢地實現。這裏實現的Lenet的結構是:卷積-->下采樣-->卷積-->下采樣-->全鏈接。和原論文的結構基本一致。卷積核的大小也是參考的原論文。具體細節可參考以前發的論文連接。這裏咱們設置了一個Score的監聽事件,主要是能夠在訓練的時候獲取每一次權重更新後損失函數的收斂狀況。後面一會有截圖。編程

定義完網絡結構以後,咱們就能夠對以前讀取的數據進行訓練和分類準確性評估。先看下代碼:bash

for( int i = 0; i < nEpochs; ++i ) {  
    model.fit(mnistTrain);  
    log.info("*** Completed epoch " + i + "***");  
  
    log.info("Evaluate model....");  
    Evaluation eval = new Evaluation(outputNum);  
    while(mnistTest.hasNext()){  
        DataSet ds = mnistTest.next();            
        INDArray output = model.output(ds.getFeatureMatrix(), false);  
        eval.eval(ds.getLabels(), output);  
    }  
    log.info(eval.stats());  
    mnistTest.reset();  
} 

    相信這部分是比較容易理解的。每訓練完一輪後,咱們會對測試集合進行評估,而後打印出相似下面的結果。圖中的上半部分是具體分類的統計,包括分對的和分錯的圖片數量均可以看獲得。而後,咱們耐心等待一段時間,能夠看到通過10輪訓練的Lenet對於Mnist數據集的分類準確率達到99%以下:網絡

Examples labeled as 0 classified by model as 0: 974 times
Examples labeled as 0 classified by model as 6: 2 times
Examples labeled as 0 classified by model as 7: 2 times
Examples labeled as 0 classified by model as 8: 1 times
Examples labeled as 0 classified by model as 9: 1 times
Examples labeled as 1 classified by model as 0: 1 times
Examples labeled as 1 classified by model as 1: 1128 times
Examples labeled as 1 classified by model as 2: 1 times
Examples labeled as 1 classified by model as 3: 2 times
Examples labeled as 1 classified by model as 5: 1 times
Examples labeled as 1 classified by model as 6: 2 times
Examples labeled as 2 classified by model as 2: 1026 times
Examples labeled as 2 classified by model as 4: 1 times
Examples labeled as 2 classified by model as 6: 1 times
Examples labeled as 2 classified by model as 7: 3 times
Examples labeled as 2 classified by model as 8: 1 times
Examples labeled as 3 classified by model as 0: 1 times
Examples labeled as 3 classified by model as 1: 1 times
Examples labeled as 3 classified by model as 2: 1 times
Examples labeled as 3 classified by model as 3: 998 times
Examples labeled as 3 classified by model as 5: 3 times
Examples labeled as 3 classified by model as 7: 1 times
Examples labeled as 3 classified by model as 8: 4 times
Examples labeled as 3 classified by model as 9: 1 times
Examples labeled as 4 classified by model as 2: 1 times
Examples labeled as 4 classified by model as 4: 973 times
Examples labeled as 4 classified by model as 6: 2 times
Examples labeled as 4 classified by model as 7: 1 times
Examples labeled as 4 classified by model as 9: 5 times
Examples labeled as 5 classified by model as 0: 2 times
Examples labeled as 5 classified by model as 3: 4 times
Examples labeled as 5 classified by model as 5: 882 times
Examples labeled as 5 classified by model as 6: 1 times
Examples labeled as 5 classified by model as 7: 1 times
Examples labeled as 5 classified by model as 8: 2 times
Examples labeled as 6 classified by model as 0: 4 times
Examples labeled as 6 classified by model as 1: 2 times
Examples labeled as 6 classified by model as 4: 1 times
Examples labeled as 6 classified by model as 5: 4 times
Examples labeled as 6 classified by model as 6: 945 times
Examples labeled as 6 classified by model as 8: 2 times
Examples labeled as 7 classified by model as 1: 5 times
Examples labeled as 7 classified by model as 2: 3 times
Examples labeled as 7 classified by model as 3: 1 times
Examples labeled as 7 classified by model as 7: 1016 times
Examples labeled as 7 classified by model as 8: 1 times
Examples labeled as 7 classified by model as 9: 2 times
Examples labeled as 8 classified by model as 0: 1 times
Examples labeled as 8 classified by model as 3: 1 times
Examples labeled as 8 classified by model as 5: 2 times
Examples labeled as 8 classified by model as 7: 2 times
Examples labeled as 8 classified by model as 8: 966 times
Examples labeled as 8 classified by model as 9: 2 times
Examples labeled as 9 classified by model as 3: 1 times
Examples labeled as 9 classified by model as 4: 2 times
Examples labeled as 9 classified by model as 5: 4 times
Examples labeled as 9 classified by model as 6: 1 times
Examples labeled as 9 classified by model as 7: 5 times
Examples labeled as 9 classified by model as 8: 3 times
Examples labeled as 9 classified by model as 9: 993 times


==========================Scores========================================
 Accuracy:        0.9901
 Precision:       0.99
 Recall:          0.99
 F1 Score:        0.99
========================================================================
[main] INFO cv.LenetMnistExample - ****************Example finished********************

    由於圖傳不上去,我就直接粘帖告終果。從中咱們看到最終的一個準確率,還有就是哪些圖片是分類正確的,哪些是分類錯誤的。固然咱們能夠經過增長訓練的輪次還有調超參數來進一步優化,不過實際上這樣的結果已經能夠拿到生產上去用了。app

    總結一下。其實包括我本身在內的不少人都對深度學習不瞭解,記得當時看csdn上寫的有關深度學習的博客的時候,都以爲本身不可能達到那種水平。但其實,咱們都忽略了一點,深度學習自身再複雜,它也是一個算法模型,也是一種機器學習。雖然它比感知機、邏輯迴歸等模型複雜不少(其實邏輯迴歸可看做神經網絡中的一個神經元,充當的是激勵函數的做用,相似的激勵函數不少,如tanh,relu等),但終究用它的目的依然是完成迴歸、分類、壓縮數據等任務。因此第一步嘗試仍是挺重要的。固然,咱們不可能從複雜的模型開始,一開始就跟上當下最流行的模型,因此就從Mnist識別的例子開始,找找感受。之後會寫一些用Deeplearning4j在Spark的案例,也仍是從Mnist開始。分享的同時本身也複習一下。。。

相關文章
相關標籤/搜索