決策樹系列(五)——CART

CART,又名分類迴歸樹,是在ID3的基礎上進行優化的決策樹,學習CART記住如下幾個關鍵點:node

(1)CART既能是分類樹,又能是分類樹;算法

(2)當CART是分類樹時,採用GINI值做爲節點分裂的依據;當CART是迴歸樹時,採用樣本的最小方差做爲節點分裂的依據;數組

(3)CART是一棵二叉樹。ide

接下來將以一個實際的例子對CART進行介紹:函數

                                                                    表1 原始數據表學習

看電視時間優化

婚姻狀況this

職業spa

年齡3d

3

未婚

學生

12

4

未婚

學生

18

2

已婚

老師

26

5

已婚

上班族

47

2.5

已婚

上班族

36

3.5

未婚

老師

29

4

已婚

學生

21

從如下的思路理解CART

分類樹?迴歸樹?

      分類樹的做用是經過一個對象的特徵來預測該對象所屬的類別,而回歸樹的目的是根據一個對象的信息預測該對象的屬性,並以數值表示。

      CART既能是分類樹,又能是決策樹,如上表所示,若是咱們想預測一我的是否已婚,那麼構建的CART將是分類樹;若是想預測一我的的年齡,那麼構建的將是迴歸樹。

分類樹和迴歸樹是怎麼作決策的?假設咱們構建了兩棵決策樹分別預測用戶是否已婚和實際的年齡,如圖1和圖2所示:

                                      圖1 預測婚姻狀況決策樹                                               圖2 預測年齡的決策樹

       圖1表示一棵分類樹,其葉子節點的輸出結果爲一個實際的類別,在這個例子裏是婚姻的狀況(已婚或者未婚),選擇葉子節點中數量佔比最大的類別做爲輸出的類別;

       圖2是一棵迴歸樹,預測用戶的實際年齡,是一個具體的輸出值。怎樣獲得這個輸出值?通常狀況下選擇使用中值、平均值或者衆數進行表示,圖2使用節點年齡數據的平均值做爲輸出值。

CART如何選擇分裂的屬性?

      分裂的目的是爲了可以讓數據變純,使決策樹輸出的結果更接近真實值。那麼CART是如何評價節點的純度呢?若是是分類樹,CART採用GINI值衡量節點純度;若是是迴歸樹,採用樣本方差衡量節點純度。節點越不純,節點分類或者預測的效果就越差。

GINI值的計算公式:

                               

      節點越不純,GINI值越大。以二分類爲例,若是節點的全部數據只有一個類別,則 ,若是兩類數量相同,則

迴歸方差計算公式:

                                                                               

      方差越大,表示該節點的數據越分散,預測的效果就越差。若是一個節點的全部數據都相同,那麼方差就爲0,此時能夠很確定得認爲該節點的輸出值;若是節點的數據相差很大,那麼輸出的值有很大的可能與實際值相差較大。

      所以,不管是分類樹仍是迴歸樹,CART都要選擇使子節點的GINI值或者回歸方差最小的屬性做爲分裂的方案。即最小化(分類樹):

                               

或者(迴歸樹):

                                                                                                     

CART如何分裂成一棵二叉樹?

     節點的分裂分爲兩種狀況,連續型的數據和離散型的數據。

     CART對連續型屬性的處理與C4.5差很少,經過最小化分裂後的GINI值或者樣本方差尋找最優分割點,將節點一分爲二,在這裏再也不敘述,詳細請看C4.5

     對於離散型屬性,理論上有多少個離散值就應該分裂成多少個節點。但CART是一棵二叉樹,每一次分裂只會產生兩個節點,怎麼辦呢?很簡單,只要將其中一個離散值獨立做爲一個節點,其餘的離散值生成另一個節點便可。這種分裂方案有多少個離散值就有多少種劃分的方法,舉一個簡單的例子:若是某離散屬性一個有三個離散值X,Y,Z,則該屬性的分裂方法有{X}、{Y,Z},{Y}、{X,Z},{Z}、{X,Y},分別計算每種劃分方法的基尼值或者樣本方差肯定最優的方法。

     以屬性「職業」爲例,一共有三個離散值,「學生」、「老師」、「上班族」。該屬性有三種劃分的方案,分別爲{「學生」}、{「老師」、「上班族」},{「老師」}、{「學生」、「上班族」},{「上班族」}、{「學生」、「老師」},分別計算三種劃分方案的子節點GINI值或者樣本方差,選擇最優的劃分方法,以下圖所示:

第一種劃分方法:{「學生」}、{「老師」、「上班族」}

預測是否已婚(分類):

                    

預測年齡(迴歸):

            

 

第二種劃分方法:{「老師」}、{「學生」、「上班族」}

 

預測是否已婚(分類):

                    

預測年齡(迴歸):

            

第三種劃分方法:{「上班族」}、{「學生」、「老師」}

 預測是否已婚(分類):

                    

預測年齡(迴歸):

            

綜上,若是想預測是否已婚,則選擇{「上班族」}、{「學生」、「老師」}的劃分方法,若是想預測年齡,則選擇{「老師」}、{「學生」、「上班族」}的劃分方法。

 

如何剪枝?

      CART採用CCP(代價複雜度)剪枝方法。代價複雜度選擇節點表面偏差率增益值最小的非葉子節點,刪除該非葉子節點的左右子節點,如有多個非葉子節點的表面偏差率增益值相同小,則選擇非葉子節點中子節點數最多的非葉子節點進行剪枝。

可描述以下:

令決策樹的非葉子節點爲

a)計算全部非葉子節點的表面偏差率增益值 

b)選擇表面偏差率增益值最小的非葉子節點(若多個非葉子節點具備相同小的表面偏差率增益值,選擇節點數最多的非葉子節點)。

c)對進行剪枝

表面偏差率增益值的計算公式:

                               

其中:

表示葉子節點的偏差代價, 爲節點的錯誤率, 爲節點數據量的佔比;

表示子樹的偏差代價,爲子節點i的錯誤率, 表示節點i的數據節點佔比;

表示子樹節點個數。

算例:

下圖是其中一顆子樹,設決策樹的總數據量爲40。

