決策樹系列(四)——C4.5

預備知識:決策樹ID3html

      如上一篇文章所述,ID3方法主要有幾個缺點:一是採用信息增益進行數據分裂,準確性不如信息增益率;二是不能對連續數據進行處理,只能經過連續數據離散化進行處理;三是沒有采用剪枝的策略,決策樹的結構可能會過於複雜,可能會出現過擬合的狀況。node

      C4.5在ID3的基礎上對上述三個方面進行了相應的改進:算法

      a)  C4.5對節點進行分裂時採用信息增益率做爲分裂的依據;數組

      b)  可以對連續數據進行處理;ide

      c)  C4.5採用剪枝的策略,對徹底生長的決策樹進行剪枝處理,必定程度上下降過擬合的影響。函數

1.採用信息增益率做爲分裂的依據測試

     信息增益率的計算公式爲:this

      其中表示信息增益,表示分裂子節點數據量的信息增益,計算公式爲:spa

      其中m表示節點的數量,Ni表示第i個節點的數據量,N表示父親節點的數據量,說白了,實際上是分裂節點的熵。設計

信息增益率越大,說明分裂的效果越好。

      以一個實際的例子說明C4.5如何經過信息增益率選擇分裂的屬性:

                                表1 原始數據表

當每天氣

溫度

溼度

日期

逛街

25

50

工做日

21

48

工做日

18

70

週末

28

41

週末

8

65

工做日

18

43

工做日

24

56

週末

18

76

週末

31

61

週末

6

43

週末

15

55

工做日

4

58

工做日

     以當每天氣爲例:

     一共有三個屬性值,晴、陰、雨,一共分裂成三個子節點。

 

 

  根據上述公式,能夠計算信息增益率以下:

     

 

 

 

 

 

 

 

因此使用天氣屬性進行分裂能夠獲得信息增益率0.44。

2.對連續型屬性進行處理

      C4.5處理離散型屬性的方式與ID3一致,新增對連續型屬性的處理。處理方式是先根據連續型屬性進行排序,而後採用一刀切的方式將數據砍成兩半。

那麼如何選擇切割點呢?很簡單,直接計算每個切割點切割後的信息增益,而後選擇使分裂效果最優的切割點。以溫度爲例:

 

      從上圖能夠看出,理論上來說,N條數據就有N-1個切割點,爲了選取最優的切割墊,要計算按每一次切割的信息增益,計算量是比較大的,那麼有沒有簡化的方法呢?有,注意到,其實有些切割點是很明顯能夠排除的。好比說上圖右側的第2條和第3條記錄,二者的類標籤(逛街)都是「是」,若是從這裏切割的話,就將兩個原本相同的類分開了,確定不會比將他們歸爲一類的切分方法好,所以,能夠經過去除先後兩個類標籤相同的切割點以簡化計算的複雜度,以下圖所示:

 

      從圖中能夠看出,最終切割點的數目從原來的11個減小到如今的6個,下降了計算的複雜度。

      肯定了分割點以後,接下來就是選擇最優的分割點了,注意,對連續型屬性是採用信息增益進行內部擇優的,由於若是使用信息增益率進行分裂會出現傾向於選擇分割先後兩個節點數據量相差最大的分割點,爲了不這種狀況,選擇信息增益選擇分割點。選擇了最優的分割點以後,再計算信息增益率跟其餘的屬性進行比較,肯定最優的分裂屬性。

3. 剪枝

      決策樹只已經提到,剪枝是在徹底生長的決策樹的基礎上,對生長後分類效果不佳的子樹進行修剪,減少決策樹的複雜度,下降過擬合的影響。

      C4.5採用悲觀剪枝方法(PEP)。悲觀剪枝認爲若是決策樹的精度在剪枝先後沒有影響的話,則進行剪枝。怎樣纔算是沒有影響?若是剪枝後的偏差小於剪枝前經度的上限,則說明剪枝後的效果與更佳,此時須要子樹進行剪枝操做。

進行剪枝必須知足的條件:

其中:

 表示子樹的偏差;

 表示葉子節點的偏差;

       令子樹偏差的經度知足二項分佈,根據二項分佈的性質,,其中,N爲子樹的數據量;一樣,葉子節點的偏差

      上述公式中,0.5表示修正因子。因爲對父節點進行分裂總會獲得比父節點分類結果更好的效果,所以,所以從理論上來講,父節點的偏差老是不小於孩子節點的偏差,所以須要進行修正,給每個節點都加上0.5的修正所以,在計算偏差的時候,子節點因爲加上了修正的因子,就沒法保證總偏差老是低於父節點。

算例:

 

 

 

 

 

 

 

因爲,因此應該進行剪枝。

程序設計及源代碼(C#版)

程序的設計過程

(1)數據格式

         對原始的數據進行數字化處理,並以二維數據的形式存儲,每一行表示一條記錄,前n-1列表示屬性,最後一列表示分類的標籤。

         如表1的數據能夠轉化爲表2:

                                                                               表2 初始化後的數據

當每天氣

溫度

溼度

季節

明每天氣

1

25

50

1

1

2

21

48

1

2

2

18

70

1

3

1

28

41

2

1

3

8

65

3

2

1

18

43

2

1

2

24

56

4

1

3

18

76

4

2

3

31

61

2

1

2

6

43

3

3

1

15

55

4

2

3

4

58

3

3

         其中,對於「當每天氣」屬性,數字{1,2,3}分別表示{晴,陰,雨};對於「季節」屬性{1,2,3,4}分別表示{春天、夏天、冬天、秋天};對於類標籤「明每天氣」,數字{1,2,3}分別表示{晴、陰、雨}。

代碼以下所示:

         static double[][] allData;                              //存儲進行訓練的數據

    static List<String>[] featureValues;                    //離散屬性對應的離散值

featureValues是鏈表數組,數組的長度爲屬性的個數,數組的每一個元素爲該屬性的離散值鏈表。

(2)兩個類:節點類和分裂信息

a)節點類Node

