決策樹系列(三)——ID3

預備知識:決策樹 html

初識ID3node

      回顧決策樹的基本知識,其構建過程主要有下述三個重要的問題:算法

     (1)數據是怎麼分裂的數組

     (2)如何選擇分類的屬性ide

     (3)何時中止分裂函數

     從上述三個問題出發,以實際的例子對ID3算法進行闡述。測試

例:經過當天的天氣、溫度、溼度和季節預測明天的天氣this

                                  表1 原始數據spa

當每天氣設計

溫度

溼度

季節

明每天氣

25

50

春天

21

48

春天

18

70

春天

28

41

夏天

8

65

冬天

18

43

夏天

24

56

秋天

18

76

秋天

31

61

夏天

6

43

冬天

15

55

秋天

4

58

冬天

 1.數據分割

      對於離散型數據,直接按照離散數據的取值進行分裂,每個取值對應一個子節點,以「當前天氣」爲例對數據進行分割,如圖1所示。

 

      對於連續型數據,ID3本來是沒有處理能力的,只有經過離散化將連續性數據轉化成離散型數據再進行處理。

      連續數據離散化是另一個課題,本文不深刻闡述,這裏直接採用等距離數據劃分的李算話方法。該方法先對數據進行排序,而後將連續型數據劃分爲多個區間,並使每個區間的數據量基本相同,以溫度爲例對數據進行分割,如圖2所示。

 

 2. 選擇最優分裂屬性

      ID3採用信息增益做爲選擇最優的分裂屬性的方法,選擇熵做爲衡量節點純度的標準,信息增益的計算公式以下:

                                               

      其中, 表示父節點的熵; 表示節點i的熵,熵越大,節點的信息量越多,越不純; 表示子節點i的數據量與父節點數據量之比。 越大,表示分裂後的熵越小,子節點變得越純,分類的效果越好,所以選擇 最大的屬性做爲分裂屬性。

      對上述的例子的跟節點進行分裂,分別計算每個屬性的信息增益,選擇信息增益最大的屬性進行分裂。

      天氣屬性:(數據分割如上圖1所示) 

  

      溫度:(數據分割如上圖2所示)

     

      溼度:

 

      

      季節:

 

      

     因爲最大,因此選擇屬性「季節」做爲根節點的分裂屬性。

 

3.中止分裂的條件

     中止分裂的條件已經在決策樹中闡述,這裏再也不進行闡述。

     (1)最小節點數

  當節點的數據量小於一個指定的數量時,不繼續分裂。兩個緣由:一是數據量較少時,再作分裂容易強化噪聲數據的做用;二是下降樹生長的複雜性。提早結束分裂必定程度上有利於下降過擬合的影響。

  (2)熵或者基尼值小於閥值。

     由上述可知,熵和基尼值的大小表示數據的複雜程度,當熵或者基尼值太小時,表示數據的純度比較大,若是熵或者基尼值小於必定程度時,節點中止分裂。

  (3)決策樹的深度達到指定的條件

   節點的深度能夠理解爲節點與決策樹跟節點的距離,如根節點的子節點的深度爲1,由於這些節點與跟節點的距離爲1,子節點的深度要比父節點的深度大1。決策樹的深度是全部葉子節點的最大深度,當深度到達指定的上限大小時,中止分裂。

  (4)全部特徵已經使用完畢,不能繼續進行分裂。

     被動式中止分裂的條件,當已經沒有可分的屬性時,直接將當前節點設置爲葉子節點。

 