該子樹的表面偏差率增益值能夠計算以下:

 

求出該子樹的表面錯誤覆蓋率爲 ,只要求出其餘子樹的表面偏差率增益值就能夠對決策樹進行剪枝。

 

程序實際以及源代碼

流程圖:

(1)數據處理

         對原始的數據進行數字化處理,並以二維數據的形式存儲,每一行表示一條記錄,前n-1列表示屬性,最後一列表示分類的標籤。

         如表1的數據能夠轉化爲表2:

                                                                           表2 初始化後的數據

看電視時間

婚姻狀況

職業

年齡

3

未婚

學生

12

4

未婚

學生

18

2

已婚

老師

26

5

已婚

上班族

47

2.5

已婚

上班族

36

3.5

未婚

老師

29

4

已婚

學生

21

        

      其中,對於「婚姻狀況」屬性,數字{1,2}分別表示{未婚,已婚 };對於「職業」屬性{1,2,3, }分別表示{學生、老師、上班族};

代碼以下所示:

         static double[][] allData;                              //存儲進行訓練的數據

    static List<String>[] featureValues;                    //離散屬性對應的離散值

featureValues是鏈表數組,數組的長度爲屬性的個數,數組的每一個元素爲該屬性的離散值鏈表。

(2)兩個類:節點類和分裂信息

a)節點類Node

      該類表示一個節點,屬性包括節點選擇的分裂屬性、節點的輸出類、孩子節點、深度等。注意,與ID3中相比,新增了兩個屬性:leafWrong和leafNode_Count分別表示葉子節點的總分類偏差和葉子節點的個數,主要是爲了方便剪枝。

 1 class Node
 2 {
 3     /// <summary>
 4     /// 每個節點的分裂值
 5     /// </summary>
 6     public List<String> features { get; set; }
 7     /// <summary>
 8     /// 分裂屬性的類型{離散、連續}
 9     /// </summary>
10     public String feature_Type { get; set; }
11     /// <summary>
12     /// 分裂屬性的下標
13     /// </summary>
14     public String SplitFeature { get; set; }
15     //List<int> nums = new List<int>();                       //行序號
16     /// <summary>
17     /// 每個類對應的數目
18     /// </summary>
19     public double[] ClassCount { get; set; }
20     //int[] isUsed = new int[0];                              //屬性的使用狀況 1:已用 2:未用
21     /// <summary>
22     /// 孩子節點
23     /// </summary>
24     public List<Node> childNodes { get; set; }
25     Node Parent = null;
26     /// <summary>
27     /// 該節點佔比最大的類別
28     /// </summary>
29     public String finalResult { get; set; }
30     /// <summary>
31     /// 樹的深度
32     /// </summary>
33     public int deep { get; set; }
34     /// <summary>
35     /// 最大的類下標
36     /// </summary>
37     public int result { get; set; }
38     /// <summary>
39     /// 子節點偏差
40     /// </summary>
41     public int leafWrong { get; set; }
42     /// <summary>
43     /// 子節點數目
44     /// </summary>
45     public int leafNode_Count { get; set; }
46     /// <summary>
47     /// 數據量
48     /// </summary>
49     public int rowCount { get; set; }
50 
51     public void setClassCount(double[] count)
52     {
53         this.ClassCount = count;
54         double max = ClassCount[0];
55         int result = 0;
56         for (int i = 1; i < ClassCount.Length; i++)
57         {
58             if (max < ClassCount[i])
59             {
60                 max = ClassCount[i];
61                 result = i;
62             }
63         }
64         this.result = result;
65     }
66     public double getErrorCount()
67     {
68         return rowCount - ClassCount[result];
69     }
70 }
樹的節點

b)分裂信息類,該類存儲節點進行分裂的信息,包括各個子節點的行座標、子節點各個類的數目、該節點分裂的屬性、屬性的類型等。

 1     class SplitInfo
 2     {
 3         /// <summary>
 4         /// 分裂的屬性下標
 5         /// </summary>
 6         public int splitIndex { get; set; }
 7         /// <summary>
 8         /// 數據類型
 9         /// </summary>
10         public int type { get; set; }
11         /// <summary>
12         /// 分裂屬性的取值
13         /// </summary>
14         public List<String> features { get; set; }
15         /// <summary>
16         /// 各個節點的行座標鏈表
17         /// </summary>
18         public List<int>[] temp { get; set; }
19         /// <summary>
20         /// 每一個節點各種的數目
21         /// </summary>
22         public double[][] class_Count { get; set; }
23     }
分裂信息

主方法findBestSplit(Node node,List<int> nums,int[] isUsed),該方法對節點進行分裂

其中:

node表示即將進行分裂的節點;

nums表示節點數據對一個的行座標列表;

isUsed表示到該節點位置全部屬性的使用狀況;

findBestSplit的這個方法主要有如下幾個組成部分:

1)節點分裂中止的斷定