該類表示一個節點,屬性包括節點選擇的分裂屬性、節點的輸出類、孩子節點、深度等。注意,與ID3中相比,新增了兩個屬性:leafWrong和leafNode_Count分別表示葉子節點的總分類偏差和葉子節點的個數,主要是爲了方便剪枝。

 1 class Node
 2  {
 3         /// <summary>
 4         /// 各個子節點對應的取值
 5         /// </summary>
 6         //public List<String> features;
 7         public List<String> features{get;set;}
 8         /// <summary>
 9         /// 分裂屬性的數據類型(1:連續 0:離散)
10         /// </summary>
11         public String feature_Type {get;set;}
12         /// <summary>
13         /// 分裂屬性列的下標
14         /// </summary>
15         public String SplitFeature {get;set;}
16         /// <summary>
17         /// 各種別的數量統計
18         /// </summary>
19         public double[] ClassCount {get;set;}
20         /// <summary>
21         /// 數據量
22         /// </summary>
23         public int rowCount { get; set; }
24         /// <summary>
25         /// 各個子節點
26         /// </summary>
27         public List<Node> childNodes {get;set;}
28         /// <summary>
29         /// 父親節點
30         /// </summary>
31         public Node Parent {get;set;}
32         /// <summary>
33         /// 該節點佔比最大的類別
34         /// </summary>
35         public String finalResult {get;set;}
36         /// <summary>
37         /// 數的深度
38         /// </summary>
39         public int deep {get;set;}
40         /// <summary>
41         /// 節點佔比最大類的標號
42         /// </summary>
43         public int result {get;set;}
44         /// <summary>
45         /// 子節點的錯誤數
46         /// </summary>
47         public int leafWrong {get;set;}
48         /// <summary>
49         /// 子節點的數目
50         /// </summary>
51         public int leafNode_Count {get;set;}
52 
53         public double getErrorCount()
54         {
55             return rowCount - ClassCount[result];
56         }
57         #region
58         public void setClassCount(double[] count)
59         {
60             this.ClassCount = count;
61             double max = ClassCount[0];
62             int result = 0;
63             for (int i = 1; i < ClassCount.Length; i++)
64             {
65                 if (max < ClassCount[i])
66                 {
67                     max = ClassCount[i];
68                     result = i;
69                 }
70             }
71             this.result = result;
72         }
73         #endregion
74 }
View Code

b)分裂信息類,該類存儲節點進行分裂的信息,包括各個子節點的行座標、子節點各個類的數目、該節點分裂的屬性、屬性的類型等。

 1     class SplitInfo
 2     {
 3         /// <summary>
 4         /// 分裂的屬性下標
 5         /// </summary>
 6         public int splitIndex { get; set; }
 7         /// <summary>
 8         /// 數據類型
 9         /// </summary>
10         public int type { get; set; }
11         /// <summary>
12         /// 分裂屬性的取值
13         /// </summary>
14         public List<String> features { get; set; }
15         /// <summary>
16         /// 各個節點的行座標鏈表
17         /// </summary>
18         public List<int>[] temp { get; set; }
19         /// <summary>
20         /// 每一個節點各種的數目
21         /// </summary>
22         public double[][] class_Count { get; set; }
23     }
View Code

主方法findBestSplit(Node node,List<int> nums,int[] isUsed),該方法對節點進行分裂

其中:

node表示即將進行分裂的節點;

nums表示節點數據的行座標列表;

isUsed表示到該節點位置全部屬性的使用狀況;

findBestSplit的這個方法主要有如下幾個組成部分:

1)節點分裂中止的斷定

節點分裂條件如上文所述,源代碼以下:

 1         public static bool ifEnd(Node node, double entropy,int[] isUsed)
 2         {
 3             try
 4             {
 5                 double[] count = node.ClassCount;
 6                 int rowCount = node.rowCount;
 7                 int maxResult = 0;
 8                 #region 數達到某一深度
 9                 int deep = node.deep;
10                 if (deep >= maxDeep)
11                 {
12                     maxResult = node.result + 1;
13                     node.feature_Type=("result");
14                     node.features=(new List<String>() { maxResult + "" });
15                     node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
16                     node.leafNode_Count = 1;
17                     return true;
18                 }
19                 #endregion
20                 #region 純度(其實跟後面的有點重了,記得要修改)
21                 //maxResult = 1;
22                 //for (int i = 1; i < count.Length; i++)
23                 //{
24                 //    if (count[i] / rowCount >= 0.95)
25                 //    {
26                 //        node.feature_Type=("result");
27                 //        node.features=(new List<String> { "" + (i + 1) });
28                 //        node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
29                 //        node.leafNode_Count = 1;
30                 //        return true;
31                 //    }
32                 //}
33                 #endregion
34                 #region 熵爲0
35                 if (entropy == 0)
36                 {
37                     maxResult = node.result+1;
38                     node.feature_Type=("result");
39                     node.features=(new List<String> { maxResult + "" });
40                     node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
41                     node.leafNode_Count = 1;
42                     return true;
43                 }
44                 #endregion
45                 #region 屬性已經分完
46                 bool flag = true;
47                 for (int i = 0; i < isUsed.Length - 1; i++)
48                 {
49                     if (isUsed[i] == 0)
50                     {
51                         flag = false;
52                         break;
53                     }
54                 }
55                 if (flag)
56                 {
57                     maxResult = node.result+1;
58                     node.feature_Type=("result");
59                     node.features=(new List<String> { "" + (maxResult) });
60                     node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
61                     node.leafNode_Count = 1;
62                     return true;
63                 }
64                 #endregion
65                 #region 數據量少於100
66                 if (rowCount < Limit_Node)
67                 {
68                     maxResult = node.result+1;
69                     node.feature_Type=("result");
70                     node.features=(new List<String> { "" + (maxResult) });
71                     node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
72                     node.leafNode_Count = 1;
73                     return true;
74                 }
75                 #endregion
76                 return false;
77             }
78             catch (Exception e)
79             {
80                 return false;
81             }
82         }
View Code

