LSTM java 實現

因爲實驗室事情緣故,須要將Python寫的神經網絡轉成Java版本的,可是python中的numpy等啥包也不知道在Java裏面對應的是什麼工具,因此索性直接尋找一個現成可用的Java神經網絡框架,因而就找到了JOONE,JOONE是一個神經網絡的開源框架,使用的是BP算法進行迭代計算參數,使用起來比較方便也比較實用,下面介紹一下JOONE的一些使用方法。java

 

JOONE須要使用一些外部的依賴包,這在官方網站上有,也能夠在這裏下載。將所需的包引入工程以後,就能夠進行編碼實現了。python

 

首先看下完整的程序,這個是上面那個超連接給出的程序,應該是官方給出的一個示例吧,由於好多文章都用這個,這實際上是神經網絡訓練一個異或計算器:算法

 

[java]  view plain  copy
 
  1. import org.joone.engine.*;  
  2. import org.joone.engine.learning.*;  
  3. import org.joone.io.*;  
  4. import org.joone.net.*;  
  5.   
  6.   
  7. /* 
  8.  *  
  9.  * JOONE實現 
  10.  *  
  11.  * */  
  12. public class XOR_using_NeuralNet implements NeuralNetListener  
  13. {  
  14.     private NeuralNet nnet = null;  
  15.     private MemoryInputSynapse inputSynapse, desiredOutputSynapse;  
  16.     LinearLayer input;  
  17.     SigmoidLayer hidden, output;  
  18.     boolean singleThreadMode = true;  
  19.   
  20.     // XOR input  
  21.     private double[][] inputArray = new double[][]  
  22.     {  
  23.     { 0.0, 0.0 },  
  24.     { 0.0, 1.0 },  
  25.     { 1.0, 0.0 },  
  26.     { 1.0, 1.0 } };  
  27.   
  28.     // XOR desired output  
  29.     private double[][] desiredOutputArray = new double[][]  
  30.     {  
  31.     { 0.0 },  
  32.     { 1.0 },  
  33.     { 1.0 },  
  34.     { 0.0 } };  
  35.   
  36.     /** 
  37.      * @param args 
  38.      *            the command line arguments 
  39.      */  
  40.     public static void main(String args[])  
  41.     {  
  42.         XOR_using_NeuralNet xor = new XOR_using_NeuralNet();  
  43.   
  44.         xor.initNeuralNet();  
  45.         xor.train();  
  46.         xor.interrogate();  
  47.     }  
  48.   
  49.     /** 
  50.      * Method declaration 
  51.      */  
  52.     public void train()  
  53.     {  
  54.   
  55.         // set the inputs  
  56.         inputSynapse.setInputArray(inputArray);  
  57.         inputSynapse.setAdvancedColumnSelector(" 1,2 ");  
  58.         // set the desired outputs  
  59.         desiredOutputSynapse.setInputArray(desiredOutputArray);  
  60.         desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");  
  61.   
  62.         // get the monitor object to train or feed forward  
  63.         Monitor monitor = nnet.getMonitor();  
  64.   
  65.         // set the monitor parameters  
  66.         monitor.setLearningRate(0.8);  
  67.         monitor.setMomentum(0.3);  
  68.         monitor.setTrainingPatterns(inputArray.length);  
  69.         monitor.setTotCicles(5000);  
  70.         monitor.setLearning(true);  
  71.   
  72.         long initms = System.currentTimeMillis();  
  73.         // Run the network in single-thread, synchronized mode  
  74.         nnet.getMonitor().setSingleThreadMode(singleThreadMode);  
  75.         nnet.go(true);  
  76.         System.out.println(" Total time=  "  
  77.                 + (System.currentTimeMillis() - initms) + "  ms ");  
  78.     }  
  79.   
  80.     private void interrogate()  
  81.     {  
  82.   
  83.         double[][] inputArray = new double[][]  
  84.         {  
  85.         { 1.0, 1.0 } };  
  86.         // set the inputs  
  87.         inputSynapse.setInputArray(inputArray);  
  88.         inputSynapse.setAdvancedColumnSelector(" 1,2 ");  
  89.         Monitor monitor = nnet.getMonitor();  
  90.         monitor.setTrainingPatterns(4);  
  91.         monitor.setTotCicles(1);  
  92.         monitor.setLearning(false);  
  93.         MemoryOutputSynapse memOut = new MemoryOutputSynapse();  
  94.         // set the output synapse to write the output of the net  
  95.   
  96.         if (nnet != null)  
  97.         {  
  98.             nnet.addOutputSynapse(memOut);  
  99.             System.out.println(nnet.check());  
  100.             nnet.getMonitor().setSingleThreadMode(singleThreadMode);  
  101.             nnet.go();  
  102.   
  103.             for (int i = 0; i < 4; i++)  
  104.             {  
  105.                 double[] pattern = memOut.getNextPattern();  
  106.                 System.out.println(" Output pattern # " + (i + 1) + " = "  
  107.                         + pattern[0]);  
  108.             }  
  109.             System.out.println(" Interrogating Finished ");  
  110.         }  
  111.     }  
  112.   
  113.     /** 
  114.      * Method declaration 
  115.      */  
  116.     protected void initNeuralNet()  
  117.     {  
  118.   
  119.         // First create the three layers  
  120.         input = new LinearLayer();  
  121.         hidden = new SigmoidLayer();  
  122.         output = new SigmoidLayer();  
  123.   
  124.         // set the dimensions of the layers  
  125.         input.setRows(2);  
  126.         hidden.setRows(3);  
  127.         output.setRows(1);  
  128.   
  129.         input.setLayerName(" L.input ");  
  130.         hidden.setLayerName(" L.hidden ");  
  131.         output.setLayerName(" L.output ");  
  132.   
  133.         // Now create the two Synapses  
  134.         FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */  
  135.         FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */  
  136.   
  137.         // Connect the input layer whit the hidden layer  
  138.         input.addOutputSynapse(synapse_IH);  
  139.         hidden.addInputSynapse(synapse_IH);  
  140.   
  141.         // Connect the hidden layer whit the output layer  
  142.         hidden.addOutputSynapse(synapse_HO);  
  143.         output.addInputSynapse(synapse_HO);  
  144.   
  145.         // the input to the neural net  
  146.         inputSynapse = new MemoryInputSynapse();  
  147.   
  148.         input.addInputSynapse(inputSynapse);  
  149.   
  150.         // The Trainer and its desired output  
  151.         desiredOutputSynapse = new MemoryInputSynapse();  
  152.   
  153.         TeachingSynapse trainer = new TeachingSynapse();  
  154.   
  155.         trainer.setDesired(desiredOutputSynapse);  
  156.   
  157.         // Now we add this structure to a NeuralNet object  
  158.         nnet = new NeuralNet();  
  159.   
  160.         nnet.addLayer(input, NeuralNet.INPUT_LAYER);  
  161.         nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);  
  162.         nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);  
  163.         nnet.setTeacher(trainer);  
  164.         output.addOutputSynapse(trainer);  
  165.         nnet.addNeuralNetListener(this);  
  166.     }  
  167.   
  168.     public void cicleTerminated(NeuralNetEvent e)  
  169.     {  
  170.     }  
  171.   
  172.     public void errorChanged(NeuralNetEvent e)  
  173.     {  
  174.         Monitor mon = (Monitor) e.getSource();  
  175.         if (mon.getCurrentCicle() % 100 == 0)  
  176.             System.out.println(" Epoch:  "  
  177.                     + (mon.getTotCicles() - mon.getCurrentCicle()) + "  RMSE: "  
  178.                     + mon.getGlobalError());  
  179.     }  
  180.   
  181.     public void netStarted(NeuralNetEvent e)  
  182.     {  
  183.         Monitor mon = (Monitor) e.getSource();  
  184.         System.out.print(" Network started for  ");  
  185.         if (mon.isLearning())  
  186.             System.out.println(" training. ");  
  187.         else  
  188.             System.out.println(" interrogation. ");  
  189.     }  
  190.   
  191.     public void netStopped(NeuralNetEvent e)  
  192.     {  
  193.         Monitor mon = (Monitor) e.getSource();  
  194.         System.out.println(" Network stopped. Last RMSE= "  
  195.                 + mon.getGlobalError());  
  196.     }  
  197.   
  198.     public void netStoppedError(NeuralNetEvent e, String error)  
  199.     {  
  200.         System.out.println(" Network stopped due the following error:  "  
  201.                 + error);  
  202.     }  
  203.   
  204. }  

 

 