節點分裂條件如上文所述,源代碼以下:

  1         public static bool ifEnd(Node node, double shang,int[] isUsed)
  2         {
  3             try
  4             {
  5                 double[] count = node.ClassCount;
  6                 int rowCount = node.rowCount;
  7                 int maxResult = 0;
  8                 double maxRate = 0;
  9                 #region 數達到某一深度
 10                 int deep = node.deep;
 11                 if (deep >= 10)
 12                 {
 13                     maxResult = node.result + 1;
 14                     node.feature_Type="result";
 15                     node.features=new List<String>() { maxResult + "" 
 16 
 17 };
 18                     node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]);
 19                     node.leafNode_Count=1;
 20                     return true;
 21                 }
 22                 #endregion
 23                 #region 純度(其實跟後面的有點重了,記得要修改)
 24                 //maxResult = 1;
 25                 //for (int i = 1; i < count.Length; i++)
 26                 //{
 27                 //    if (count[i] / rowCount >= 0.95)
 28                 //    {
 29                 //        node.feature_Type="result";
 30                 //        node.features=new List<String> { "" + (i + 
 31 
 32 1) };
 33                 //        node.leafNode_Count=1;
 34                 //        node.leafWrong=rowCount - Convert.ToInt32
 35 
 36 (count[i]);
 37                 //        return true;
 38                 //    }
 39                 //}
 40                 #endregion
 41                 #region 熵爲0
 42                 if (shang == 0)
 43                 {
 44                     maxRate = count[0] / rowCount;
 45                     maxResult = 1;
 46                     for (int i = 1; i < count.Length; i++)
 47                     {
 48                         if (count[i] / rowCount >= maxRate)
 49                         {
 50                             maxRate = count[i] / rowCount;
 51                             maxResult = i + 1;
 52                         }
 53                     }
 54                     node.feature_Type="result";
 55                     node.features=new List<String> { maxResult + "" 
 56 
 57 };
 58                     node.leafWrong=rowCount - Convert.ToInt32(count
 59 
 60 [maxResult - 1]);
 61                     node.leafNode_Count=1;
 62                     return true;
 63                 }
 64                 #endregion
 65                 #region 屬性已經分完
 66                 //int[] isUsed = node.getUsed();
 67                 bool flag = true;
 68                 for (int i = 0; i < isUsed.Length - 1; i++)
 69                 {
 70                     if (isUsed[i] == 0)
 71                     {
 72                         flag = false;
 73                         break;
 74                     }
 75                 }
 76                 if (flag)
 77                 {
 78                     maxRate = count[0] / rowCount;
 79                     maxResult = 1;
 80                     for (int i = 1; i < count.Length; i++)
 81                     {
 82                         if (count[i] / rowCount >= maxRate)
 83                         {
 84                             maxRate = count[i] / rowCount;
 85                             maxResult = i + 1;
 86                         }
 87                     }
 88                     node.feature_Type=("result");
 89                     node.features=(new List<String> { "" + 
 90 
 91 (maxResult) });
 92                     node.leafWrong=(rowCount - Convert.ToInt32(count
 93 
 94 [maxResult - 1]));
 95                     node.leafNode_Count=(1);
 96                     return true;
 97                 }
 98                 #endregion
 99                 #region 幾點數少於100
100                 if (rowCount < Limit_Node)
101                 {
102                     maxRate = count[0] / rowCount;
103                     maxResult = 1;
104                     for (int i = 1; i < count.Length; i++)
105                     {
106                         if (count[i] / rowCount >= maxRate)
107                         {
108                             maxRate = count[i] / rowCount;
109                             maxResult = i + 1;
110                         }
111                     }
112                     node.feature_Type="result";
113                     node.features=new List<String> { "" + (maxResult) 
114 
115 };
116                     node.leafWrong=rowCount - Convert.ToInt32(count
117 
118 [maxResult - 1]);
119                     node.leafNode_Count=1;
120                     return true;
121                 }
122                 #endregion
123                 return false;
124             }
125             catch (Exception e)
126             {
127                 return false;
128             }
129         }
中止分裂的條件

2)尋找最優的分裂屬性

尋找最優的分裂屬性須要計算每個分裂屬性分裂後的GINI值或者樣本方差,計算公式上文已給出,其中GINI值的計算代碼以下:

1         public static double getGini(double[] counts, int countAll)
2         {
3             double Gini = 1;
4             for (int i = 0; i < counts.Length; i++)
5             {
6                 Gini = Gini - Math.Pow(counts[i] / countAll, 2);
7             }
8             return Gini;
9         }
GINI值計算

3)進行分裂,同時對子節點進行迭代處理

其實就是遞歸的過程,對每個子節點執行findBestSplit方法進行分裂。