2)尋找最優的分裂屬性

尋找最優的分裂屬性須要計算每個分裂屬性分裂後的信息增益率,計算公式上文已給出,其中熵的計算代碼以下:

 1         public static double CalEntropy(double[] counts, int countAll)
 2         {
 3             try
 4             {
 5                 double allShang = 0;
 6                 for (int i = 0; i < counts.Length; i++)
 7                 {
 8                     if (counts[i] == 0)
 9                     {
10                         continue;
11                     }
12                     double rate = counts[i] / countAll;
13                     allShang = allShang + rate * Math.Log(rate, 2);
14                 }
15                 return allShang;
16             }
17             catch (Exception e)
18             {
19                 return 0;
20             }
21         }
View Code

3)進行分裂,同時對子節點進行迭代處理

其實就是遞歸的工程,對每個子節點執行findBestSplit方法進行分裂。

findBestSplit源代碼:

  1         public static Node findBestSplit(Node node, List<int> nums, int[] isUsed)
  2         {
  3             try
  4             {
  5                 //判斷是否繼續分裂
  6                 double totalShang = CalEntropy(node.ClassCount, node.rowCount);
  7                 if (ifEnd(node, totalShang,isUsed))
  8                 {
  9                     return node;
 10                 }
 11                 #region 變量聲明
 12                 SplitInfo info = new SplitInfo();
 13                 int RowCount = nums.Count;                  //樣本總數
 14                 double jubuMax = 0;                         //局部最大熵
 15                 #endregion
 16                 for (int i = 0; i < isUsed.Length - 1; i++)
 17                 {
 18                     if (isUsed[i] == 1)
 19                     {
 20                         continue;
 21                     }
 22                     #region 離散變量
 23                     if (type[i] == 0)
 24                     {
 25                         int[] allFeatureCount = new int[0];         //全部類別的數量
 26                         double[][] allCount = new double[allNum[i]][];
 27                         for (int j = 0; j < allCount.Length; j++)
 28                         {
 29                             allCount[j] = new double[classCount];
 30                         }
 31                         int[] countAllFeature = new int[allNum[i]];
 32                         List<int>[] temp = new List<int>[allNum[i]];
 33                         for (int j = 0; j < temp.Length; j++)
 34                         {
 35                             temp[j] = new List<int>();
 36                         }
 37                         for (int j = 0; j < nums.Count; j++)
 38                         {
 39                             int index = Convert.ToInt32(allData[nums[j]][i]);
 40                             temp[index - 1].Add(nums[j]);
 41                             countAllFeature[index - 1]++;
 42                             allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
 43                         }
 44                         double allShang = 0;
 45                         double chushu = 0;
 46                         for (int j = 0; j < allCount.Length; j++)
 47                         {
 48                             allShang = allShang + CalEntropy(allCount[j], countAllFeature[j]) * countAllFeature[j] / RowCount;
 49                             if (countAllFeature[j] > 0)
 50                             {
 51                                 double rate = countAllFeature[j] / Convert.ToDouble(RowCount);
 52                                 chushu = chushu + rate * Math.Log(rate, 2);
 53                             }
 54                         }
 55                         allShang = (-totalShang + allShang);
 56                         if (allShang > jubuMax)
 57                         {
 58                             info.features = new List<string>();
 59                             info.type = 0;
 60                             info.temp = temp;
 61                             info.splitIndex = i;
 62                             info.class_Count = allCount;
 63                             jubuMax = allShang;
 64                             allFeatureCount = countAllFeature;
 65                         }
 66                     }
 67                     #endregion
 68                     #region 連續變量
 69                     else
 70                     {
 71                         double[] leftCount = new double[classCount];          //作節點各個類別的數量
 72                         double[] rightCount = new double[classCount];        //右節點各個類別的數量
 73                         double[] count1 = new double[classCount];            //子集1的統計量
 74                         //double[] count2 = new double[node.getCount().Length];   //子集2的統計量
 75                         double[] count2 = new double[node.ClassCount.Length];   //子集2的統計量
 76                         for (int j = 0; j < node.ClassCount.Length; j++)
 77                         {
 78                             count2[j] = node.ClassCount[j];
 79                         }
 80                         int all1 = 0;                                           //子集1的樣本量
 81                         int all2 = nums.Count;                                  //子集2的樣本量
 82                         double lastValue = 0;                                  //上一個記錄的類別
 83                         double currentValue = 0;                               //當前類別
 84                         double lastPoint = 0;                                   //上一個點的值
 85                         double currentPoint = 0;                                //當前點的值
 86                         int splitPoint = 0;
 87                         double splitValue = 0;
 88                         double[] values = new double[nums.Count];
 89                         for (int j = 0; j < values.Length; j++)
 90                         {
 91                             values[j] = allData[nums[j]][i];
 92                         }
 93                         QSort(values, nums, 0, nums.Count - 1);
 94                         double chushu = 0;
 95                         double lianxuMax = 0;                       //連續型屬性的最大熵
 96                         for (int j = 0; j < nums.Count - 1; j++)
 97                         {
 98                             currentValue = allData[nums[j]][lieshu - 1];
 99                             currentPoint = allData[nums[j]][i];
100                             if (j == 0)
101                             {
102                                 lastValue = currentValue;
103                                 lastPoint = currentPoint;
104                             }
105                             if (currentValue != lastValue)
106                             {
107                                 double shang1 = CalEntropy(count1, all1);
108                                 double shang2 = CalEntropy(count2, all2);
109                                 double allShang = shang1 * all1 / (all1 + all2) + shang2 * all2 / (all1 + all2);
110                                 allShang = (-totalShang + allShang);
111                                 if (lianxuMax < allShang)
112                                 {
113                                     lianxuMax = allShang;
114                                     for (int k = 0; k < count1.Length; k++)
115                                     {
116                                         leftCount[k] = count1[k];
117                                         rightCount[k] = count2[k];
118                                     }
119                                     splitPoint = j;
120                                     splitValue = (currentPoint + lastPoint) / 2;
121                                 }
122                             }
123                             all1++;
124                             count1[Convert.ToInt32(currentValue) - 1]++;
125                             count2[Convert.ToInt32(currentValue) - 1]--;
126                             all2--;
127                             lastValue = currentValue;
128                             lastPoint = currentPoint;
129                         }
130                         double rate1 = Convert.ToDouble(leftCount[0] + leftCount[1]) / (leftCount[0] + leftCount[1] + rightCount[0] + rightCount[1]);
131                         chushu = 0;
132                         if (rate1 > 0)
133                         {
134                             chushu = chushu + rate1 * Math.Log(rate1, 2);
135                         }
136                         double rate2 = Convert.ToDouble(rightCount[0] + rightCount[1]) / (leftCount[0] + leftCount[1] + rightCount[0] + rightCount[1]);
137                         if (rate2 > 0)
138                         {
139                             chushu = chushu + rate2 * Math.Log(rate2, 2);
140                         }
141                         //lianxuMax = lianxuMax ;
142                         //lianxuMax = lianxuMax;
143                         if (lianxuMax > jubuMax)
144                         {
145                             //info.setSplitIndex(i);
146                             info.splitIndex=(i);
147                             //info.setFeatures(new List<String> { splitValue + "" });
148                             info.features = (new List<String> { splitValue + "" });
149                             //info.setType(1);
150                             info.type=(1);
151                             jubuMax = lianxuMax;
152                             //info.setType(1);
153                             List<int>[] allInt = new List<int>[2];
154                             allInt[0] = new List<int>();
155                             allInt[1] = new List<int>();
156                             for (int k = 0; k < splitPoint; k++)
157                             {
158                                 allInt[0].Add(nums[k]);
159                             }
160                             for (int k = splitPoint; k < nums.Count; k++)
161                             {
162                                 allInt[1].Add(nums[k]);
163                             }
164                             info.temp=(allInt);
165                             //info.setTemp(allInt);
166                             double[][] alls = new double[2][];
167                             alls[0] = new double[leftCount.Length];
168                             alls[1] = new double[leftCount.Length];
169                             for (int k = 0; k < leftCount.Length; k++)
170                             {
171                                 alls[0][k] = leftCount[k];
172                                 alls[1][k] = rightCount[k];
173                             }
174                             info.class_Count=(alls);
175                             //info.setclassCount(alls);
176                         }
177                     }
178                     #endregion
179                 }
180                 #region 若是找不到最佳的分裂屬性,則設爲葉節點
181                 if (info.splitIndex == -1)
182                 {
183                     double[] finalCount = node.ClassCount;
184                     double max = finalCount[0];
185                     int result = 1;
186                     for (int i = 1; i < finalCount.Length; i++)
187                     {
188                         if (finalCount[i] > max)
189                         {
190                             max = finalCount[i];
191                             result = (i + 1);
192                         }
193                     }
194                     node.feature_Type=("result");
195                     node.features=(new List<String> { "" + result });
196                     return node;
197                 }
198                 #endregion
199                 #region 分裂
200                 int deep = node.deep;
201                 node.SplitFeature=("" + info.splitIndex);
202 
203                 List<Node> childNode = new List<Node>();
204                 int[] used = new int[isUsed.Length];
205                 for (int i = 0; i < used.Length; i++)
206                 {
207                     used[i] = isUsed[i];
208                 }
209                 if (info.type == 0)
210                 {
211                     used[info.splitIndex] = 1;
212                     node.feature_Type=("離散");
213                 }
214                 else
215                 {
216                     used[info.splitIndex] = 0;
217                     node.feature_Type=("連續");
218                 }
219                 int sumLeaf = 0;
220                 int sumWrong = 0;
221                 List<int>[] rowIndex = info.temp;
222                 List<String> features = info.features;
223                 for (int j = 0; j < rowIndex.Length; j++)
224                 {
225                     if (rowIndex[j].Count == 0)
226                     {
227                         continue;
228                     }
229                     if (info.type == 0)
230                         features.Add("" + (j + 1));
231                     Node node1 = new Node();
232                     node1.setClassCount(info.class_Count[j]);
233                     node1.deep=(deep + 1);
234                     node1.rowCount = info.temp[j].Count;
235                     node1 = findBestSplit(node1, info.temp[j], used);
236                     sumLeaf += node1.leafNode_Count;
237                     sumWrong += node1.leafWrong;
238                     childNode.Add(node1);
239                 }
240                 node.leafNode_Count = (sumLeaf);
241                 node.leafWrong = (sumWrong);
242                 node.features=(features);
243                 node.childNodes=(childNode);
244                 #endregion
245                 return node;
246             }
247             catch (Exception e)
248             {
249                 Console.WriteLine(e.StackTrace);
250                 return node;
251             }
252         }
View Code