如今我會逐步解釋上面的程序。數組

 【1】 從main方法開始提及,首先第一步新建一個對象:網絡

[java]  view plain  copy
 
  1. XOR_using_NeuralNet xor = new XOR_using_NeuralNet();  

【2】而後初始化神經網絡:數據結構

 

[java]  view plain  copy
 
  1. xor.initNeuralNet();  

初始化神經網絡的方法中:多線程

[java]  view plain  copy
 
  1. // First create the three layers  
  2.         input = new LinearLayer();  
  3.         hidden = new SigmoidLayer();  
  4.         output = new SigmoidLayer();  
  5.   
  6.         // set the dimensions of the layers  
  7.         input.setRows(2);  
  8.         hidden.setRows(3);  
  9.         output.setRows(1);  
  10.   
  11.         input.setLayerName(" L.input ");  
  12.         hidden.setLayerName(" L.hidden ");  
  13.         output.setLayerName(" L.output ");  

 

 

上面代碼解釋:框架

input=new LinearLayer()是新建一個輸入層,由於神經網絡的輸入層並無訓練參數,因此使用的是線性層;函數

hidden = new SigmoidLayer();這裏是新建一個隱含層,使用sigmoid函數做爲激勵函數,固然你也能夠選擇其餘的激勵函數,如softmax激勵函數工具

