CART分類迴歸樹算法

CART算法也是一種決策樹分類算法。CART分類迴歸樹算法的本質也是對數據進行分類的,最終數據的表現形式也是以樹形的模式展示的,與ID3,C4.5算法不一樣的是,他的分類標準所採用的算法不一樣了。下面列出了其中的一些不一樣之處:html

一、CART最後造成的樹是一個二叉樹,每一個節點會分紅2個節點,左孩子節點和右孩子節點,而在ID3和C4.5中是按照分類屬性的值類型進行劃分,因而這就要求CART算法在所選定的屬性中又要劃分出最佳的屬性劃分值,節點若是選定了劃分屬性名稱還要肯定裏面按照那個值作一個二元的劃分。java

二、CART算法對於屬性的值採用的是基於Gini係數值的方式作比較,gini某個屬性的某次值的劃分的gini指數的值爲:node

,pk就是分別爲正負實例的機率,gini係數越小說明分類純度越高,能夠想象成與熵的定義同樣。所以在最後計算的時候咱們只取其中值最小的作出劃分。最後作比較的時候用的是gini的增益作比較,要對分類號的數據作出一個帶權重的gini指數的計算。舉一個網上的一個例子:android

 

好比體溫爲恆溫時包含哺乳類5個、鳥類2個,則:算法

體溫爲非恆溫時包含爬行類3個、魚類3個、兩棲類2個,則數組

因此若是按照「體溫爲恆溫和非恆溫」進行劃分的話,咱們獲得GINI的增益(類比信息增益):app

最好的劃分就是使得GINI_Gain最小的劃分。工具

經過比較每一個屬性的最小的gini指數值,做爲最後的結果。優化

三、CART算法在把數據進行分類以後,會對樹進行一個剪枝,經常使用的用前剪枝和後剪枝法,而常見的後剪枝發包括代價複雜度剪枝,悲觀偏差剪枝等等,我寫的這次算法採用的是代價複雜度剪枝法。代價複雜度剪枝的算法公式爲:ui

α表示的是每一個非葉子節點的偏差增益率,能夠理解爲偏差代價,最後選出偏差代價最小的一個節點進行剪枝。

裏面變量的意思爲:

 

是子樹中包含的葉子節點個數;

是節點t的偏差代價,若是該節點被剪枝;

r(t)是節點t的偏差率;

p(t)是節點t上的數據佔全部數據的比例。

是子樹Tt的偏差代價,若是該節點不被剪枝。它等於子樹Tt上全部葉子節點的偏差代價之和。下面說說我對於這個公式的理解:其實這個公式的本質是對於剪枝前和剪枝後的樣本誤差率作一個差值比較,一個好的分類固然是分類後的樣本誤差率相較於沒分類(就是剪枝掉的時候)的誤差率小,因此這時的值就會大,若是分類先後基本變化不大,則意味着分類不起什麼效果,α值的分子位置就小,因此偏差代價就小,能夠被剪枝。可是通常分類後的誤差率會小於分類前的,由於誤差數在高層節點的時候確定比子節點的多,子節點誤差數最多與父親節點同樣。

CART算法實現

首先是程序的備用數據,我是把他存在了一個文字中,經過程序進行逐行的讀取:

 

[java] view plain copy

 print?

  1. Rid Age Income Student CreditRating BuysComputer  
  2. 1 Youth High No Fair No  
  3. 2 Youth High No Excellent No  
  4. 3 MiddleAged High No Fair Yes  
  5. 4 Senior Medium No Fair Yes  
  6. 5 Senior Low Yes Fair Yes  
  7. 6 Senior Low Yes Excellent No  
  8. 7 MiddleAged Low Yes Excellent Yes  
  9. 8 Youth Medium No Fair No  
  10. 9 Youth Low Yes Fair Yes  
  11. 10 Senior Medium Yes Fair Yes  
  12. 11 Youth Medium Yes Excellent Yes  
  13. 12 MiddleAged Medium No Excellent Yes  
  14. 13 MiddleAged High Yes Fair Yes  
  15. 14 Senior Medium No Excellent No  