(4)剪枝

悲觀剪枝方法(PEP):

 1 public static void prune(Node node)
 2         {
 3             if (node.feature_Type == "result")
 4                 return;
 5             double treeWrong = node.getErrorCount() + 0.5;
 6             double leafError = node.leafWrong + 0.5 * node.leafNode_Count;
 7             double var = Math.Sqrt(leafError * (1 - Convert.ToDouble(leafError) / node.nums.Count));
 8             double panbie = leafError + var - treeWrong;
 9             if (panbie > 0)
10             {
11                 node.feature_Type=("result");
12                 node.childNodes=(null);
13                 int result = (node.result + 1);
14                 node.features=(new List<String>() { "" + result });
15             }
16             else
17             {
18                 List<Node> childNodes = node.childNodes;
19                 for (int i = 0; i < childNodes.Count; i++)
20                 {
21                     prune(childNodes[i]);
22                 }
23             }
24         }
View Code

C4.5核心算法的全部源代碼:

  1         #region C4.5核心算法
  2         /// <summary>
  3         /// 測試
  4         /// </summary>
  5         /// <param name="node"></param>
  6         /// <param name="data"></param>
  7         public static String findResult(Node node, String[] data)
  8         {
  9             List<String> featrues = node.features;
 10             String type = node.feature_Type;
 11             if (type == "result")
 12             {
 13                 return featrues[0];
 14             }
 15             int split = Convert.ToInt32(node.SplitFeature);
 16             List<Node> childNodes = node.childNodes;
 17             double[] resultCount = node.ClassCount;
 18             if (type == "連續")
 19             {
 20                 double value = Convert.ToDouble(featrues[0]);
 21                 if (Convert.ToDouble(data[split]) <= value)
 22                 {
 23                     return findResult(childNodes[0], data);
 24                 }
 25                 else
 26                 {
 27                     return findResult(childNodes[1], data);
 28                 }
 29             }
 30             else
 31             {
 32                 for (int i = 0; i < featrues.Count; i++)
 33                 {
 34                     if (data[split] == featrues[i])
 35                     {
 36                         return findResult(childNodes[i], data);
 37                     }
 38                     if (i == featrues.Count - 1)
 39                     {
 40                         double count = resultCount[0];
 41                         int maxInt = 0;
 42                         for (int j = 1; j < resultCount.Length; j++)
 43                         {
 44                             if (count < resultCount[j])
 45                             {
 46                                 count = resultCount[j];
 47                                 maxInt = j;
 48                             }
 49                         }
 50                         return findResult(childNodes[0], data);
 51                     }
 52                 }
 53             }
 54             return null;
 55         }
 56         /// <summary>
 57         /// 判斷是否還須要分裂
 58         /// </summary>
 59         /// <param name="node"></param>
 60         /// <returns></returns>
 61         public static bool ifEnd(Node node, double entropy,int[] isUsed)
 62         {
 63             try
 64             {
 65                 double[] count = node.ClassCount;
 66                 int rowCount = node.rowCount;
 67                 int maxResult = 0;
 68                 #region 數達到某一深度
 69                 int deep = node.deep;
 70                 if (deep >= maxDeep)
 71                 {
 72                     maxResult = node.result + 1;
 73                     node.feature_Type=("result");
 74                     node.features=(new List<String>() { maxResult + "" });
 75                     node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
 76                     node.leafNode_Count = 1;
 77                     return true;
 78                 }
 79                 #endregion
 80                 #region 純度(其實跟後面的有點重了,記得要修改)
 81                 //maxResult = 1;
 82                 //for (int i = 1; i < count.Length; i++)
 83                 //{
 84                 //    if (count[i] / rowCount >= 0.95)
 85                 //    {
 86                 //        node.feature_Type=("result");
 87                 //        node.features=(new List<String> { "" + (i + 1) });
 88                 //        node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
 89                 //        node.leafNode_Count = 1;
 90                 //        return true;
 91                 //    }
 92                 //}
 93                 #endregion
 94                 #region 熵爲0
 95                 if (entropy == 0)
 96                 {
 97                     maxResult = node.result+1;
 98                     node.feature_Type=("result");
 99                     node.features=(new List<String> { maxResult + "" });
100                     node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
101                     node.leafNode_Count = 1;
102                     return true;
103                 }
104                 #endregion
105                 #region 屬性已經分完
106                 bool flag = true;
107                 for (int i = 0; i < isUsed.Length - 1; i++)
108                 {
109                     if (isUsed[i] == 0)
110                     {
111                         flag = false;
112                         break;
113                     }
114                 }
115                 if (flag)
116                 {
117                     maxResult = node.result+1;
118                     node.feature_Type=("result");
119                     node.features=(new List<String> { "" + (maxResult) });
120                     node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
121                     node.leafNode_Count = 1;
122                     return true;
123                 }
124                 #endregion
125                 #region 數據量少於100
126                 if (rowCount < Limit_Node)
127                 {
128                     maxResult = node.result+1;
129                     node.feature_Type=("result");
130                     node.features=(new List<String> { "" + (maxResult) });
131                     node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1]));
132                     node.leafNode_Count = 1;
133                     return true;
134                 }
135                 #endregion
136                 return false;
137             }
138             catch (Exception e)
139             {
140                 return false;
141             }
142         }
143         #region 排序算法
144         public static void InsertSort(double[] values, List<int> arr, int StartIndex, int endIndex)
145         {
146             for (int i = StartIndex + 1; i <= endIndex; i++)
147             {
148                 int key = arr[i];
149                 double init = values[i];
150                 int j = i - 1;
151                 while (j >= StartIndex && values[j] > init)
152                 {
153                     arr[j + 1] = arr[j];
154                     values[j + 1] = values[j];
155                     j--;
156                 }
157                 arr[j + 1] = key;
158                 values[j + 1] = init;
159             }
160         }
161         static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high)
162         {
163             int mid = low + ((high - low) >> 1);//計算數組中間的元素的下標  
164 
165             //使用三數取中法選擇樞軸  
166             if (values[mid] > values[high])//目標: arr[mid] <= arr[high]  
167             {
168                 swap(values, arr, mid, high);
169             }
170             if (values[low] > values[high])//目標: arr[low] <= arr[high]  
171             {
172                 swap(values, arr, low, high);
173             }
174             if (values[mid] > values[low]) //目標: arr[low] >= arr[mid]  
175             {
176                 swap(values, arr, mid, low);
177             }
178             //此時,arr[mid] <= arr[low] <= arr[high]  
179             return low;
180             //low的位置上保存這三個位置中間的值  
181             //分割時能夠直接使用low位置的元素做爲樞軸,而不用改變分割函數了  
182         }
183         static void swap(double[] values, List<int> arr, int t1, int t2)
184         {
185             double temp = values[t1];
186             values[t1] = values[t2];
187             values[t2] = temp;
188             int key = arr[t1];
189             arr[t1] = arr[t2];
190             arr[t2] = key;
191         }
192         static void QSort(double[] values, List<int> arr, int low, int high)
193         {
194             int first = low;
195             int last = high;
196 
197             int left = low;
198             int right = high;
199 
200             int leftLen = 0;
201             int rightLen = 0;
202 
203             if (high - low + 1 < 10)
204             {
205                 InsertSort(values, arr, low, high);
206                 return;
207             }
208 
209             //一次分割 
210             int key = SelectPivotMedianOfThree(values, arr, low, high);//使用三數取中法選擇樞軸 
211             double inti = values[key];
212             int currentKey = arr[key];
213 
214             while (low < high)
215             {
216                 while (high > low && values[high] >= inti)
217                 {
218                     if (values[high] == inti)//處理相等元素  
219                     {
220                         swap(values, arr, right, high);
221                         right--;
222                         rightLen++;
223                     }
224                     high--;
225                 }
226                 arr[low] = arr[high];
227                 values[low] = values[high];
228                 while (high > low && values[low] <= inti)
229                 {
230                     if (values[low] == inti)
231                     {
232                         swap(values, arr, left, low);
233                         left++;
234                         leftLen++;
235                     }
236                     low++;
237                 }
238                 arr[high] = arr[low];
239                 values[high] = values[low];
240             }
241             arr[low] = currentKey;
242             values[low] = values[key];
243             //一次快排結束  
244             //把與樞軸key相同的元素移到樞軸最終位置周圍  
245             int i = low - 1;
246             int j = first;
247             while (j < left && values[i] != inti)
248             {
249                 swap(values, arr, i, j);
250                 i--;
251                 j++;
252             }
253             i = low + 1;
254             j = last;
255             while (j > right && values[i] != inti)
256             {
257                 swap(values, arr, i, j);
258                 i++;
259                 j--;
260             }
261             QSort(values, arr, first, low - 1 - leftLen);
262             QSort(values, arr, low + 1 + rightLen, last);
263         }
264         #endregion
265         /// <summary>
266         /// 尋找最佳的分裂點
267         /// </summary>
268         /// <param name="num"></param>
269         /// <param name="node"></param>
270         public static Node findBestSplit(Node node, List<int> nums, int[] isUsed)
271         {
272             try
273             {
274                 //判斷是否繼續分裂
275                 double totalShang = CalEntropy(node.ClassCount, node.rowCount);
276                 if (ifEnd(node, totalShang,isUsed))
277                 {
278                     return node;
279                 }
280                 #region 變量聲明
281                 SplitInfo info = new SplitInfo();
282                 int RowCount = nums.Count;                  //樣本總數
283                 double jubuMax = 0;                         //局部最大熵
284                 #endregion
285                 for (int i = 0; i < isUsed.Length - 1; i++)
286                 {
287                     if (isUsed[i] == 1)
288                     {
289                         continue;
290                     }
291                     #region 離散變量
292                     if (type[i] == 0)
293                     {
294                         int[] allFeatureCount = new int[0];         //全部類別的數量
295                         double[][] allCount = new double[allNum[i]][];
296                         for (int j = 0; j < allCount.Length; j++)
297                         {
298                             allCount[j] = new double[classCount];
299                         }
300                         int[] countAllFeature = new int[allNum[i]];
301                         List<int>[] temp = new List<int>[allNum[i]];
302                         for (int j = 0; j < temp.Length; j++)
303                         {
304                             temp[j] = new List<int>();
305                         }
306                         for (int j = 0; j < nums.Count; j++)
307                         {
308                             int index = Convert.ToInt32(allData[nums[j]][i]);
309                             temp[index - 1].Add(nums[j]);
310                             countAllFeature[index - 1]++;
311                             allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
312                         }
313                         double allShang = 0;
314                         double chushu = 0;
315                         for (int j = 0; j < allCount.Length; j++)
316                         {
317                             allShang = allShang + CalEntropy(allCount[j], countAllFeature[j]) * countAllFeature[j] / RowCount;
318                             if (countAllFeature[j] > 0)
319                             {
320                                 double rate = countAllFeature[j] / Convert.ToDouble(RowCount);
321                                 chushu = chushu + rate * Math.Log(rate, 2);
322                             }
323                         }
324                         allShang = (-totalShang + allShang);
325                         if (allShang > jubuMax)
326                         {
327                             info.features = new List<string>();
328                             info.type = 0;
329                             info.temp = temp;
330                             info.splitIndex = i;
331                             info.class_Count = allCount;
332                             jubuMax = allShang;
333                             allFeatureCount = countAllFeature;
334                         }
335                     }
336                     #endregion
337                     #region 連續變量
338                     else
339                     {
340                         double[] leftCount = new double[classCount];          //作節點各個類別的數量
341                         double[] rightCount = new double[classCount];        //右節點各個類別的數量
342                         double[] count1 = new double[classCount];            //子集1的統計量
343                         //double[] count2 = new double[node.getCount().Length];   //子集2的統計量
344                         double[] count2 = new double[node.ClassCount.Length];   //子集2的統計量
345                         for (int j = 0; j < node.ClassCount.Length; j++)
346                         {
347                             count2[j] = node.ClassCount[j];
348                         }
349                         int all1 = 0;                                           //子集1的樣本量
350                         int all2 = nums.Count;                                  //子集2的樣本量
351                         double lastValue = 0;                                  //上一個記錄的類別
352                         double currentValue = 0;                               //當前類別
353                         double lastPoint = 0;                                   //上一個點的值
354                         double currentPoint = 0;                                //當前點的值
355                         int splitPoint = 0;
356                         double splitValue = 0;
357                         double[] values = new double[nums.Count];
358                         for (int j = 0; j < values.Length; j++)
359                         {
360                             values[j] = allData[nums[j]][i];
361                         }
362                         QSort(values, nums, 0, nums.Count - 1);
363                         double chushu = 0;
364                         double lianxuMax = 0;                       //連續型屬性的最大熵
365                         for (int j = 0; j < nums.Count - 1; j++)
366                         {
367                             currentValue = allData[nums[j]][lieshu - 1];
368                             currentPoint = allData[nums[j]][i];
369                             if (j == 0)
370                             {
371                                 lastValue = currentValue;
372                                 lastPoint = currentPoint;
373                             }
374                             if (currentValue != lastValue)
375                             {
376                                 double shang1 = CalEntropy(count1, all1);
377                                 double shang2 = CalEntropy(count2, all2);
378                                 double allShang = shang1 * all1 / (all1 + all2) + shang2 * all2 / (all1 + all2);
379                                 allShang = (-totalShang + allShang);
380                                 if (lianxuMax < allShang)
381                                 {
382                                     lianxuMax = allShang;
383                                     for (int k = 0; k < count1.Length; k++)
384                                     {
385                                         leftCount[k] = count1[k];
386                                         rightCount[k] = count2[k];
387                                     }
388                                     splitPoint = j;
389                                     splitValue = (currentPoint + lastPoint) / 2;
390                                 }
391                             }
392                             all1++;
393                             count1[Convert.ToInt32(currentValue) - 1]++;
394                             count2[Convert.ToInt32(currentValue) - 1]--;
395                             all2--;
396                             lastValue = currentValue;
397                             lastPoint = currentPoint;
398                         }
399                         double rate1 = Convert.ToDouble(leftCount[0] + leftCount[1]) / (leftCount[0] + leftCount[1] + rightCount[0] + rightCount[1]);
400                         chushu = 0;
401                         if (rate1 > 0)
402                         {
403                             chushu = chushu + rate1 * Math.Log(rate1, 2);
404                         }
405                         double rate2 = Convert.ToDouble(rightCount[0] + rightCount[1]) / (leftCount[0] + leftCount[1] + rightCount[0] + rightCount[1]);
406                         if (rate2 > 0)
407                         {
408                             chushu = chushu + rate2 * Math.Log(rate2, 2);
409                         }
410                         //lianxuMax = lianxuMax ;
411                         //lianxuMax = lianxuMax;
412                         if (lianxuMax > jubuMax)
413                         {
414                             //info.setSplitIndex(i);
415                             info.splitIndex=(i);
416                             //info.setFeatures(new List<String> { splitValue + "" });
417                             info.features = (new List<String> { splitValue + "" });
418                             //info.setType(1);
419                             info.type=(1);
420                             jubuMax = lianxuMax;
421                             //info.setType(1);
422                             List<int>[] allInt = new List<int>[2];
423                             allInt[0] = new List<int>();
424                             allInt[1] = new List<int>();
425                             for (int k = 0; k < splitPoint; k++)
426                             {
427                                 allInt[0].Add(nums[k]);
428                             }
429                             for (int k = splitPoint; k < nums.Count; k++)
430                             {
431                                 allInt[1].Add(nums[k]);
432                             }
433                             info.temp=(allInt);
434                             //info.setTemp(allInt);
435                             double[][] alls = new double[2][];
436                             alls[0] = new double[leftCount.Length];
437                             alls[1] = new double[leftCount.Length];
438                             for (int k = 0; k < leftCount.Length; k++)
439                             {
440                                 alls[0][k] = leftCount[k];
441                                 alls[1][k] = rightCount[k];
442                             }
443                             info.class_Count=(alls);
444                             //info.setclassCount(alls);
445                         }
446                     }
447                     #endregion
448                 }
449                 #region 若是找不到最佳的分裂屬性,則設爲葉節點
450                 if (info.splitIndex == -1)
451                 {
452                     double[] finalCount = node.ClassCount;
453                     double max = finalCount[0];
454                     int result = 1;
455                     for (int i = 1; i < finalCount.Length; i++)
456                     {
457                         if (finalCount[i] > max)
458                         {
459                             max = finalCount[i];
460                             result = (i + 1);
461                         }
462                     }
463                     node.feature_Type=("result");
464                     node.features=(new List<String> { "" + result });
465                     return node;
466                 }
467                 #endregion
468                 #region 分裂
469                 int deep = node.deep;
470                 node.SplitFeature=("" + info.splitIndex);
471 
472                 List<Node> childNode = new List<Node>();
473                 int[] used = new int[isUsed.Length];
474                 for (int i = 0; i < used.Length; i++)
475                 {
476                     used[i] = isUsed[i];
477                 }
478                 if (info.type == 0)
479                 {
480                     used[info.splitIndex] = 1;
481                     node.feature_Type=("離散");
482                 }
483                 else
484                 {
485                     used[info.splitIndex] = 0;
486                     node.feature_Type=("連續");
487                 }
488                 int sumLeaf = 0;
489                 int sumWrong = 0;
490                 List<int>[] rowIndex = info.temp;
491                 List<String> features = info.features;
492                 for (int j = 0; j < rowIndex.Length; j++)
493                 {
494                     if (rowIndex[j].Count == 0)
495                     {
496                         continue;
497                     }
498                     if (info.type == 0)
499                         features.Add("" + (j + 1));
500                     Node node1 = new Node();
501                     node1.setClassCount(info.class_Count[j]);
502                     node1.deep=(deep + 1);
503                     node1.rowCount = info.temp[j].Count;
504                     node1 = findBestSplit(node1, info.temp[j], used);
505                     sumLeaf += node1.leafNode_Count;
506                     sumWrong += node1.leafWrong;
507                     childNode.Add(node1);
508                 }
509                 node.leafNode_Count = (sumLeaf);
510                 node.leafWrong = (sumWrong);
511                 node.features=(features);
512                 node.childNodes=(childNode);
513                 #endregion
514                 return node;
515             }
516             catch (Exception e)
517             {
518                 Console.WriteLine(e.StackTrace);
519                 return node;
520             }
521         }
522         /// <summary>
523         /// 計算熵
524         /// </summary>
525         /// <param name="counts"></param>
526         /// <param name="countAll"></param>
527         /// <returns></returns>
528         public static double CalEntropy(double[] counts, int countAll)
529         {
530             try
531             {
532                 double allShang = 0;
533                 for (int i = 0; i < counts.Length; i++)
534                 {
535                     if (counts[i] == 0)
536                     {
537                         continue;
538                     }
539                     double rate = counts[i] / countAll;
540                     allShang = allShang + rate * Math.Log(rate, 2);
541                 }
542                 return allShang;
543             }
544             catch (Exception e)
545             {
546                 return 0;
547             }
548         }
549 
550         #region 悲觀剪枝
551         public static void prune(Node node)
552         {
553             if (node.feature_Type == "result")
554                 return;
555             double treeWrong = node.getErrorCount() + 0.5;
556             double leafError = node.leafWrong + 0.5 * node.leafNode_Count;
557             double var = Math.Sqrt(leafError * (1 - Convert.ToDouble(leafError) / node.rowCount));
558             double panbie = leafError + var - treeWrong;
559             if (panbie > 0)
560             {
561                 node.feature_Type = "result";
562                 node.childNodes = null;
563                 int result = node.result + 1;
564                 node.features= new List<String>() { "" + result };
565             }
566             else
567             {
568                 List<Node> childNodes = node.childNodes;
569                 for (int i = 0; i < childNodes.Count; i++)
570                 {
571                     prune(childNodes[i]);
572                 }
573             }
574         }
575         #endregion
576         #endregion
View Code

 

總結:

      要記住,C4.5是分類樹最終要的算法,算法的思想其實很簡單,可是分類的準確性高。能夠說C4.5是ID3的升級版和強化版,解決了ID3未能解決的問題。要重點記住如下幾個方面:

      1.C4.5是採用信息增益率選擇分裂的屬性,解決了ID3選擇屬性時的偏向性問題;

      2.C4.5可以對連續數據進行處理,採用一刀切的方式將連續型的數據切成兩份,在選擇切割點的時候使用信息增益做爲擇優的條件;

      3.C4.5採用悲觀剪枝的策略,必定程度上下降了過擬合的影響。

相關文章
相關標籤/搜索