程序設計及源代碼(C#版本)

(1)數據處理

       用二維數組存儲原始的數據,每一行表示一條記錄,前n-1列表示數據的屬性,第n列表示分類的標籤。

   static double[][] allData;

   爲了方便後面的處理,對離散屬性進行數字化處理,將離散值表示成數字,並用一個鏈表數組進行存儲,數組的第一個元素表示屬性1的離散值。

   static List<String>[] featureValues;

       那麼通過處理後的表1數據能夠轉化爲如表2所示的數據:

                                                                                表2 初始化後的數據

當每天氣

溫度

溼度

季節

明每天氣

1

25

50

1

1

2

21

48

1

2

2

18

70

1

3

1

28

41

2

1

3

8

65

3

2

1

18

43

2

1

2

24

56

4

1

3

18

76

4

2

3

31

61

2

1

2

6

43

3

3

1

15

55

4

2

3

4

58

3

3

      其中,對於當每天氣屬性,數字{1,2,3}分別表示{晴,陰,雨};對於季節屬性{1,2,3,4}分別表示{春天、夏天、冬天、秋天};對於明每天氣{1,2,3}分別表示{晴、陰、雨}。

(2)兩個類:節點類和分裂信息

  a)節點類Node

      該類存儲了節點的信息,包括節點的數據量、節點選擇的分裂屬性、節點輸出類、子節點的個數、子節點的分類偏差等。

 1     class Node
 2     {
 3         /// <summary>
 4         /// 各個子節點的取值
 5         /// </summary>
 6         public List<String> features { get; set; }
 7         /// <summary>
 8         /// 分裂屬性的類型
 9         /// </summary>
10         public String feature_Type { get; set; }
11         /// <summary>
12         /// 分裂的屬性
13         /// </summary>
14         public String SplitFeature { get; set; }
15         /// <summary>
16         /// 節點對應各個分類的數目
17         /// </summary>
18         public double[] ClassCount { get; set; }
19         /// <summary>
20         /// 各個孩子節點
21         /// </summary>
22         public List<Node> childNodes { get; set; }
23         /// <summary>
24         /// 父親節點(未用到)
25         /// </summary>
26         public Node Parent { get; set; }
27         /// <summary>
28         /// 佔比最大的類別
29         /// </summary>
30         public String finalResult { get; set; }
31         /// <summary>
32         /// 數的深度
33         /// </summary>
34         public int deep { get; set; }
35         /// <summary>
36         /// 該節點佔比最大的類標號
37         /// </summary>
38         public int result { get; set; }
39         /// <summary>
40         /// 節點的數量
41         /// </summary>
42         public int rowCount{ get; set; }
43 
44         
45         public void setClassCount(double[] count)
46         {
47             this.ClassCount = count;
48             double max = ClassCount[0];
49             int result = 0;
50             for (int i = 1; i < ClassCount.Length; i++)
51             {
52                 if (max < ClassCount[i])
53                 {
54                     max = ClassCount[i];
55                     result = i;
56                 }
57             }
58             //wrong = Convert.ToInt32(nums.Count - ClassCount[result]);
59             this.result = result;
60         }
61     }
View Code

  b)分裂信息類SplitInfo

      該類存儲節點進行分裂的信息,包括各個子節點的行座標、子節點各個類的數目、該節點分裂的屬性、屬性的類型等。

 1     class SplitInfo
 2     {
 3         /// <summary>
 4         /// 分裂的列下標
 5         /// </summary>
 6         public int splitIndex { get; set; }
 7         /// <summary>
 8         /// 數據類型
 9         /// </summary>
10         public int type { get; set; }
11         /// <summary>
12         /// 分裂屬性的取值
13         /// </summary>
14         public List<String> features { get; set; }
15         /// <summary>
16         /// 各個節點的行座標鏈表
17         /// </summary>
18         public List<int>[] temp { get; set; }
19         /// <summary>
20         /// 每一個節點各種的數目
21         /// </summary>
22         public double[][] class_Count { get; set; }
23     }
View Code

(3)節點分裂方法findBestSplit(Node node,List<int> nums,int[] isUsed),該方法對節點進行分裂,返回值Node

其中:

    node表示即將進行分裂的節點;

    nums表示節點數據對應的行座標列表;

    isUsed表示到該節點位置全部屬性的使用狀況(1:表示該屬性不能再次使用,0:表示該屬性可使用);

findBestSplit主要有如下幾個組成部分:

1)節點分裂中止的斷定