下面是主程序,裏面有具體的註釋:

 

 

[java] view plain copy

 print?

  1. package DataMing_CART;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.IOException;  
  7. import java.util.ArrayList;  
  8. import java.util.HashMap;  
  9. import java.util.LinkedList;  
  10. import java.util.Map;  
  11. import java.util.Queue;  
  12.   
  13. import javax.lang.model.element.NestingKind;  
  14. import javax.swing.text.DefaultEditorKit.CutAction;  
  15. import javax.swing.text.html.MinimalHTMLWriter;  
  16.   
  17. /** 
  18.  * CART分類迴歸樹算法工具類 
  19.  *  
  20.  * @author lyq 
  21.  *  
  22.  */  
  23. public class CARTTool {  
  24.     // 類標號的值類型  
  25.     private final String YES = "Yes";  
  26.     private final String NO = "No";  
  27.   
  28.     // 全部屬性的類型總數,在這裏就是data源數據的列數  
  29.     private int attrNum;  
  30.     private String filePath;  
  31.     // 初始源數據,用一個二維字符數組存放模仿表格數據  
  32.     private String[][] data;  
  33.     // 數據的屬性行的名字  
  34.     private String[] attrNames;  
  35.     // 每一個屬性的值全部類型  
  36.     private HashMap<String, ArrayList<String>> attrValue;  
  37.   
  38.     public CARTTool(String filePath) {  
  39.         this.filePath = filePath;  
  40.         attrValue = new HashMap<>();  
  41.     }  
  42.   
  43.     /** 
  44.      * 從文件中讀取數據 
  45.      */  
  46.     public void readDataFile() {  
  47.         File file = new File(filePath);  
  48.         ArrayList<String[]> dataArray = new ArrayList<String[]>();  
  49.   
  50.         try {  
  51.             BufferedReader in = new BufferedReader(new FileReader(file));  
  52.             String str;  
  53.             String[] tempArray;  
  54.             while ((str = in.readLine()) != null) {  
  55.                 tempArray = str.split(" ");  
  56.                 dataArray.add(tempArray);  
  57.             }  
  58.             in.close();  
  59.         } catch (IOException e) {  
  60.             e.getStackTrace();  
  61.         }  
  62.   
  63.         data = new String[dataArray.size()][];  
  64.         dataArray.toArray(data);  
  65.         attrNum = data[0].length;  
  66.         attrNames = data[0];  
  67.   
  68.         /* 
  69.          * for (int i = 0; i < data.length; i++) { for (int j = 0; j < 
  70.          * data[0].length; j++) { System.out.print(" " + data[i][j]); } 
  71.          * System.out.print("\n"); } 
  72.          */  
  73.   
  74.     }  
  75.   
  76.     /** 
  77.      * 首先初始化每種屬性的值的全部類型,用於後面的子類熵的計算時用 
  78.      */  
  79.     public void initAttrValue() {  
  80.         ArrayList<String> tempValues;  
  81.   
  82.         // 按照列的方式,從左往右找  
  83.         for (int j = 1; j < attrNum; j++) {  
  84.             // 從一列中的上往下開始尋找值  
  85.             tempValues = new ArrayList<>();  
  86.             for (int i = 1; i < data.length; i++) {  
  87.                 if (!tempValues.contains(data[i][j])) {  
  88.                     // 若是這個屬性的值沒有添加過,則添加  
  89.                     tempValues.add(data[i][j]);  
  90.                 }  
  91.             }  
  92.   
  93.             // 一列屬性的值已經遍歷完畢,複製到map屬性表中  
  94.             attrValue.put(data[0][j], tempValues);  
  95.         }  
  96.   
  97.         /* 
  98.          * for (Map.Entry entry : attrValue.entrySet()) { 
  99.          * System.out.println("key:value " + entry.getKey() + ":" + 
  100.          * entry.getValue()); } 
  101.          */  
  102.     }  
  103.   
  104.     /** 
  105.      * 計算機基尼指數 
  106.      *  
  107.      * @param remainData 
  108.      *            剩餘數據 
  109.      * @param attrName 
  110.      *            屬性名稱 
  111.      * @param value 
  112.      *            屬性值 
  113.      * @param beLongValue 
  114.      *            分類是否屬於此屬性值 
  115.      * @return 
  116.      */  
  117.     public double computeGini(String[][] remainData, String attrName,  
  118.             String value, boolean beLongValue) {  
  119.         // 實例總數  
  120.         int total = 0;  
  121.         // 正實例數  
  122.         int posNum = 0;  
  123.         // 負實例數  
  124.         int negNum = 0;  
  125.         // 基尼指數  
  126.         double gini = 0;  
  127.   
  128.         // 仍是按列從左往右遍歷屬性  
  129.         for (int j = 1; j < attrNames.length; j++) {  
  130.             // 找到了指定的屬性  
  131.             if (attrName.equals(attrNames[j])) {  
  132.                 for (int i = 1; i < remainData.length; i++) {  
  133.                     // 統計正負實例按照屬於和不屬於值類型進行劃分  
  134.                     if ((beLongValue && remainData[i][j].equals(value))  
  135.                             || (!beLongValue && !remainData[i][j].equals(value))) {  
  136.                         if (remainData[i][attrNames.length - 1].equals(YES)) {  
  137.                             // 判斷此行數據是否爲正實例  
  138.                             posNum++;  
  139.                         } else {  
  140.                             negNum++;  
  141.                         }  
  142.                     }  
  143.                 }  
  144.             }  
  145.         }  
  146.   
  147.         total = posNum + negNum;  
  148.         double posProbobly = (double) posNum / total;  
  149.         double negProbobly = (double) negNum / total;  
  150.         gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly;  
  151.   
  152.         // 返回計算基尼指數  
  153.         return gini;  
  154.     }  
  155.   
  156.     /** 
  157.      * 計算屬性劃分的最小基尼指數,返回最小的屬性值劃分和最小的基尼指數,保存在一個數組中 
  158.      *  
  159.      * @param remainData 
  160.      *            剩餘誰 
  161.      * @param attrName 
  162.      *            屬性名稱 
  163.      * @return 
  164.      */  
  165.     public String[] computeAttrGini(String[][] remainData, String attrName) {  
  166.         String[] str = new String[2];  
  167.         // 最終該屬性的劃分類型值  
  168.         String spiltValue = "";  
  169.         // 臨時變量  
  170.         int tempNum = 0;  
  171.         // 保存屬性的值劃分時的最小的基尼指數  
  172.         double minGini = Integer.MAX_VALUE;  
  173.         ArrayList<String> valueTypes = attrValue.get(attrName);  
  174.         // 屬於此屬性值的實例數  
  175.         HashMap<String, Integer> belongNum = new HashMap<>();  
  176.   
  177.         for (String string : valueTypes) {  
  178.             // 從新計數的時候,數字歸0  
  179.             tempNum = 0;  
  180.             // 按列從左往右遍歷屬性  
  181.             for (int j = 1; j < attrNames.length; j++) {  
  182.                 // 找到了指定的屬性  
  183.                 if (attrName.equals(attrNames[j])) {  
  184.                     for (int i = 1; i < remainData.length; i++) {  
  185.                         // 統計正負實例按照屬於和不屬於值類型進行劃分  
  186.                         if (remainData[i][j].equals(string)) {  
  187.                             tempNum++;  
  188.                         }  
  189.                     }  
  190.                 }  
  191.             }  
  192.   
  193.             belongNum.put(string, tempNum);  
  194.         }  
  195.   
  196.         double tempGini = 0;  
  197.         double posProbably = 1.0;  
  198.         double negProbably = 1.0;  
  199.         for (String string : valueTypes) {  
  200.             tempGini = 0;  
  201.   
  202.             posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1);  
  203.             negProbably = 1 - posProbably;  
  204.   
  205.             tempGini += posProbably  
  206.                     * computeGini(remainData, attrName, string, true);  
  207.             tempGini += negProbably  
  208.                     * computeGini(remainData, attrName, string, false);  
  209.   
  210.             if (tempGini < minGini) {  
  211.                 minGini = tempGini;  
  212.                 spiltValue = string;  
  213.             }  
  214.         }  
  215.   
  216.         str[0] = spiltValue;  
  217.         str[1] = minGini + "";  
  218.   
  219.         return str;  
  220.     }  
  221.   
  222.     public void buildDecisionTree(AttrNode node, String parentAttrValue,  
  223.             String[][] remainData, ArrayList<String> remainAttr,  
  224.             boolean beLongParentValue) {  
  225.         // 屬性劃分值  
  226.         String valueType = "";  
  227.         // 劃分屬性名稱  
  228.         String spiltAttrName = "";  
  229.         double minGini = Integer.MAX_VALUE;  
  230.         double tempGini = 0;  
  231.         // 基尼指數數組,保存了基尼指數和此基尼指數的劃分屬性值  
  232.         String[] giniArray;  
  233.   
  234.         if (beLongParentValue) {  
  235.             node.setParentAttrValue(parentAttrValue);  
  236.         } else {  
  237.             node.setParentAttrValue("!" + parentAttrValue);  
  238.         }  
  239.   
  240.         if (remainAttr.size() == 0) {  
  241.             if (remainData.length > 1) {  
  242.                 ArrayList<String> indexArray = new ArrayList<>();  
  243.                 for (int i = 1; i < remainData.length; i++) {  
  244.                     indexArray.add(remainData[i][0]);  
  245.                 }  
  246.                 node.setDataIndex(indexArray);  
  247.             }  
  248.             System.out.println("attr remain null");  
  249.             return;  
  250.         }  
  251.   
  252.         for (String str : remainAttr) {  
  253.             giniArray = computeAttrGini(remainData, str);  
  254.             tempGini = Double.parseDouble(giniArray[1]);  
  255.   
  256.             if (tempGini < minGini) {  
  257.                 spiltAttrName = str;  
  258.                 minGini = tempGini;  
  259.                 valueType = giniArray[0];  
  260.             }  
  261.         }  
  262.         // 移除劃分屬性  
  263.         remainAttr.remove(spiltAttrName);  
  264.         node.setAttrName(spiltAttrName);  
  265.   
  266.         // 孩子節點,分類迴歸樹中,每次二元劃分,分出2個孩子節點  
  267.         AttrNode[] childNode = new AttrNode[2];  
  268.         String[][] rData;  
  269.   
  270.         boolean[] bArray = new boolean[] { true, false };  
  271.         for (int i = 0; i < bArray.length; i++) {  
  272.             // 二元劃分屬於屬性值的劃分  
  273.             rData = removeData(remainData, spiltAttrName, valueType, bArray[i]);  
  274.   
  275.             boolean sameClass = true;  
  276.             ArrayList<String> indexArray = new ArrayList<>();  
  277.             for (int k = 1; k < rData.length; k++) {  
  278.                 indexArray.add(rData[k][0]);  
  279.                 // 判斷是否爲同一類的  
  280.                 if (!rData[k][attrNames.length - 1]  
  281.                         .equals(rData[1][attrNames.length - 1])) {  
  282.                     // 只要有1個不相等,就不是同類型的  
  283.                     sameClass = false;  
  284.                     break;  
  285.                 }  
  286.             }  
  287.   
  288.             childNode[i] = new AttrNode();  
  289.             if (!sameClass) {  
  290.                 // 建立新的對象屬性,對象的同個引用會出錯  
  291.                 ArrayList<String> rAttr = new ArrayList<>();  
  292.                 for (String str : remainAttr) {  
  293.                     rAttr.add(str);  
  294.                 }  
  295.                 buildDecisionTree(childNode[i], valueType, rData, rAttr,  
  296.                         bArray[i]);  
  297.             } else {  
  298.                 String pAtr = (bArray[i] ? valueType : "!" + valueType);  
  299.                 childNode[i].setParentAttrValue(pAtr);  
  300.                 childNode[i].setDataIndex(indexArray);  
  301.             }  
  302.         }  
  303.   
  304.         node.setChildAttrNode(childNode);  
  305.     }  
  306.   
  307.     /** 
  308.      * 屬性劃分完畢,進行數據的移除 
  309.      *  
  310.      * @param srcData 
  311.      *            源數據 
  312.      * @param attrName 
  313.      *            劃分的屬性名稱 
  314.      * @param valueType 
  315.      *            屬性的值類型 
  316.      * @parame beLongValue 分類是否屬於此值類型 
  317.      */  
  318.     private String[][] removeData(String[][] srcData, String attrName,  
  319.             String valueType, boolean beLongValue) {  
  320.         String[][] desDataArray;  
  321.         ArrayList<String[]> desData = new ArrayList<>();  
  322.         // 待刪除數據  
  323.         ArrayList<String[]> selectData = new ArrayList<>();  
  324.         selectData.add(attrNames);  
  325.   
  326.         // 數組數據轉化到列表中,方便移除  
  327.         for (int i = 0; i < srcData.length; i++) {  
  328.             desData.add(srcData[i]);  
  329.         }  
  330.   
  331.         // 仍是從左往右一列列的查找  
  332.         for (int j = 1; j < attrNames.length; j++) {  
  333.             if (attrNames[j].equals(attrName)) {  
  334.                 for (int i = 1; i < desData.size(); i++) {  
  335.                     if (desData.get(i)[j].equals(valueType)) {  
  336.                         // 若是匹配這個數據,則移除其餘的數據  
  337.                         selectData.add(desData.get(i));  
  338.                     }  
  339.                 }  
  340.             }  
  341.         }  
  342.   
  343.         if (beLongValue) {  
  344.             desDataArray = new String[selectData.size()][];  
  345.             selectData.toArray(desDataArray);  
  346.         } else {  
  347.             // 屬性名稱行不移除  
  348.             selectData.remove(attrNames);  
  349.             // 若是是劃分不屬於此類型的數據時,進行移除  
  350.             desData.removeAll(selectData);  
  351.             desDataArray = new String[desData.size()][];  
  352.             desData.toArray(desDataArray);  
  353.         }  
  354.   
  355.         return desDataArray;  
  356.     }  
  357.   
  358.     public void startBuildingTree() {  
  359.         readDataFile();  
  360.         initAttrValue();  
  361.   
  362.         ArrayList<String> remainAttr = new ArrayList<>();  
  363.         // 添加屬性,除了最後一個類標號屬性  
  364.         for (int i = 1; i < attrNames.length - 1; i++) {  
  365.             remainAttr.add(attrNames[i]);  
  366.         }  
  367.   
  368.         AttrNode rootNode = new AttrNode();  
  369.         buildDecisionTree(rootNode, "", data, remainAttr, false);  
  370.         setIndexAndAlpah(rootNode, 0, false);  
  371.         System.out.println("剪枝前:");  
  372.         showDecisionTree(rootNode, 1);  
  373.         setIndexAndAlpah(rootNode, 0, true);  
  374.         System.out.println("\n剪枝後:");  
  375.         showDecisionTree(rootNode, 1);  
  376.     }  
  377.   
  378.     /** 
  379.      * 顯示決策樹 
  380.      *  
  381.      * @param node 
  382.      *            待顯示的節點 
  383.      * @param blankNum 
  384.      *            行空格符,用於顯示樹型結構 
  385.      */  
  386.     private void showDecisionTree(AttrNode node, int blankNum) {  
  387.         System.out.println();  
  388.         for (int i = 0; i < blankNum; i++) {  
  389.             System.out.print("    ");  
  390.         }  
  391.         System.out.print("--");  
  392.         // 顯示分類的屬性值  
  393.         if (node.getParentAttrValue() != null  
  394.                 && node.getParentAttrValue().length() > 0) {  
  395.             System.out.print(node.getParentAttrValue());  
  396.         } else {  
  397.             System.out.print("--");  
  398.         }  
  399.         System.out.print("--");  
  400.   
  401.         if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {  
  402.             String i = node.getDataIndex().get(0);  
  403.             System.out.print("【" + node.getNodeIndex() + "】類別:"  
  404.                     + data[Integer.parseInt(i)][attrNames.length - 1]);  
  405.             System.out.print("[");  
  406.             for (String index : node.getDataIndex()) {  
  407.                 System.out.print(index + ", ");  
  408.             }  
  409.             System.out.print("]");  
  410.         } else {  
  411.             // 遞歸顯示子節點  
  412.             System.out.print("【" + node.getNodeIndex() + ":"  
  413.                     + node.getAttrName() + "】");  
  414.             if (node.getChildAttrNode() != null) {  
  415.                 for (AttrNode childNode : node.getChildAttrNode()) {  
  416.                     showDecisionTree(childNode, 2 * blankNum);  
  417.                 }  
  418.             } else {  
  419.                 System.out.print("【  Child Null】");  
  420.             }  
  421.         }  
  422.     }  
  423.   
  424.     /** 
  425.      * 爲節點設置序列號,並計算每一個節點的偏差率,用於後面剪枝 
  426.      *  
  427.      * @param node 
  428.      *            開始的時候傳入的是根節點 
  429.      * @param index 
  430.      *            開始的索引號,從1開始 
  431.      * @param ifCutNode 
  432.      *            是否須要剪枝 
  433.      */  
  434.     private void setIndexAndAlpah(AttrNode node, int index, boolean ifCutNode) {  
  435.         AttrNode tempNode;  
  436.         // 最小偏差代價節點,即將被剪枝的節點  
  437.         AttrNode minAlphaNode = null;  
  438.         double minAlpah = Integer.MAX_VALUE;  
  439.         Queue<AttrNode> nodeQueue = new LinkedList<AttrNode>();  
  440.   
  441.         nodeQueue.add(node);  
  442.         while (nodeQueue.size() > 0) {  
  443.             index++;  
  444.             // 從隊列頭部獲取首個節點  
  445.             tempNode = nodeQueue.poll();  
  446.             tempNode.setNodeIndex(index);  
  447.             if (tempNode.getChildAttrNode() != null) {  
  448.                 for (AttrNode childNode : tempNode.getChildAttrNode()) {  
  449.                     nodeQueue.add(childNode);  
  450.                 }  
  451.                 computeAlpha(tempNode);  
  452.                 if (tempNode.getAlpha() < minAlpah) {  
  453.                     minAlphaNode = tempNode;  
  454.                     minAlpah = tempNode.getAlpha();  
  455.                 } else if (tempNode.getAlpha() == minAlpah) {  
  456.                     // 若是偏差代價值同樣,比較包含的葉子節點個數,剪枝有多葉子節點數的節點  
  457.                     if (tempNode.getLeafNum() > minAlphaNode.getLeafNum()) {  
  458.                         minAlphaNode = tempNode;  
  459.                     }  
  460.                 }  
  461.             }  
  462.         }  
  463.   
  464.         if (ifCutNode) {  
  465.             // 進行樹的剪枝,讓其左右孩子節點爲null  
  466.             minAlphaNode.setChildAttrNode(null);  
  467.         }  
  468.     }  
  469.   
  470.     /** 
  471.      * 爲非葉子節點計算偏差代價,這裏的後剪枝法用的是CCP代價複雜度剪枝 
  472.      *  
  473.      * @param node 
  474.      *            待計算的非葉子節點 
  475.      */  
  476.     private void computeAlpha(AttrNode node) {  
  477.         double rt = 0;  
  478.         double Rt = 0;  
  479.         double alpha = 0;  
  480.         // 當前節點的數據總數  
  481.         int sumNum = 0;  
  482.         // 最少的誤差數  
  483.         int minNum = 0;  
  484.   
  485.         ArrayList<String> dataIndex;  
  486.         ArrayList<AttrNode> leafNodes = new ArrayList<>();  
  487.   
  488.         addLeafNode(node, leafNodes);  
  489.         node.setLeafNum(leafNodes.size());  
  490.         for (AttrNode attrNode : leafNodes) {  
  491.             dataIndex = attrNode.getDataIndex();  
  492.   
  493.             int num = 0;  
  494.             sumNum += dataIndex.size();  
  495.             for (String s : dataIndex) {  
  496.                 // 統計分類數據中的正負實例數  
  497.                 if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) {  
  498.                     num++;  
  499.                 }  
  500.             }  
  501.             minNum += num;  
  502.   
  503.             // 取小數量的值部分  
  504.             if (1.0 * num / dataIndex.size() > 0.5) {  
  505.                 num = dataIndex.size() - num;  
  506.             }  
  507.   
  508.             rt += (1.0 * num / (data.length - 1));  
  509.         }  
  510.           
  511.         //一樣取出少誤差的那部分  
  512.         if (1.0 * minNum / sumNum > 0.5) {  
  513.             minNum = sumNum - minNum;  
  514.         }  
  515.   
  516.         Rt = 1.0 * minNum / (data.length - 1);  
  517.         alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1);  
  518.         node.setAlpha(alpha);  
  519.     }  
  520.   
  521.     /** 
  522.      * 篩選出節點所包含的葉子節點數 
  523.      *  
  524.      * @param node 
  525.      *            待篩選節點 
  526.      * @param leafNode 
  527.      *            葉子節點列表容器 
  528.      */  
  529.     private void addLeafNode(AttrNode node, ArrayList<AttrNode> leafNode) {  
  530.         ArrayList<String> dataIndex;  
  531.   
  532.         if (node.getChildAttrNode() != null) {  
  533.             for (AttrNode childNode : node.getChildAttrNode()) {  
  534.                 dataIndex = childNode.getDataIndex();  
  535.                 if (dataIndex != null && dataIndex.size() > 0) {  
  536.                     // 說明此節點爲葉子節點  
  537.                     leafNode.add(childNode);  
  538.                 } else {  
  539.                     // 若是仍是非葉子節點則繼續遞歸調用  
  540.                     addLeafNode(childNode, leafNode);  
  541.                 }  
  542.             }  
  543.         }  
  544.     }  
  545.   
  546. }  

