決策樹分類器-Java實現

決策樹模型,其基本思想相似於if else的結構,即知足什麼條件則將它斷定爲某一類,而這裏的決策樹的深度就相似於if else的深度。  node

決策樹的問題焦點在於,對於一個擁有多維數據特徵的數據點,如何選擇合適的分類依據。例如一隻雞(兩條腿,有翅膀,沒有腳蹼。。。),一隻鴨(兩條腿,有翅膀,有腳蹼。。),等等,如今來了一隻奇怪的生物(兩條腿,有翅膀,沒有腳蹼。。),若是先根據腿或翅膀來判斷的話,根本沒法判斷它屬於哪種生物,而若是根據腳蹼來判斷的話,馬上就能分辨出來。從這個例子中,想表達的就是決策樹若是去抉擇一種最合適的特徵來獲得不一樣的判決類。數組

本文是基於數據集信息熵最小的原則,來肯定這種樹的生長規則的。信息熵的背景,很少說,簡而言之,越有序的系統熵越小,越無序的系統熵越大。其計算公式以下:測試

H(x) = E[I(xi)] = E[ log(2,1/p(xi)) ] = -∑p(xi)log(2,p(xi)) (i=1,2,..n)  設計

其中p(xi)爲xi樣本在x整體中的取值機率(或統計學中的頻率)。  對象

在給出具體實現代碼以前,我先給出此處用到的樹結構。 遞歸

 /** 接口

 * Created by Song on 2017/1/4. element

 * 樹節點,可序列化存儲 get

 */ it

public class Node implements Serializable{ 

      public Object element; 

      public Map<Object,Node> child; 

}

之因此這樣設計,是基於此處具體的應用環境。e在此應用環境中,element爲String類型的特徵名稱,而Map中的每一個鍵值對,鍵名錶明着判決條件(鏈接兩個節點的線的標稱),值表明着下一個節點。 

下面再給出,Java中對象序列化存儲的部分代碼(在測試時,我註釋掉了),用於在經過訓練集獲得決策樹結構以後,將該樹保存在文件中,而不須要,每次都從新訓練獲得決策樹結構。 Node root = handler.createTree(dataSet,featurelabels,labelStr); 

//樹結構存儲 

ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File("E:\\dectree.txt"))); 

      oos.writeObject(root); 

      oos.flush(); 

      oos.close(); 

//樹結構讀取 

              ObjectInputStream ois = new ObjectInputStream(new FileInputStream(new File("E:\\dectree.txt"))); 

              Node tree = (Node) ois.readObject(); 

下面是決策樹分類器的具體實現代碼: 