findBestSplit源代碼:

  1         public static Node findBestSplit(Node node,List<int> nums,int[] isUsed)
  2         {
  3             try
  4             {
  5                 //判斷是否繼續分裂
  6                 double totalShang = getGini(node.ClassCount, node.rowCount);
  7                 if (ifEnd(node, totalShang, isUsed))
  8                 {
  9                     return node;
 10                 }
 11                 #region 變量聲明
 12                 SplitInfo info = new SplitInfo();
 13                 info.initial();
 14                 int RowCount = nums.Count;                  //樣本總數
 15                 double jubuMax = 1;                         //局部最大熵
 16                 int splitPoint = 0;                         //分裂的點
 17                 double splitValue = 0;                      //分裂的值
 18                 #endregion
 19                 for (int i = 0; i < isUsed.Length - 1; i++)
 20                 {
 21                     if (isUsed[i] == 1)
 22                     {
 23                         continue;
 24                     }
 25                     #region 離散變量
 26                     if (type[i] == 0)
 27                     {
 28                         double[][] allCount = new double[allNum[i]][];
 29                         for (int j = 0; j < allCount.Length; j++)
 30                         {
 31                             allCount[j] = new double[classCount];
 32                         }
 33                         int[] countAllFeature = new int[allNum[i]];
 34                         List<int>[] temp = new List<int>[allNum[i]];
 35                         double[] allClassCount = node.ClassCount;     //全部類別的數量
 36                         for (int j = 0; j < temp.Length; j++)
 37                         {
 38                             temp[j] = new List<int>();
 39                         }
 40                         for (int j = 0; j < nums.Count; j++)
 41                         {
 42                             int index = Convert.ToInt32(allData[nums[j]][i]);
 43                             temp[index - 1].Add(nums[j]);
 44                             countAllFeature[index - 1]++;
 45                             allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
 46                         }
 47                         double allShang = 1;
 48                         int choose = 0;
 49 
 50                         double[][] jubuCount = new double[2][];
 51                         for (int k = 0; k < allCount.Length; k++)
 52                         {
 53                             if (temp[k].Count == 0)
 54                                 continue;
 55                             double JubuShang = 0;
 56                             double[][] tempCount = new double[2][];
 57                             tempCount[0] = allCount[k];
 58                             tempCount[1] = new double[allCount[0].Length];
 59                             for (int j = 0; j < tempCount[1].Length; j++)
 60                             {
 61                                 tempCount[1][j] = allClassCount[j] - allCount[k][j];
 62                             }
 63                             JubuShang = JubuShang + getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount;
 64                             int nodecount = RowCount - countAllFeature[k];
 65                             JubuShang = JubuShang + getGini(tempCount[1], nodecount) * nodecount / RowCount;
 66                             if (JubuShang < allShang)
 67                             {
 68                                 allShang = JubuShang;
 69                                 jubuCount = tempCount;
 70                                 choose = k;
 71                             }
 72                         }                        
 73                         if (allShang < jubuMax)
 74                         {
 75                             info.type = 0;
 76                             jubuMax = allShang;
 77                             info.class_Count = jubuCount;
 78                             info.temp[0] = temp[choose];
 79                             info.temp[1] = new List<int>();
 80                             info.features = new List<string>();
 81                             info.features.Add((choose + 1) + "");
 82                             info.features.Add("");
 83                             for (int j = 0; j < temp.Length; j++)
 84                             {
 85                                 if (j == choose)
 86                                     continue;
 87                                 for (int k = 0; k < temp[j].Count; k++)
 88                                 {
 89                                     info.temp[1].Add(temp[j][k]);
 90                                 }
 91                                 if (temp[j].Count != 0)
 92                                 {
 93                                     info.features[1] = info.features[1] + (j + 1) + ",";
 94                                 }
 95                             }
 96                             info.splitIndex = i;
 97                         }
 98                     }
 99                     #endregion
100                     #region 連續變量
101                     else
102                     {
103                         double[] leftCunt = new double[classCount];   
104 
105           //作節點各個類別的數量
106                         double[] rightCount = new double[classCount]; 
107 
108           //右節點各個類別的數量
109                         double[] count1 = new double[classCount];     
110 
111           //子集1的統計量
112                         double[] count2 = new double
113 
114 [node.ClassCount.Length];   //子集2的統計量
115                         for (int j = 0; j < node.ClassCount.Length; 
116 
117 j++)
118                         {
119                             count2[j] = node.ClassCount[j];
120                         }
121                         int all1 = 0;                                 
122 
123           //子集1的樣本量
124                         int all2 = nums.Count;                        
125 
126           //子集2的樣本量
127                         double lastValue = 0;                         
128 
129          //上一個記錄的類別
130                         double currentValue = 0;                      
131 
132          //當前類別
133                         double lastPoint = 0;                         
134 
135           //上一個點的值
136                         double currentPoint = 0;                      
137 
138           //當前點的值
139                         double[] values = new double[nums.Count];
140                         for (int j = 0; j < values.Length; j++)
141                         {
142                             values[j] = allData[nums[j]][i];
143                         }
144                         QSort(values, nums, 0, nums.Count - 1);
145                         double lianxuMax = 1;                         
146 
147           //連續型屬性的最大熵
148                         #region 尋找最佳的分割點
149                         for (int j = 0; j < nums.Count - 1; j++)
150                         {
151                             currentValue = allData[nums[j]][lieshu - 
152 
153 1];
154                             currentPoint = (allData[nums[j]][i]);
155                             if (j == 0)
156                             {
157                                 lastValue = currentValue;
158                                 lastPoint = currentPoint;
159                             }
160                             if (currentValue != lastValue && 
161 
162 currentPoint != lastPoint)
163                             {
164                                 double shang1 = getGini(count1, 
165 
166 all1);
167                                 double shang2 = getGini(count2, 
168 
169 all2);
170                                 double allShang = shang1 * all1 / 
171 
172 (all1 + all2) + shang2 * all2 / (all1 + all2);
173                                 //allShang = (totalShang - allShang);
174                                 if (lianxuMax > allShang)
175                                 {
176                                     lianxuMax = allShang;
177                                     for (int k = 0; k < 
178 
179 count1.Length; k++)
180                                     {
181                                         leftCunt[k] = count1[k];
182                                         rightCount[k] = count2[k];
183                                     }
184                                     splitPoint = j;
185                                     splitValue = (currentPoint + 
186 
187 lastPoint) / 2;
188                                 }
189                             }
190                             all1++;
191                             count1[Convert.ToInt32(currentValue) - 
192 
193 1]++;
194                             count2[Convert.ToInt32(currentValue) - 
195 
196 1]--;
197                             all2--;
198                             lastValue = currentValue;
199                             lastPoint = currentPoint;
200                         }
201                         #endregion
202                         #region 若是超過了局部值,重設
203                         if (lianxuMax < jubuMax)
204                         {
205                             info.type = 1;
206                             info.splitIndex = i;
207                             info.features=new List<string>()
208 
209 {splitValue+""};
210                             //finalPoint = splitPoint;
211                             jubuMax = lianxuMax;
212                             info.temp[0] = new List<int>();
213                             info.temp[1] = new List<int>();
214                             for (int k = 0; k < splitPoint; k++)
215                             {
216                                 info.temp[0].Add(nums[k]);
217                             }
218                             for (int k = splitPoint; k < nums.Count; 
219 
220 k++)
221                             {
222                                 info.temp[1].Add(nums[k]);
223                             }
224                             info.class_Count[0] = new double
225 
226 [leftCunt.Length];
227                             info.class_Count[1] = new double
228 
229 [leftCunt.Length];
230                             for (int k = 0; k < leftCunt.Length; k++)
231                             {
232                                 info.class_Count[0][k] = leftCunt[k];
233                                 info.class_Count[1][k] = rightCount
234 
235 [k];
236                             }
237                         }
238                         #endregion
239                     }
240                     #endregion
241                 }
242                 #region 沒有尋找到最佳的分裂點,則設置爲葉節點
243                 if (info.splitIndex == -1)
244                 {
245                     double[] finalCount = node.ClassCount;
246                     double max = finalCount[0];
247                     int result = 1;
248                     for (int i = 1; i < finalCount.Length; i++)
249                     {
250                         if (finalCount[i] > max)
251                         {
252                             max = finalCount[i];
253                             result = (i + 1);
254                         }
255                     }
256                     node.feature_Type="result";
257                     node.features=new List<String> { "" + result };
258                     return node;
259                 }
260                 #endregion
261                 #region 分裂
262                 int deep = node.deep;
263                 node.SplitFeature = ("" + info.splitIndex);
264                 List<Node> childNode = new List<Node>();
265                 int[][] used = new int[2][];
266                 used[0] = new int[isUsed.Length];
267                 used[1] = new int[isUsed.Length];
268                 for (int i = 0; i < isUsed.Length; i++)
269                 {
270                     used[0][i] = isUsed[i];
271                     used[1][i] = isUsed[i];
272                 }
273                 if (info.type == 0)
274                 {
275                     used[0][info.splitIndex] = 1;
276                     node.feature_Type = ("離散");
277                 }
278                 else
279                 {
280                     //used[info.splitIndex] = 0;
281                     node.feature_Type = ("連續");
282                 }
283                 List<int>[] rowIndex = info.temp;
284                 List<String> features = info.features;
285                 Node node1 = new Node();
286                 Node node2 = new Node();
287                 node1.setClassCount(info.class_Count[0]);
288                 node2.setClassCount(info.class_Count[1]);
289                 node1.rowCount = info.temp[0].Count;
290                 node2.rowCount = info.temp[1].Count;
291                 node1.deep = deep + 1;
292                 node2.deep = deep + 1;
293                 node1 = findBestSplit(node1, info.temp[0],used[0]);
294                 node2 = findBestSplit(node2, info.temp[1], used[1]);
295                 node.leafNode_Count = (node1.leafNode_Count
296 
297 +node2.leafNode_Count);
298                 node.leafWrong = (node1.leafWrong+node2.leafWrong);
299                 node.features = (features);
300                 childNode.Add(node1);
301                 childNode.Add(node2);
302                 node.childNodes = childNode;
303                 #endregion
304                 return node;
305             }
306             catch (Exception e)
307             {
308                 Console.WriteLine(e.StackTrace);
309                 return node;
310             }
311         }
節點選擇屬性和分裂