output則是新建一個輸出層

以後的三行代碼是創建輸入層、隱含層、輸出層的神經元個數,這裏表示輸入層爲2個神經元,隱含層是3個神經元,輸出層是1個神經元

最後的三行代碼是給每一個輸出層取一個名字。

[java]  view plain  copy
 
  1. // Now create the two Synapses  
  2.         FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */  
  3.         FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */  
  4.   
  5.         // Connect the input layer whit the hidden layer  
  6.         input.addOutputSynapse(synapse_IH);  
  7.         hidden.addInputSynapse(synapse_IH);  
  8.   
  9.         // Connect the hidden layer whit the output layer  
  10.         hidden.addOutputSynapse(synapse_HO);  
  11.         output.addInputSynapse(synapse_HO);  

 

上面代碼解釋:

 

上面代碼的主要做用是將三個層鏈接起來,synapse_IH用來鏈接輸入層和隱含層,synapse_HO用來鏈接隱含層和輸出層

[java]  view plain  copy
 
  1. // the input to the neural net  
  2.         inputSynapse = new MemoryInputSynapse();  
  3.   
  4.         input.addInputSynapse(inputSynapse);  
  5.   
  6.         // The Trainer and its desired output  
  7.         desiredOutputSynapse = new MemoryInputSynapse();  
  8.   
  9.         TeachingSynapse trainer = new TeachingSynapse();  
  10.   
  11.         trainer.setDesired(desiredOutputSynapse);  

 

上面代碼解釋: 

 

上面的代碼是在訓練的時候指定輸入層的數據和目的輸出的數據,

 inputSynapse = new MemoryInputSynapse();這裏指的是使用了從內存中輸入數據的方法,指的是輸入層輸入數據,固然還有從文件輸入的方法,這點在文章後面再談。同理,desiredOutputSynapse = new MemoryInputSynapse();也是從內存中輸入數據,指的是從輸入層應該輸出的數據

[java]  view plain  copy
 
  1. // Now we add this structure to a NeuralNet object  
  2.         nnet = new NeuralNet();  
  3.   
  4.         nnet.addLayer(input, NeuralNet.INPUT_LAYER);  
  5.         nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);  
  6.         nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);  
  7.         nnet.setTeacher(trainer);  
  8.         output.addOutputSynapse(trainer);  
  9.         nnet.addNeuralNetListener(this);  