判斷節點是否須要繼續分裂,分裂判斷條件如上文所述。源代碼以下:

  1         public static Object[] ifEnd(Node node, double entropy,int[] isUsed)
  2         {
  3             try
  4             {
  5                 double[] count = node.ClassCount;
  6                 int rowCount = node.rowCount;
  7                 int maxResult = 0;
  8                 double maxRate = 0;
  9                 #region 數達到某一深度
 10                 int deep = node.deep;
 11                 if (deep >= maxDeep)
 12                 {
 13                     maxResult = node.result + 1;
 14                     node.feature_Type=("result");
 15                     node.features=(new List<String>() { maxResult + "" });
 16                     return new Object[] { true, node };
 17                 }
 18                 #endregion
 19                 #region 純度(其實跟後面的有點重了,記得要修改)
 20                 //maxResult = 1;
 21                 //for (int i = 1; i < count.Length; i++)
 22                 //{
 23                 //    if (count[i] / rowCount >= 0.95)
 24                 //    {
 25                 //        node.setFeatureType("result");
 26                 //        node.setFeatures(new List<String> { "" + (i + 1) });
 27                 //        return new Object[] { true, node };
 28                 //    }
 29                 //}
 30                 //node.setLeafWrong(rowCount - Convert.ToInt32(count[maxResult - 1]));
 31                 #endregion
 32                 #region 熵爲0
 33                 if (entropy == 0)
 34                 {
 35                     maxRate = count[0] / rowCount;
 36                     maxResult = 1;
 37                     for (int i = 1; i < count.Length; i++)
 38                     {
 39                         if (count[i] / rowCount >= maxRate)
 40                         {
 41                             maxRate = count[i] / rowCount;
 42                             maxResult = i + 1;
 43                         }
 44                     }
 45                     node.feature_Type=("result");
 46                     node.features=(new List<String> { maxResult + "" });
 47                     return new Object[] { true, node };
 48                 }
 49                 #endregion
 50                 #region 屬性已經分完
 51                 //int[] isUsed = node.;
 52                 bool flag = true;
 53                 for (int i = 0; i < isUsed.Length - 1; i++)
 54                 {
 55                     if (isUsed[i] == 0)
 56                     {
 57                         flag = false;
 58                         break;
 59                     }
 60                 }
 61                 if (flag)
 62                 {
 63                     maxRate = count[0] / rowCount;
 64                     maxResult = 1;
 65                     for (int i = 1; i < count.Length; i++)
 66                     {
 67                         if (count[i] / rowCount >= maxRate)
 68                         {
 69                             maxRate = count[i] / rowCount;
 70                             maxResult = i + 1;
 71                         }
 72                     }
 73                     node.feature_Type=("result");
 74                     node.features=(new List<String> { "" + (maxResult) });
 75                     return new Object[] { true, node };
 76                 }
 77                 #endregion
 78                 #region 數據量少於100
 79                 if (rowCount < Limit_Node)
 80                 {
 81                     maxRate = count[0] / rowCount;
 82                     maxResult = 1;
 83                     for (int i = 1; i < count.Length; i++)
 84                     {
 85                         if (count[i] / rowCount >= maxRate)
 86                         {
 87                             maxRate = count[i] / rowCount;
 88                             maxResult = i + 1;
 89                         }
 90                     }
 91                     node.feature_Type=("result");
 92                     node.features=(new List<String> { "" + (maxResult) });
 93                     return new Object[] { true, node };
 94                 }
 95                 #endregion
 96                 return new Object[] { false, node };
 97             }
 98             catch (Exception e)
 99             {
100                 return new Object[] { false, node };
101             }
102         }
View Code

2)尋找最優的分裂屬性

尋找最優的分裂屬性須要計算每個分裂屬性分裂後的信息增益,計算公式上文已給出,其中熵的計算代碼以下:

 1         public static double CalEntropy(double[] counts, int countAll)
 2         {
 3             try
 4             {
 5                 double allShang = 0;
 6                 for (int i = 0; i < counts.Length; i++)
 7                 {
 8                     if (counts[i] == 0)
 9                     {
10                         continue;
11                     }
12                     double rate = counts[i] / countAll;
13                     allShang = allShang + rate * Math.Log(rate, 2);
14                 }
15                 return -allShang;
16             }
17             catch (Exception e)
18             {
19                 return 0;
20             }
21         }
View Code

3)進行分裂,同時子節點也執行相同的分類步驟

其實就是遞歸的過程,對每個子節點執行findBestSplit方法進行分裂。