AttrNode節點的設計和屬性:

 

 

[java] view plain copy

 print?

  1. /** 
  2.  * 迴歸分類樹節點 
  3.  *  
  4.  * @author lyq 
  5.  *  
  6.  */  
  7. public class AttrNode {  
  8.     // 節點屬性名字  
  9.     private String attrName;  
  10.     // 節點索引標號  
  11.     private int nodeIndex;  
  12.     //包含的葉子節點數  
  13.     private int leafNum;  
  14.     // 節點偏差率  
  15.     private double alpha;  
  16.     // 父親分類屬性值  
  17.     private String parentAttrValue;  
  18.     // 孩子節點  
  19.     private AttrNode[] childAttrNode;  
  20.     // 數據記錄索引  
  21.     private ArrayList<String> dataIndex;  
  22.     .....  

get,set方法自行補上。客戶端的場景調用:

 

 

[java] view plain copy

 print?

  1. package DataMing_CART;  
  2.   
  3. public class Client {  
  4.     public static void main(String[] args){  
  5.         String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";  
  6.           
  7.         CARTTool tool = new CARTTool(filePath);  
  8.           
  9.         tool.startBuildingTree();  
  10.     }  
  11. }  

數據文件路徑自行修改,不然會報錯(特殊狀況懶得處理了.....)。最後程序的輸出結果,請自行從左往右看,從上往下,左邊的是父親節點,上面的是考前的子節點:

 

 

[java] view plain copy

 print?

  1. 剪枝前:  
  2.   
  3.     --!--【1:Age】  
  4.         --MiddleAged--【2】類別:Yes[3, 7, 12, 13, ]  
  5.         --!MiddleAged--【3:Student】  
  6.                 --No--【4:Income】  
  7.                                 --High--【6】類別:No[1, 2, ]  
  8.                                 --!High--【7:CreditRating】  
  9.                                                                 --Fair--【10】類別:Yes[4, 8, ]  
  10.                                                                 --!Fair--【11】類別:No[14, ]  
  11.                 --!No--【5:CreditRating】  
  12.                                 --Fair--【8】類別:Yes[5, 9, 10, ]  
  13.                                 --!Fair--【9:Income】  
  14.                                                                 --Medium--【12】類別:Yes[11, ]  
  15.                                                                 --!Medium--【13】類別:No[6, ]  
  16. 剪枝後:  
  17.   
  18.     --!--【1:Age】  
  19.         --MiddleAged--【2】類別:Yes[3, 7, 12, 13, ]  
  20.         --!MiddleAged--【3:Student】  
  21.                 --No--【4:Income】【  Child Null】  
  22.                 --!No--【5:CreditRating】  
  23.                                 --Fair--【8】類別:Yes[5, 9, 10, ]  
  24.                                 --!Fair--【9:Income】  
  25.                                                                 --Medium--【12】類別:Yes[11, ]  
  26.                                                                 --!Medium--【13】類別:No[6, ]  

 

結果分析:

我在一開始的時候根據的是最後分類的數據是否爲同一個類標識的,若是都爲YES或者都爲NO的,分類終止,通常狀況下都說的通,可是若是最後屬性劃分完畢了,剩餘的數據還有存在類標識不同的狀況就會誤差,好比說這裏的7號CredaRating節點,下面Fair分支中的[4,8]就不是同類的。因此在後面的剪枝算法就被剪枝了。由於後面的4和7號節點的偏差代價率爲0,說明分類先後沒有類誤差變化,這也見證了後剪枝算法的威力所在了。

在coding遇到的困難和改進的地方:

一、先說說在編碼時遇到的困難,在對節點進行賦索引標號值的時候出了問題,由於是以前生成樹的時候採用了DFS的思想,若是編號時也採用此方法就不對了,因而就用到了把節點取出放入隊列這樣的遍歷方式,就是BFS的方式爲節點標號。

二、程序的一個改進的地方在於算一個非葉子節點的時候須要計算他所包含的葉子節點數,採用了從當前節點開始從上往下遞歸計算,並且每一個非葉子節點都計算一遍,顯然這樣作的效率是不高,後來想到了一種從葉子節點開始計算,從下往上直到根節點,對父親節點的非葉子節點列表作更新操做,就只要計算一次,這有點dp的思想在裏面了,因爲時間關係,沒有來得及實現。

三、第二個優化點就是後剪枝算法的多樣化,我這裏採用的是CCP代價複雜度算法,你們能夠試着實現其餘的諸如悲觀偏差算法進行剪枝,看看能不能把程序中4和7號節點識別出來,而且剪枝掉。

相關文章
相關標籤/搜索