上面代碼解釋:

 

這段代碼指的是將以前初始化的構件鏈接成一個神經網絡,NeuralNet是JOONE提供的類,主要是鏈接各個神經層,最後一個nnet.addNeuralNetListener(this);這個做用是對神經網絡的訓練過程進行監聽,由於這個類實現了NeuralNetListener這個接口,這個接口有一些方法,能夠實現觀察神經網絡訓練過程,有助於參數調整。

【3】而後咱們來看一下train這個方法:

[java]  view plain  copy
 
  1. inputSynapse.setInputArray(inputArray);  
  2.         inputSynapse.setAdvancedColumnSelector(" 1,2 ");  
  3.         // set the desired outputs  
  4.         desiredOutputSynapse.setInputArray(desiredOutputArray);  
  5.         desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");  

 

上面代碼解釋:

 

inputSynapse.setInputArray(inputArray);這個方法是初始化輸入層數據,也就是指定輸入層數據的內容,inputArray是程序中給定的二維數組,這也就是爲何以前初始化神經網絡的時候使用的是MemoryInputSynapse,表示從內存中讀取數據

inputSynapse.setAdvancedColumnSelector(" 1,2 ");這個表示的是輸入層數據使用的是inputArray的前兩列數據。

desiredOutputSynapse這個也同理

[java]  view plain  copy
 
  1. Monitor monitor = nnet.getMonitor();  
  2.   
  3.         // set the monitor parameters  
  4.         monitor.setLearningRate(0.8);  
  5.         monitor.setMomentum(0.3);  
  6.         monitor.setTrainingPatterns(inputArray.length);  
  7.         monitor.setTotCicles(5000);  
  8.         <span style="line-height: 1.5;">monitor.setLearning(true);  

 上面代碼解釋:

這個monitor類也是JOONE框架提供的,主要是用來調節神經網絡的參數,monitor.setLearningRate(0.8);是用來設置神經網絡訓練的步長參數,步長越大,神經網絡梯度降低的速度越快,monitor.setTrainingPatterns(inputArray.length);這個是設置神經網絡的輸入層的訓練數據大小size,這裏使用的是數組的長度;monitor.setTotCicles(5000);這個指的是設置迭代數目;monitor.setLearning(true);這個true表示是在訓練過程。

[java]  view plain  copy
 
  1. nnet.getMonitor().setSingleThreadMode(singleThreadMode);  
  2.         nnet.go(true);  

上面代碼解釋:

 

nnet.getMonitor().setSingleThreadMode(singleThreadMode);這個指的是是否是使用多線程,可是我不太清楚這裏的多線程指的是什麼意思

nnet.go(true)表示的是開始訓練。

【4】最後來看一下interrogate方法

[java]  view plain  copy
 
  1. double[][] inputArray = new double[][]  
  2.         {  
  3.         { 1.0, 1.0 } };  
  4.         // set the inputs  
  5.         inputSynapse.setInputArray(inputArray);  
  6.         inputSynapse.setAdvancedColumnSelector(" 1,2 ");  
  7.         Monitor monitor = nnet.getMonitor();  
  8.         monitor.setTrainingPatterns(4);  
  9.         monitor.setTotCicles(1);  
  10.         monitor.setLearning(false);  
  11.         MemoryOutputSynapse memOut = new MemoryOutputSynapse();  
  12.         // set the output synapse to write the output of the net  
  13.   
  14.         if (nnet != null)  
  15.         {  
  16.             nnet.addOutputSynapse(memOut);  
  17.             System.out.println(nnet.check());  
  18.             nnet.getMonitor().setSingleThreadMode(singleThreadMode);  
  19.             nnet.go();  
  20.   
  21.             for (int i = 0; i < 4; i++)  
  22.             {  
  23.                 double[] pattern = memOut.getNextPattern();  
  24.                 System.out.println(" Output pattern # " + (i + 1) + " = "  
  25.                         + pattern[0]);  
  26.             }  
  27.             System.out.println(" Interrogating Finished ");  
  28.         }  

 