/** 

 * Created by Song on 2017/1/3. 

 * 決策樹 

 */

 public class DectreeHandler { 

        /** 

 * 計算數據集的香農熵 

 * @param dataSet 數據集(最後一列爲分類信息) 

 * @return 香農熵 

 */ 

private static double calcShannonEnt(Matrix dataSet){ 

       int m = dataSet.getRowDimension(); 

       int n = dataSet.getColumnDimension(); 

      double currentLabel = 0; 

      double shannonEnt = 0; 

      double rate = 0; 

      HashMap<Double,Integer> labelCounts = new HashMap<Double, Integer>(); 

      //統計各種出現次數 

     for(int i=0;i<m;i++){ 

           currentLabel = dataSet.get(i,n-1); 

           if(!labelCounts.containsKey(currentLabel)) 

                  labelCounts.put(currentLabel,0);           labelCounts.put(currentLabel,labelCounts.get(currentLabel)+1); 

 } 

 //計算總體香農熵 

 for(double key:labelCounts.keySet()){ 

       rate =labelCounts.get(key)/(float)m; 

      shannonEnt -= rate*Math.log(rate)/Math.log(2); 

    } 

 return shannonEnt; 

 } 


 /** 

 * 劃分數據集(當第axis維數據等於value時,提取出該行數據()去掉第axis維) 

 * @param dataSet 數據集(最後一列爲分類信息) 

 * @param axis 待匹配列(從0開始) 

 * @param value 待匹配列值 

 * @return 

 */ 

    private Matrix splitDataSet(Matrix dataSet,int axis,double value){ 

           Matrix retDataSet = new Matrix(0,dataSet.getColumnDimension()-1); 

           Matrix temp = new Matrix(1,dataSet.getColumnDimension()-1); 

           for(int i=0;i<dataSet.getRowDimension();i++){ 

                 if(dataSet.get(i,axis)==value){ 

                 int k = 0; 

                for(int j=0;j<dataSet.getColumnDimension();j++){ 

                     if(j!=axis) 

                          temp.set(0,k++,dataSet.get(i,j)); 

                } 

               retDataSet = retDataSet.expand(temp,false); 

         } 

     } 

     return retDataSet; 

 } 

 /** 

 * 選擇最好的數據集劃分方式 

 * @param dataSet 數據集(最後一列爲分類信息) 

 * @return 香農熵最小時(增益最大)的特徵值序號 

 */ 

private int chooseBestFeatureToSplit(Matrix dataSet){ 

      //特徵數 

     int featureNums= dataSet.getColumnDimension()-1; 

     //數據集的香農熵 

    double baseEntropy = calcShannonEnt(dataSet); 

    double bestInfoGain = 0.0; 

    int bestFeature = -1; 

   double newEntropy = 0.0; 

   Set<Double> tempFeatureSet = new HashSet<Double>(); 

   for(int j=0;j<featureNums;j++){ 

          //取數據集中的第i列Set 

         for(int i=0;i<dataSet.getRowDimension();i++){       tempFeatureSet.add(dataSet.get(i,j)); 

Matrix subMatrix; 

double prob=0; 

double infoGain=0; 

newEntropy = 0.0; 

for(double val:tempFeatureSet){ 

     subMatrix = splitDataSet(dataSet,j,val); 

     prob = subMatrix.getRowDimension()/(float)dataSet.getRowDimension();            newEntropy += prob*calcShannonEnt(subMatrix); 

 } 

 infoGain = baseEntropy-newEntropy; 

if(infoGain>bestInfoGain){ 

       bestInfoGain = infoGain; 

       bestFeature = j; 

  } 

 return bestFeature; 

 } 


 /** 

 * 返回出現次數最多的類 

 * @param labels 每一個樣本所屬的類矩陣 

 * @return 出現次數最多的類 

 */ 

   private double majorityCnt(Matrix labels){ 

          Map<Double,Integer> classCount = new HashMap<Double, Integer>(); 

          for(int i=0;i<labels.getRowDimension();i++){                    if(!classCount.containsKey(labels.get(i,0))) 

      classCount.put(labels.get(i,0),0); classCount.put(labels.get(i,0),classCount.get(labels.get(i,0))+1); 

int count =0; 

double label = -1; 

for(double key:classCount.keySet()){ 

     if(classCount.get(key)>count){ 

         count = classCount.get(key); 

         label = key; 

     }

   } 

   return label; 


 /** 

 * 遞歸建立決策樹 

 * @param dataSet 數據集(最後一列爲類) 

 * @param featurelabels 各列特徵名 

 * @param labelStr 類名 

 * @return 決策樹 

 */ public Node createTree(Matrix dataSet,String [] featurelabels,String [] labelStr) {          double[] classList = new double[dataSet.getRowDimension()]; 

 for (int i = 0; i < dataSet.getRowDimension(); i++) { 

       classList[i] = dataSet.get(i, dataSet.getColumnDimension() - 1); 

 } 

 int num = 0; 

 for (double cla : classList) { 

        if (cla == classList[0]) num++; 

 } 

 if (num == classList.length) { 

     Node node = new Node(); 

     node.element=labelStr[(int)classList[0]]; 

     return node; 

 } //若爲同一類,則直接返回該類 

 if(dataSet.getColumnDimension()==1) { 

          Node node = new Node(); 

          node.element=majorityCnt(new Matrix(classList,1).transpose()); 

          return node; 

 } 

 double bestFeature = chooseBestFeatureToSplit(dataSet); 

 String bestFeatureLabel = featurelabels[(int)bestFeature]; 


 Node root = new Node(); 

 root.element = bestFeatureLabel; 

 String [] subLabels = del(featurelabels,bestFeatureLabel); 

 Set<Double> uniqFeatureVals = new HashSet<Double>(); 

 for(int i=0;i<dataSet.getRowDimension();i++){ 

       uniqFeatureVals.add(dataSet.get(i,(int)bestFeature)); 

 } 


 Map<Object,Node> child = new HashMap<Object, Node>(); 

 for(double val:uniqFeatureVals){ 

          child.put(val,createTree(splitDataSet(dataSet(int)bestFeature,val),subLabels,labelStr)); 

 } 

 root.child=child; 

 return root; 

 } 


 /** 

 * 從labels數組中刪除元素val 

 * @param labels 

 * @param val 

 * @return 新的數組 

 */ 

 private String[] del(String [] labels,String val){ 

          tring [] newLabels = new String[labels.length-1]; 

          int k=0; 

          for(int i=0;i<labels.length && k<labels.length-1;i++){ 

          if(!labels[i].equals(val)) 

                 newLabels[k++]=labels[i]; 

 } 

 return newLabels; 

 } 

 /** 

 * 決策樹分類調用接口 

 * @param tree 調用createTree獲得的決策樹根節點 

 * @param featureLabels 特徵集名稱 

 * @param sample 待分類樣本 

 * @return 

 */ 

 public String classify(Node tree,String [] featureLabels,Matrix sample){ 

        while ((null != tree) && (null != tree.child)){ 

               try { 

                     System.out.println(tree.element); 

                      tree = tree.child.get(sample.get(0,getIndex(featureLabels,(String) tree.element))); 

              }catch (Exception e){ 

                   e.printStackTrace(); 

                   return "Class Not Find"; 

               } 

        } 

        if(null == tree) return "Class Not Find"; 

        return (String) tree.element; 

 } 

 /** 

 * 從String數組中獲取對應值的下標 

 * @param labels 

 * @param val 

 * @return 

 */ 

 private int getIndex(String [] labels,String val){ 

         for(int i=0;i<labels.length;i++){ 

              if(val.equals(labels[i])) 

                     return i; 

           } 

            return -1; 

   }  

public static void main(String [] args) throws Exception{ 

       DectreeHandler handler = new DectreeHandler(); 

       double [][] data = { 

                   {1,1,1,1,1,1,1}, 

                   {2,2,2,2,2,2,2},  

                   {3,3,3,3,3,3,3}, 

                   {1,1,4,2,3,3,1}, 

                   {4,1,5,4,2,1,2}, 

                   {1,2,6,2,1,2,6}, 

                   {4,2,7,4,3,5,4}, 

                   {1,2,8,3,3,3,4}, 

                   {2,12,9,5,2,4,5}, 

                   {1,2,3,10,8,6,5} 

 }; 

 Matrix dataSet = new Matrix(data); 

double [] labels = {1,1,1,2,2,3,3,3,3,0}; 

dataSet = dataSet.expand(new Matrix(labels,1).transpose(),true); 

int bestFeature = handler.chooseBestFeatureToSplit(dataSet); System.out.println(bestFeature); 

dataSet.print(dataSet.getColumnDimension(),3); 

String [] featurelabels = {"特徵A","特徵B","‘特徵C","特徵D","特徵E","特徵F","特徵G"}; String [] labelStr = {"類A","類B","類C","類D"}; 

Node root = handler.createTree(dataSet,featurelabels,labelStr); 

 //序列化存儲 

/* ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File("E:\\dectree.txt"))); 

oos.writeObject(root); 

oos.flush(); 

oos.close(); 


ObjectInputStream ois = new ObjectInputStream(new FileInputStream(new File("E:\\dectree.txt"))); 

Node tree = (Node) ois.readObject();*/ 

      double [] sample = new double[]{1,1,3,10,3,3,4}; 

      String className = handler.classify(root,featurelabels,new Matrix(sample,1));         System.out.println(className); 

       } 

}

說明,提供外部調用權限的僅兩個方法,一是createTree()用於根據訓練集數據,遞歸建立決策樹,二是classify()根據決策樹結構以及樣本點數據獲得樣本點的具體分類。至於其餘部分:計算數據集的香農熵,遍歷特徵選擇使剩餘數據熵最小的特徵做爲分支斷定依據等部分邏輯直接看代碼及註釋,此處也不細說了。 其中,建立決策樹的過程以下: 

(1)選擇當前數據集A中最佳的特徵做爲節點判決依據 

(2)得到訓練集中(1)中特徵的全部取值 

(3)將當前數據集去掉該列特徵數據,獲得新的數據集B 

(4)遍歷(2)中該特徵的全部取值,獲得全部子節點,其中子節點斷定條件對應該特徵值的一個取值,子節點爲將(3)中獲得的數據集B迭代回(1)獲得。 

其中葉子節點的判斷條件爲,當前數據集僅有一個分類。 

其中(1)中選擇最佳特徵的過程爲,遍歷當前數據的全部特徵,根據特徵的取值域及每一個取值對應的頻率,根據信息熵計算公式獲得該特徵值對應的熵值,取全部特徵中熵最小的特徵做爲最佳特徵。 

因爲此處是須要計算數據集的香農熵,因此此處決策樹僅適用於數值型數據。

相關文章
相關標籤/搜索