(4)剪枝

代價複雜度剪枝方法(CCP):

 1         public static void getSeries(Node node)
 2         {
 3             Stack<Node> nodeStack = new Stack<Node>();
 4             if (node != null)
 5             {
 6                 nodeStack.Push(node);
 7             }
 8             if (node.feature_Type == "result")
 9                 return;
10             List<Node> childs = node.childNodes;
11             for (int i = 0; i < childs.Count; i++)
12             {
13                 getSeries(node);
14             }
15         }
CCP代價複雜度剪枝

CART所有核心代碼:

  1         /// <summary>
  2         /// 判斷是否還須要分裂
  3         /// </summary>
  4         /// <param name="node"></param>
  5         /// <returns></returns>
  6         public static bool ifEnd(Node node, double shang,int[] isUsed)
  7         {
  8             try
  9             {
 10                 double[] count = node.ClassCount;
 11                 int rowCount = node.rowCount;
 12                 int maxResult = 0;
 13                 double maxRate = 0;
 14                 #region 數達到某一深度
 15                 int deep = node.deep;
 16                 if (deep >= 10)
 17                 {
 18                     maxResult = node.result + 1;
 19                     node.feature_Type="result";
 20                     node.features=new List<String>() { maxResult + "" 
 21 
 22 };
 23                     node.leafWrong=rowCount - Convert.ToInt32(count[maxResult-1]);
 24                     node.leafNode_Count=1;
 25                     return true;
 26                 }
 27                 #endregion
 28                 #region 純度(其實跟後面的有點重了,記得要修改)
 29                 //maxResult = 1;
 30                 //for (int i = 1; i < count.Length; i++)
 31                 //{
 32                 //    if (count[i] / rowCount >= 0.95)
 33                 //    {
 34                 //        node.feature_Type="result";
 35                 //        node.features=new List<String> { "" + (i + 
 36 
 37 1) };
 38                 //        node.leafNode_Count=1;
 39                 //        node.leafWrong=rowCount - Convert.ToInt32
 40 
 41 (count[i]);
 42                 //        return true;
 43                 //    }
 44                 //}
 45                 #endregion
 46                 #region 熵爲0
 47                 if (shang == 0)
 48                 {
 49                     maxRate = count[0] / rowCount;
 50                     maxResult = 1;
 51                     for (int i = 1; i < count.Length; i++)
 52                     {
 53                         if (count[i] / rowCount >= maxRate)
 54                         {
 55                             maxRate = count[i] / rowCount;
 56                             maxResult = i + 1;
 57                         }
 58                     }
 59                     node.feature_Type="result";
 60                     node.features=new List<String> { maxResult + "" 
 61 
 62 };
 63                     node.leafWrong=rowCount - Convert.ToInt32(count
 64 
 65 [maxResult - 1]);
 66                     node.leafNode_Count=1;
 67                     return true;
 68                 }
 69                 #endregion
 70                 #region 屬性已經分完
 71                 //int[] isUsed = node.getUsed();
 72                 bool flag = true;
 73                 for (int i = 0; i < isUsed.Length - 1; i++)
 74                 {
 75                     if (isUsed[i] == 0)
 76                     {
 77                         flag = false;
 78                         break;
 79                     }
 80                 }
 81                 if (flag)
 82                 {
 83                     maxRate = count[0] / rowCount;
 84                     maxResult = 1;
 85                     for (int i = 1; i < count.Length; i++)
 86                     {
 87                         if (count[i] / rowCount >= maxRate)
 88                         {
 89                             maxRate = count[i] / rowCount;
 90                             maxResult = i + 1;
 91                         }
 92                     }
 93                     node.feature_Type=("result");
 94                     node.features=(new List<String> { "" + 
 95 
 96 (maxResult) });
 97                     node.leafWrong=(rowCount - Convert.ToInt32(count
 98 
 99 [maxResult - 1]));
100                     node.leafNode_Count=(1);
101                     return true;
102                 }
103                 #endregion
104                 #region 幾點數少於100
105                 if (rowCount < Limit_Node)
106                 {
107                     maxRate = count[0] / rowCount;
108                     maxResult = 1;
109                     for (int i = 1; i < count.Length; i++)
110                     {
111                         if (count[i] / rowCount >= maxRate)
112                         {
113                             maxRate = count[i] / rowCount;
114                             maxResult = i + 1;
115                         }
116                     }
117                     node.feature_Type="result";
118                     node.features=new List<String> { "" + (maxResult) 
119 
120 };
121                     node.leafWrong=rowCount - Convert.ToInt32(count
122 
123 [maxResult - 1]);
124                     node.leafNode_Count=1;
125                     return true;
126                 }
127                 #endregion
128                 return false;
129             }
130             catch (Exception e)
131             {
132                 return false;
133             }
134         }
135         #region 排序算法
136         public static void InsertSort(double[] values, List<int> arr, 
137 
138 int StartIndex, int endIndex)
139         {
140             for (int i = StartIndex + 1; i <= endIndex; i++)
141             {
142                 int key = arr[i];
143                 double init = values[i];
144                 int j = i - 1;
145                 while (j >= StartIndex && values[j] > init)
146                 {
147                     arr[j + 1] = arr[j];
148                     values[j + 1] = values[j];
149                     j--;
150                 }
151                 arr[j + 1] = key;
152                 values[j + 1] = init;
153             }
154         }
155         static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high)
156         {
157             int mid = low + ((high - low) >> 1);//計算數組中間的元素的下標  
158 
159             //使用三數取中法選擇樞軸  
160             if (values[mid] > values[high])//目標: arr[mid] <= arr[high]  
161             {
162                 swap(values, arr, mid, high);
163             }
164             if (values[low] > values[high])//目標: arr[low] <= arr[high]  
165             {
166                 swap(values, arr, low, high);
167             }
168             if (values[mid] > values[low]) //目標: arr[low] >= arr[mid]  
169             {
170                 swap(values, arr, mid, low);
171             }
172             //此時,arr[mid] <= arr[low] <= arr[high]  
173             return low;
174             //low的位置上保存這三個位置中間的值  
175             //分割時能夠直接使用low位置的元素做爲樞軸,而不用改變分割函數了  
176         }
177         static void swap(double[] values, List<int> arr, int t1, int t2)
178         {
179             double temp = values[t1];
180             values[t1] = values[t2];
181             values[t2] = temp;
182             int key = arr[t1];
183             arr[t1] = arr[t2];
184             arr[t2] = key;
185         }
186         static void QSort(double[] values, List<int> arr, int low, int high)
187         {
188             int first = low;
189             int last = high;
190 
191             int left = low;
192             int right = high;
193 
194             int leftLen = 0;
195             int rightLen = 0;
196 
197             if (high - low + 1 < 10)
198             {
199                 InsertSort(values, arr, low, high);
200                 return;
201             }
202 
203             //一次分割 
204             int key = SelectPivotMedianOfThree(values, arr, low, 
205 
206 high);//使用三數取中法選擇樞軸 
207             double inti = values[key];
208             int currentKey = arr[key];
209 
210             while (low < high)
211             {
212                 while (high > low && values[high] >= inti)
213                 {
214                     if (values[high] == inti)//處理相等元素  
215                     {
216                         swap(values, arr, right, high);
217                         right--;
218                         rightLen++;
219                     }
220                     high--;
221                 }
222                 arr[low] = arr[high];
223                 values[low] = values[high];
224                 while (high > low && values[low] <= inti)
225                 {
226                     if (values[low] == inti)
227                     {
228                         swap(values, arr, left, low);
229                         left++;
230                         leftLen++;
231                     }
232                     low++;
233                 }
234                 arr[high] = arr[low];
235                 values[high] = values[low];
236             }
237             arr[low] = currentKey;
238             values[low] = values[key];
239             //一次快排結束  
240             //把與樞軸key相同的元素移到樞軸最終位置周圍  
241             int i = low - 1;
242             int j = first;
243             while (j < left && values[i] != inti)
244             {
245                 swap(values, arr, i, j);
246                 i--;
247                 j++;
248             }
249             i = low + 1;
250             j = last;
251             while (j > right && values[i] != inti)
252             {
253                 swap(values, arr, i, j);
254                 i++;
255                 j--;
256             }
257             QSort(values, arr, first, low - 1 - leftLen);
258             QSort(values, arr, low + 1 + rightLen, last);
259         }
260         #endregion
261         /// <summary>
262         /// 尋找最佳的分裂點
263         /// </summary>
264         /// <param name="num"></param>
265         /// <param name="node"></param>
266         public static Node findBestSplit(Node node,List<int> nums,int[] isUsed)
267         {
268             try
269             {
270                 //判斷是否繼續分裂
271                 double totalShang = getGini(node.ClassCount, node.rowCount);
272                 if (ifEnd(node, totalShang, isUsed))
273                 {
274                     return node;
275                 }
276                 #region 變量聲明
277                 SplitInfo info = new SplitInfo();
278                 info.initial();
279                 int RowCount = nums.Count;                  //樣本總數
280                 double jubuMax = 1;                         //局部最大熵
281                 int splitPoint = 0;                         //分裂的點
282                 double splitValue = 0;                      //分裂的值
283                 #endregion
284                 for (int i = 0; i < isUsed.Length - 1; i++)
285                 {
286                     if (isUsed[i] == 1)
287                     {
288                         continue;
289                     }
290                     #region 離散變量
291                     if (type[i] == 0)
292                     {
293                         double[][] allCount = new double[allNum[i]][];
294                         for (int j = 0; j < allCount.Length; j++)
295                         {
296                             allCount[j] = new double[classCount];
297                         }
298                         int[] countAllFeature = new int[allNum[i]];
299                         List<int>[] temp = new List<int>[allNum[i]];
300                         double[] allClassCount = node.ClassCount;     //全部類別的數量
301                         for (int j = 0; j < temp.Length; j++)
302                         {
303                             temp[j] = new List<int>();
304                         }
305                         for (int j = 0; j < nums.Count; j++)
306                         {
307                             int index = Convert.ToInt32(allData[nums[j]][i]);
308                             temp[index - 1].Add(nums[j]);
309                             countAllFeature[index - 1]++;
310                             allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
311                         }
312                         double allShang = 1;
313                         int choose = 0;
314 
315                         double[][] jubuCount = new double[2][];
316                         for (int k = 0; k < allCount.Length; k++)
317                         {
318                             if (temp[k].Count == 0)
319                                 continue;
320                             double JubuShang = 0;
321                             double[][] tempCount = new double[2][];
322                             tempCount[0] = allCount[k];
323                             tempCount[1] = new double[allCount[0].Length];
324                             for (int j = 0; j < tempCount[1].Length; j++)
325                             {
326                                 tempCount[1][j] = allClassCount[j] - allCount[k][j];
327                             }
328                             JubuShang = JubuShang + getGini(tempCount[0], countAllFeature[k]) * countAllFeature[k] / RowCount;
329                             int nodecount = RowCount - countAllFeature[k];
330                             JubuShang = JubuShang + getGini(tempCount[1], nodecount) * nodecount / RowCount;
331                             if (JubuShang < allShang)
332                             {
333                                 allShang = JubuShang;
334                                 jubuCount = tempCount;
335                                 choose = k;
336                             }
337                         }                        
338                         if (allShang < jubuMax)
339                         {
340                             info.type = 0;
341                             jubuMax = allShang;
342                             info.class_Count = jubuCount;
343                             info.temp[0] = temp[choose];
344                             info.temp[1] = new List<int>();
345                             info.features = new List<string>();
346                             info.features.Add((choose + 1) + "");
347                             info.features.Add("");
348                             for (int j = 0; j < temp.Length; j++)
349                             {
350                                 if (j == choose)
351                                     continue;
352                                 for (int k = 0; k < temp[j].Count; k++)
353                                 {
354                                     info.temp[1].Add(temp[j][k]);
355                                 }
356                                 if (temp[j].Count != 0)
357                                 {
358                                     info.features[1] = info.features[1] + (j + 1) + ",";
359                                 }
360                             }
361                             info.splitIndex = i;
362                         }
363                     }
364                     #endregion
365                     #region 連續變量
366                     else
367                     {
368                         double[] leftCunt = new double[classCount];   
369 
370           //作節點各個類別的數量
371                         double[] rightCount = new double[classCount]; 
372 
373           //右節點各個類別的數量
374                         double[] count1 = new double[classCount];     
375 
376           //子集1的統計量
377                         double[] count2 = new double
378 
379 [node.ClassCount.Length];   //子集2的統計量
380                         for (int j = 0; j < node.ClassCount.Length; 
381 
382 j++)
383                         {
384                             count2[j] = node.ClassCount[j];
385                         }
386                         int all1 = 0;                                 
387 
388           //子集1的樣本量
389                         int all2 = nums.Count;                        
390 
391           //子集2的樣本量
392                         double lastValue = 0;                         
393 
394          //上一個記錄的類別
395                         double currentValue = 0;                      
396 
397          //當前類別
398                         double lastPoint = 0;                         
399 
400           //上一個點的值
401                         double currentPoint = 0;                      
402 
403           //當前點的值
404                         double[] values = new double[nums.Count];
405                         for (int j = 0; j < values.Length; j++)
406                         {
407                             values[j] = allData[nums[j]][i];
408                         }
409                         QSort(values, nums, 0, nums.Count - 1);
410                         double lianxuMax = 1;                         
411 
412           //連續型屬性的最大熵
413                         #region 尋找最佳的分割點
414                         for (int j = 0; j < nums.Count - 1; j++)
415                         {
416                             currentValue = allData[nums[j]][lieshu - 
417 
418 1];
419                             currentPoint = (allData[nums[j]][i]);
420                             if (j == 0)
421                             {
422                                 lastValue = currentValue;
423                                 lastPoint = currentPoint;
424                             }
425                             if (currentValue != lastValue && 
426 
427 currentPoint != lastPoint)
428                             {
429                                 double shang1 = getGini(count1, 
430 
431 all1);
432                                 double shang2 = getGini(count2, 
433 
434 all2);
435                                 double allShang = shang1 * all1 / 
436 
437 (all1 + all2) + shang2 * all2 / (all1 + all2);
438                                 //allShang = (totalShang - allShang);
439                                 if (lianxuMax > allShang)
440                                 {
441                                     lianxuMax = allShang;
442                                     for (int k = 0; k < 
443 
444 count1.Length; k++)
445                                     {
446                                         leftCunt[k] = count1[k];
447                                         rightCount[k] = count2[k];
448                                     }
449                                     splitPoint = j;
450                                     splitValue = (currentPoint + 
451 
452 lastPoint) / 2;
453                                 }
454                             }
455                             all1++;
456                             count1[Convert.ToInt32(currentValue) - 
457 
458 1]++;
459                             count2[Convert.ToInt32(currentValue) - 
460 
461 1]--;
462                             all2--;
463                             lastValue = currentValue;
464                             lastPoint = currentPoint;
465                         }
466                         #endregion
467                         #region 若是超過了局部值,重設
468                         if (lianxuMax < jubuMax)
469                         {
470                             info.type = 1;
471                             info.splitIndex = i;
472                             info.features=new List<string>()
473 
474 {splitValue+""};
475                             //finalPoint = splitPoint;
476                             jubuMax = lianxuMax;
477                             info.temp[0] = new List<int>();
478                             info.temp[1] = new List<int>();
479                             for (int k = 0; k < splitPoint; k++)
480                             {
481                                 info.temp[0].Add(nums[k]);
482                             }
483                             for (int k = splitPoint; k < nums.Count; 
484 
485 k++)
486                             {
487                                 info.temp[1].Add(nums[k]);
488                             }
489                             info.class_Count[0] = new double
490 
491 [leftCunt.Length];
492                             info.class_Count[1] = new double
493 
494 [leftCunt.Length];
495                             for (int k = 0; k < leftCunt.Length; k++)
496                             {
497                                 info.class_Count[0][k] = leftCunt[k];
498                                 info.class_Count[1][k] = rightCount
499 
500 [k];
501                             }
502                         }
503                         #endregion
504                     }
505                     #endregion
506                 }
507                 #region 沒有尋找到最佳的分裂點,則設置爲葉節點
508                 if (info.splitIndex == -1)
509                 {
510                     double[] finalCount = node.ClassCount;
511                     double max = finalCount[0];
512                     int result = 1;
513                     for (int i = 1; i < finalCount.Length; i++)
514                     {
515                         if (finalCount[i] > max)
516                         {
517                             max = finalCount[i];
518                             result = (i + 1);
519                         }
520                     }
521                     node.feature_Type="result";
522                     node.features=new List<String> { "" + result };
523                     return node;
524                 }
525                 #endregion
526                 #region 分裂
527                 int deep = node.deep;
528                 node.SplitFeature = ("" + info.splitIndex);
529                 List<Node> childNode = new List<Node>();
530                 int[][] used = new int[2][];
531                 used[0] = new int[isUsed.Length];
532                 used[1] = new int[isUsed.Length];
533                 for (int i = 0; i < isUsed.Length; i++)
534                 {
535                     used[0][i] = isUsed[i];
536                     used[1][i] = isUsed[i];
537                 }
538                 if (info.type == 0)
539                 {
540                     used[0][info.splitIndex] = 1;
541                     node.feature_Type = ("離散");
542                 }
543                 else
544                 {
545                     //used[info.splitIndex] = 0;
546                     node.feature_Type = ("連續");
547                 }
548                 List<int>[] rowIndex = info.temp;
549                 List<String> features = info.features;
550                 Node node1 = new Node();
551                 Node node2 = new Node();
552                 node1.setClassCount(info.class_Count[0]);
553                 node2.setClassCount(info.class_Count[1]);
554                 node1.rowCount = info.temp[0].Count;
555                 node2.rowCount = info.temp[1].Count;
556                 node1.deep = deep + 1;
557                 node2.deep = deep + 1;
558                 node1 = findBestSplit(node1, info.temp[0],used[0]);
559                 node2 = findBestSplit(node2, info.temp[1], used[1]);
560                 node.leafNode_Count = (node1.leafNode_Count
561 
562 +node2.leafNode_Count);
563                 node.leafWrong = (node1.leafWrong+node2.leafWrong);
564                 node.features = (features);
565                 childNode.Add(node1);
566                 childNode.Add(node2);
567                 node.childNodes = childNode;
568                 #endregion
569                 return node;
570             }
571             catch (Exception e)
572             {
573                 Console.WriteLine(e.StackTrace);
574                 return node;
575             }
576         }
577         /// <summary>
578         /// GINI值
579         /// </summary>
580         /// <param name="counts"></param>
581         /// <param name="countAll"></param>
582         /// <returns></returns>
583         public static double getGini(double[] counts, int countAll)
584         {
585             double Gini = 1;
586             for (int i = 0; i < counts.Length; i++)
587             {
588                 Gini = Gini - Math.Pow(counts[i] / countAll, 2);
589             }
590             return Gini;
591         }
592         #region CCP剪枝
593         public static void getSeries(Node node)
594         {
595             Stack<Node> nodeStack = new Stack<Node>();
596             if (node != null)
597             {
598                 nodeStack.Push(node);
599             }
600             if (node.feature_Type == "result")
601                 return;
602             List<Node> childs = node.childNodes;
603             for (int i = 0; i < childs.Count; i++)
604             {
605                 getSeries(node);
606             }
607         }
608         /// <summary>
609         /// 遍歷剪枝
610         /// </summary>
611         /// <param name="node"></param>
612         public static Node getNode1(Node node, Node nodeCut)
613         {
614             
615             //List<Node> childNodes = node.getChild();
616             //double min = 100000;
617             ////Node nodeCut = new Node();
618             //double temp = 0;
619             //for (int i = 0; i < childNodes.Count; i++)
620             //{
621             //    if (childNodes[i].getType() != "result")
622             //    {
623             //        //if (!cutTree(childNodes[i]))
624             //        temp = min;
625             //        min = cutTree(childNodes[i], min);
626             //        if (min < temp)
627             //            nodeCut = childNodes[i];
628             //        getNode1(childNodes[i], nodeCut);
629             //    }
630             //}
631             //node.setChildNode(childNodes);
632             return null;
633         }
634         /// <summary>
635         /// 對每個節點剪枝
636         /// </summary>
637         public static double cutTree(Node node, double minA)
638         {
639             int rowCount = node.rowCount;
640             double leaf = node.getErrorCount();
641             double[] values = getError1(node, 0, 0);
642             double treeWrong = values[0];
643             double son = values[1];
644             double rate = (leaf - treeWrong) / (son - 1);
645             if (minA > rate)
646                 minA = rate;
647             //double var = Math.Sqrt(treeWrong * (1 - treeWrong / 
648 
649 rowCount));
650             //double panbie = treeWrong + var - leaf;
651             //if (panbie > 0)
652             //{
653             //    node.setFeatureType("result");
654             //    node.setChildNode(null);
655             //    int result = (node.getResult() + 1);
656             //    node.setFeatures(new List<String>() { "" + result 
657 
658 });
659             //    //return true;
660             //}
661             return minA;
662         }
663         /// <summary>
664         /// 得到子樹的錯誤個數
665         /// </summary>
666         /// <param name="node"></param>
667         /// <returns></returns>
668         public static double[] getError1(Node node, double treeError, 
669 
670 double son)
671         {
672             if (node.feature_Type == "result")
673             {
674 
675                 double error = node.getErrorCount();
676                 son++;
677                 return new double[] { treeError + error, son };
678             }
679             List<Node> childNode = node.childNodes;
680             for (int i = 0; i < childNode.Count; i++)
681             {
682                 double[] values = getError1(childNode[i], treeError, 
683 
684 son);
685                 treeError = values[0];
686                 son = values[1];
687             }
688             return new double[] { treeError, son };
689         }
690         #endregion
CART核心代碼

總結:

(1)CART是一棵二叉樹,每一次分裂會產生兩個子節點,對於連續性的數據,直接採用與C4.5類似的處理方法,對於離散型數據,選擇最優的兩種離散值組合方法。

(2)CART既能是分類數,又能是二叉樹。若是是分類樹,將選擇可以最小化分裂後節點GINI值的分裂屬性;若是是迴歸樹,選擇可以最小化兩個節點樣本方差的分裂屬性。

(3)CART跟C4.5同樣,須要進行剪枝,採用CCP(代價複雜度的剪枝方法)。

相關文章
相關標籤/搜索