這個方法至關於測試方法,這裏的inputArray是測試數據, 注意這裏須要設置monitor.setLearning(false);,由於這不是訓練過程,並不須要學習,monitor.setTrainingPatterns(4);這個是指測試的數量,4表示有4個測試數據(雖然這裏只有一個)。這裏還給nnet添加了一個輸出層數據對象,這個對象mmOut是初始測試結果,注意到以前咱們初始化神經網絡的時候並無給輸出層指定數據對象,由於那個時候咱們在訓練,並且指定了trainer做爲目的輸出。

 

 

接下來就是輸出結果數據了,pattern的個數和輸出層的神經元個數同樣大,這裏輸出層神經元的個數是1,因此pattern大小爲1.

 

【5】咱們看一下測試結果:

 

[java]  view plain  copy
 
  1. Output pattern # 1 = 0.018303527517809233  

 

 

表示輸出結果爲0.01,根據sigmoid函數特性,咱們獲得的輸出是0,和預期結果一致。若是輸出層神經元個數大於1,那麼輸出值將會有多個,由於輸出層結果是0|1離散值,因此咱們取輸出最大的那個神經元的輸出值取爲1,其餘爲0

 

 

 

【6】最後咱們來看一下神經網絡訓練過程當中的一些監聽函數:

cicleTerminated:每一個循環結束後輸出的信息

errorChanged:神經網絡錯誤率變化時候輸出的信息

netStarted:神經網絡開始運行的時候輸出的信息

netStopped:神經網絡中止的時候輸出的信息

 

【7】好了,JOONE基本上內容就是這些。還有一些額外東西須要說明:

 

1,從文件中讀取數據構建神經網絡

2.如何保存訓練好的神經網絡到文件夾中,只要測試的時候直接load到內存中就行,而不用每次都須要訓練。

 

 

【8】先看第一個問題:

從文件中讀取數據:

文件的格式:

0;0;0

1;0;1

1;1;0

0;1;1

 

中間使用分號隔開,使用方法以下,也就是把上文的MemoryInputSynapse換成FileInputSynapse便可。

[java]  view plain  copy
 
  1. fileInputSynapse = new FileInputSynapse();  
  2. input.addInputSynapse(fileInputSynapse);  
  3. fileDisireOutputSynapse = new FileInputSynapse();  
  4. TeachingSynapse trainer = new TeachingSynapse();  
  5. trainer.setDesired(fileDisireOutputSynapse);  

 咱們看下文件是如何輸出數據的:

[java]  view plain  copy
 
  1. private File inputFile = new File(Constants.TRAIN_WORD_VEC_PATH);  
  2. fileInputSynapse.setInputFile(inputFile);  
  3. fileInputSynapse.setFirstCol(2);//使用文件的第2列到第3列做爲輸出層輸入  
  4. fileInputSynapse.setLastCol(3);  

 

[java]  view plain  copy
 
  1. fileDisireOutputSynapse.setInputFile(inputFile);  
  2. fileDisireOutputSynapse.setFirstCol(1);//使用文件的第1列做爲輸出數據  
  3. fileDisireOutputSynapse.setLastCol(1);  

 

 其他的代碼和上文的是同樣的。

 

 

【9】而後看第二個問題:

如何保存神經網絡

其實很簡單,直接序列化nnet對象就好了,而後讀取該對象就是java的反序列化,這個就很少作介紹了,比較簡單。可是須要說明的是,保存神經網絡的時機必定是在神經網絡訓練完畢後,可使用下面代碼:

[java]  view plain  copy
 
    1. public void netStopped(NeuralNetEvent e) {  
    2.         Monitor mon = (Monitor) e.getSource();  
    3.         try {  
    4.             if (mon.isLearning()) {  
    5.                 saveModel(nnet); //序列化對象  
    6.             }  
    7.         } catch (IOException ee) {  
    8.             // TODO Auto-generated catch block  
    9.             ee.printStackTrace();  
    10.         }  
相關文章
相關標籤/搜索