所有源代碼:

  1         #region ID3核心算法
  2         /// <summary>
  3         /// 測試
  4         /// </summary>
  5         /// <param name="node"></param>
  6         /// <param name="data"></param>
  7         public static String findResult(Node node, String[] data)
  8         {
  9             List<String> featrues = node.features;
 10             String type = node.feature_Type;
 11             if (type == "result")
 12             {
 13                 return featrues[0];
 14             }
 15             int split = Convert.ToInt32(node.SplitFeature);
 16             List<Node> childNodes = node.childNodes;
 17             double[] resultCount = node.ClassCount;
 18             if (type == "連續")
 19             {
 20                 
 21 
 22                 for (int i = 0; i < featrues.Count; i++)
 23                 {
 24                     double value = Convert.ToDouble(featrues[i]);
 25                     if (Convert.ToDouble(data[split]) <= value)
 26                     {
 27                         return findResult(childNodes[i], data);
 28                     }
 29                 }
 30                 return findResult(childNodes[featrues.Count], data);
 31             }
 32             else
 33             {
 34                 for (int i = 0; i < featrues.Count; i++)
 35                 {
 36                     if (data[split] == featrues[i])
 37                     {
 38                         return findResult(childNodes[i], data);
 39                     }
 40                     if (i == featrues.Count - 1)
 41                     {
 42                         double count = resultCount[0];
 43                         int maxInt = 0;
 44                         for (int j = 1; j < resultCount.Length; j++)
 45                         {
 46                             if (count < resultCount[j])
 47                             {
 48                                 count = resultCount[j];
 49                                 maxInt = j;
 50                             }
 51                         }
 52                         return findResult(childNodes[0], data);
 53                     }
 54                 }
 55             }
 56             return null;
 57         }
 58         /// <summary>
 59         /// 判斷是否還須要分裂
 60         /// </summary>
 61         /// <param name="node"></param>
 62         /// <returns></returns>
 63         public static Object[] ifEnd(Node node, double entropy,int[] isUsed)
 64         {
 65             try
 66             {
 67                 double[] count = node.ClassCount;
 68                 int rowCount = node.rowCount;
 69                 int maxResult = 0;
 70                 double maxRate = 0;
 71                 #region 數達到某一深度
 72                 int deep = node.deep;
 73                 if (deep >= maxDeep)
 74                 {
 75                     maxResult = node.result + 1;
 76                     node.feature_Type=("result");
 77                     node.features=(new List<String>() { maxResult + "" });
 78                     return new Object[] { true, node };
 79                 }
 80                 #endregion
 81                 #region 純度(其實跟後面的有點重了,記得要修改)
 82                 //maxResult = 1;
 83                 //for (int i = 1; i < count.Length; i++)
 84                 //{
 85                 //    if (count[i] / rowCount >= 0.95)
 86                 //    {
 87                 //        node.setFeatureType("result");
 88                 //        node.setFeatures(new List<String> { "" + (i + 1) });
 89                 //        return new Object[] { true, node };
 90                 //    }
 91                 //}
 92                 //node.setLeafWrong(rowCount - Convert.ToInt32(count[maxResult - 1]));
 93                 #endregion
 94                 #region 熵爲0
 95                 if (entropy == 0)
 96                 {
 97                     maxRate = count[0] / rowCount;
 98                     maxResult = 1;
 99                     for (int i = 1; i < count.Length; i++)
100                     {
101                         if (count[i] / rowCount >= maxRate)
102                         {
103                             maxRate = count[i] / rowCount;
104                             maxResult = i + 1;
105                         }
106                     }
107                     node.feature_Type=("result");
108                     node.features=(new List<String> { maxResult + "" });
109                     return new Object[] { true, node };
110                 }
111                 #endregion
112                 #region 屬性已經分完
113                 //int[] isUsed = node.;
114                 bool flag = true;
115                 for (int i = 0; i < isUsed.Length - 1; i++)
116                 {
117                     if (isUsed[i] == 0)
118                     {
119                         flag = false;
120                         break;
121                     }
122                 }
123                 if (flag)
124                 {
125                     maxRate = count[0] / rowCount;
126                     maxResult = 1;
127                     for (int i = 1; i < count.Length; i++)
128                     {
129                         if (count[i] / rowCount >= maxRate)
130                         {
131                             maxRate = count[i] / rowCount;
132                             maxResult = i + 1;
133                         }
134                     }
135                     node.feature_Type=("result");
136                     node.features=(new List<String> { "" + (maxResult) });
137                     return new Object[] { true, node };
138                 }
139                 #endregion
140                 #region 數據量少於100
141                 if (rowCount < Limit_Node)
142                 {
143                     maxRate = count[0] / rowCount;
144                     maxResult = 1;
145                     for (int i = 1; i < count.Length; i++)
146                     {
147                         if (count[i] / rowCount >= maxRate)
148                         {
149                             maxRate = count[i] / rowCount;
150                             maxResult = i + 1;
151                         }
152                     }
153                     node.feature_Type=("result");
154                     node.features=(new List<String> { "" + (maxResult) });
155                     return new Object[] { true, node };
156                 }
157                 #endregion
158                 return new Object[] { false, node };
159             }
160             catch (Exception e)
161             {
162                 return new Object[] { false, node };
163             }
164         }
165         #region 排序算法
166         public static void InsertSort(double[] values, List<int> arr, int StartIndex, int endIndex)
167         {
168             for (int i = StartIndex + 1; i <= endIndex; i++)
169             {
170                 int key = arr[i];
171                 double init = values[i];
172                 int j = i - 1;
173                 while (j >= StartIndex && values[j] > init)
174                 {
175                     arr[j + 1] = arr[j];
176                     values[j + 1] = values[j];
177                     j--;
178                 }
179                 arr[j + 1] = key;
180                 values[j + 1] = init;
181             }
182         }
183         static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high)
184         {
185             int mid = low + ((high - low) >> 1);//計算數組中間的元素的下標  
186 
187             //使用三數取中法選擇樞軸  
188             if (values[mid] > values[high])//目標: arr[mid] <= arr[high]  
189             {
190                 swap(values, arr, mid, high);
191             }
192             if (values[low] > values[high])//目標: arr[low] <= arr[high]  
193             {
194                 swap(values, arr, low, high);
195             }
196             if (values[mid] > values[low]) //目標: arr[low] >= arr[mid]  
197             {
198                 swap(values, arr, mid, low);
199             }
200             //此時,arr[mid] <= arr[low] <= arr[high]  
201             return low;
202             //low的位置上保存這三個位置中間的值  
203             //分割時能夠直接使用low位置的元素做爲樞軸,而不用改變分割函數了  
204         }
205         static void swap(double[] values, List<int> arr, int t1, int t2)
206         {
207             double temp = values[t1];
208             values[t1] = values[t2];
209             values[t2] = temp;
210             int key = arr[t1];
211             arr[t1] = arr[t2];
212             arr[t2] = key;
213         }
214         static void QSort(double[] values, List<int> arr, int low, int high)
215         {
216             int first = low;
217             int last = high;
218 
219             int left = low;
220             int right = high;
221 
222             int leftLen = 0;
223             int rightLen = 0;
224 
225             if (high - low + 1 < 10)
226             {
227                 InsertSort(values, arr, low, high);
228                 return;
229             }
230 
231             //一次分割 
232             int key = SelectPivotMedianOfThree(values, arr, low, high);//使用三數取中法選擇樞軸 
233             double inti = values[key];
234             int currentKey = arr[key];
235 
236             while (low < high)
237             {
238                 while (high > low && values[high] >= inti)
239                 {
240                     if (values[high] == inti)//處理相等元素  
241                     {
242                         swap(values, arr, right, high);
243                         right--;
244                         rightLen++;
245                     }
246                     high--;
247                 }
248                 arr[low] = arr[high];
249                 values[low] = values[high];
250                 while (high > low && values[low] <= inti)
251                 {
252                     if (values[low] == inti)
253                     {
254                         swap(values, arr, left, low);
255                         left++;
256                         leftLen++;
257                     }
258                     low++;
259                 }
260                 arr[high] = arr[low];
261                 values[high] = values[low];
262             }
263             arr[low] = currentKey;
264             values[low] = values[key];
265             //一次快排結束  
266             //把與樞軸key相同的元素移到樞軸最終位置周圍  
267             int i = low - 1;
268             int j = first;
269             while (j < left && values[i] != inti)
270             {
271                 swap(values, arr, i, j);
272                 i--;
273                 j++;
274             }
275             i = low + 1;
276             j = last;
277             while (j > right && values[i] != inti)
278             {
279                 swap(values, arr, i, j);
280                 i++;
281                 j--;
282             }
283             QSort(values, arr, first, low - 1 - leftLen);
284             QSort(values, arr, low + 1 + rightLen, last);
285         }
286         #endregion
287         /// <summary>
288         /// 尋找最佳的分裂點
289         /// </summary>
290         /// <param name="num"></param>
291         /// <param name="node"></param>
292         public static Node findBestSplit(Node node, int lastCol,List<int> nums,int[] isUsed)
293         {
294             try
295             {
296                 //判斷是否繼續分裂
297                 double totalShang = CalEntropy(node.ClassCount, nums.Count);
298                 Object[] check = ifEnd(node, totalShang, isUsed);
299                 if ((bool)check[0])
300                 {
301                     node = (Node)check[1];
302                     return node;
303                 }
304                 #region 變量聲明
305                 SplitInfo info = new SplitInfo();
306                 //int[] isUsed = node.getUsed();              //連續變量or離散變量
307                 //List<int> nums = node.getNum();             //樣本的標號
308                 int RowCount = nums.Count;                  //樣本總數
309                 double jubuMax = 0;                         //局部最大熵
310                 #endregion
311                 for (int i = 0; i < isUsed.Length - 1; i++)
312                 {
313                     if (isUsed[i] == 1)
314                     {
315                         continue;
316                     }
317                     #region 離散變量
318                     if (type[i] == 0)
319                     {
320                         int[] allFeatureCount = new int[0];         //全部類別的數量
321                         double[][] allCount = new double[allNum[i]][];
322                         for (int j = 0; j < allCount.Length; j++)
323                         {
324                             allCount[j] = new double[classCount];
325                         }
326                         int[] countAllFeature = new int[allNum[i]];
327                         List<int>[] temp = new List<int>[allNum[i]];
328                         for (int j = 0; j < temp.Length; j++)
329                         {
330                             temp[j] = new List<int>();
331                         }
332                         for (int j = 0; j < nums.Count; j++)
333                         {
334                             int index = Convert.ToInt32(allData[nums[j]][i]);
335                             temp[index - 1].Add(nums[j]);
336                             countAllFeature[index - 1]++;
337                             allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++;
338                         }
339                         double allShang = 0;
340                         for (int j = 0; j < allCount.Length; j++)
341                         {
342                             allShang = allShang + CalEntropy(allCount[j], countAllFeature[j]) * countAllFeature[j] / RowCount;
343                         }
344                         allShang = (totalShang - allShang);
345                         if (allShang > jubuMax)
346                         {
347                             info.features=new List<String>();
348                             info.type=0;
349                             info.temp=(temp);
350                             info.splitIndex=(i);
351                             info.class_Count=(allCount);
352                             jubuMax = allShang;
353                             allFeatureCount = countAllFeature;
354                         }
355                     }
356                     #endregion
357                     #region 連續變量
358                     else
359                     {
360                         double[] leftCount = new double[classCount];          //作節點各個類別的數量
361                         double[] rightCount = new double[classCount];        //右節點各個類別的數量
362                         double[] values = new double[nums.Count];
363                         List<String> List_Feature = new List<string>();
364                         for (int j = 0; j < values.Length; j++)
365                         {
366                             values[j] = allData[nums[j]][i];
367                         }
368                         QSort(values, nums, 0, nums.Count - 1);
369                         int eachNum = nums.Count / 5;
370                         double lianxuMax = 0;                                   //連續型屬性的最大熵
371                         int index = 1;
372                         double[][] counts = new double[5][];
373                         List<int>[] temp = new List<int>[5];
374                         for (int j = 0; j < 5; j++)
375                         {
376                             counts[j] = new double[classCount];
377                             temp[j] = new List<int>();
378                         }
379                         for (int j = 0; j < nums.Count - 1; j++)
380                         {
381                             if (j >= index * eachNum&&index<5)
382                             {
383                                 List_Feature.Add(allData[nums[j]][i]+"");
384                                 lianxuMax += eachNum*CalEntropy(counts[index - 1], eachNum)/RowCount;
385                                 index++;
386                             }
387                             temp[index-1].Add(nums[j]);
388                             counts[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1])-1]++;
389                         }
390                         lianxuMax += ((eachNum + nums.Count % 5)*CalEntropy(counts[index - 1], eachNum + nums.Count % 5) / RowCount);
391                         lianxuMax = totalShang - lianxuMax;
392                         if (lianxuMax > jubuMax)
393                         {
394                             info.splitIndex=(i);
395                             info.features=(List_Feature);
396                             info.type=(1);
397                             jubuMax = lianxuMax;
398                             info.temp=(temp);
399                             info.class_Count=(counts);
400                         }
401                     }
402                     #endregion
403                 }
404                 #region 如何找不到最佳的分裂屬性,則設爲葉節點
405                 if (info.splitIndex == -1)
406                 {
407                     double[] finalCount = node.ClassCount;
408                     double max = finalCount[0];
409                     int result = 1;
410                     for (int i = 1; i < finalCount.Length; i++)
411                     {
412                         if (finalCount[i] > max)
413                         {
414                             max = finalCount[i];
415                             result = (i + 1);
416                         }
417                     }
418                     node.feature_Type=("result");
419                     node.features=(new List<String> { "" + result });
420                     return node;
421                 }
422                 #endregion
423                 int deep = node.deep;
424                 #region 分裂
425                 node.SplitFeature=("" + info.splitIndex);
426                 
427                 List<Node> childNode = new List<Node>();
428                 int[] used = new int[isUsed.Length];
429                 for (int i = 0; i < used.Length; i++)
430                 {
431                     used[i] = isUsed[i];
432                 }
433                 if (info.type == 0)
434                 {
435                     used[info.splitIndex] = 1;
436                     node.feature_Type=("離散");
437                 }
438                 else
439                 {
440                     used[info.splitIndex] = 0;
441                     node.feature_Type=("連續");
442                 }
443                 int sumLeaf = 0;
444                 int sumWrong = 0;
445                 List<int>[] rowIndex = info.temp;
446                 List<String> features = info.features;
447                 for (int j = 0; j < rowIndex.Length; j++)
448                 {
449                     if (rowIndex[j].Count == 0)
450                     {
451                         continue;
452                     }
453                     if (info.type == 0)
454                         features.Add(""+(j+1));
455                     Node node1 = new Node();
456                     //node1.setNum(info.getTemp()[j]);
457                     node1.setClassCount(info.class_Count[j]);
458                     //node1.setUsed(used);
459                     node1.deep=(deep + 1);
460                     node1.rowCount = info.temp[j].Count;
461                     node1 = findBestSplit(node1, info.splitIndex,info.temp[j], used);
462                     childNode.Add(node1);
463                 }
464                 node.features=(features);
465                 node.childNodes=(childNode);
466                 
467                 #endregion
468                 return node;
469             }
470             catch (Exception e)
471             {
472                 Console.WriteLine(e.StackTrace);
473                 return node;
474             }
475         }
476         /// <summary>
477         /// 計算熵
478         /// </summary>
479         /// <param name="counts"></param>
480         /// <param name="countAll"></param>
481         /// <returns></returns>
482         public static double CalEntropy(double[] counts, int countAll)
483         {
484             try
485             {
486                 double allShang = 0;
487                 for (int i = 0; i < counts.Length; i++)
488                 {
489                     if (counts[i] == 0)
490                     {
491                         continue;
492                     }
493                     double rate = counts[i] / countAll;
494                     allShang = allShang + rate * Math.Log(rate, 2);
495                 }
496                 return -allShang;
497             }
498             catch (Exception e)
499             {
500                 return 0;
501             }
502         }
503         #endregion
View Code

 (注:上述代碼只是ID3的核心代碼,數據預處理的代碼並無給出,只要將預處理後的數據輸入到主方法findBestSplit中,就能夠獲得最終的結果)

總結

     ID3是基本的決策樹構建算法,做爲決策樹經典的構建算法,其具備結構簡單、清晰易懂的特色。雖然ID3比較靈活方便,可是有如下幾個缺點:

 (1)採用信息增益進行分裂,分裂的精確度可能沒有采用信息增益率進行分裂高

   (2)不能處理連續型數據,只能經過離散化將連續性數據轉化爲離散型數據

   (3)不能處理缺省值

   (4)沒有對決策樹進行剪枝處理,極可能會出現過擬合的問題

相關文章
相關標籤/搜索