預備知識:決策樹 html
初識ID3node
回顧決策樹的基本知識,其構建過程主要有下述三個重要的問題:算法
(1)數據是怎麼分裂的數組
(2)如何選擇分類的屬性ide
(3)何時中止分裂函數
從上述三個問題出發,以實際的例子對ID3算法進行闡述。測試
例:經過當天的天氣、溫度、溼度和季節預測明天的天氣this
表1 原始數據spa
當每天氣設計 |
溫度 |
溼度 |
季節 |
明每天氣 |
晴 |
25 |
50 |
春天 |
晴 |
陰 |
21 |
48 |
春天 |
陰 |
陰 |
18 |
70 |
春天 |
雨 |
晴 |
28 |
41 |
夏天 |
晴 |
雨 |
8 |
65 |
冬天 |
陰 |
晴 |
18 |
43 |
夏天 |
晴 |
陰 |
24 |
56 |
秋天 |
晴 |
雨 |
18 |
76 |
秋天 |
陰 |
雨 |
31 |
61 |
夏天 |
晴 |
陰 |
6 |
43 |
冬天 |
雨 |
晴 |
15 |
55 |
秋天 |
陰 |
雨 |
4 |
58 |
冬天 |
雨 |
1.數據分割
對於離散型數據,直接按照離散數據的取值進行分裂,每個取值對應一個子節點,以「當前天氣」爲例對數據進行分割,如圖1所示。
對於連續型數據,ID3本來是沒有處理能力的,只有經過離散化將連續性數據轉化成離散型數據再進行處理。
連續數據離散化是另一個課題,本文不深刻闡述,這裏直接採用等距離數據劃分的李算話方法。該方法先對數據進行排序,而後將連續型數據劃分爲多個區間,並使每個區間的數據量基本相同,以溫度爲例對數據進行分割,如圖2所示。
2. 選擇最優分裂屬性
ID3採用信息增益做爲選擇最優的分裂屬性的方法,選擇熵做爲衡量節點純度的標準,信息增益的計算公式以下:
其中, 表示父節點的熵; 表示節點i的熵,熵越大,節點的信息量越多,越不純; 表示子節點i的數據量與父節點數據量之比。 越大,表示分裂後的熵越小,子節點變得越純,分類的效果越好,所以選擇 最大的屬性做爲分裂屬性。
對上述的例子的跟節點進行分裂,分別計算每個屬性的信息增益,選擇信息增益最大的屬性進行分裂。
天氣屬性:(數據分割如上圖1所示)
溫度:(數據分割如上圖2所示)
溼度:
季節:
因爲最大,因此選擇屬性「季節」做爲根節點的分裂屬性。
3.中止分裂的條件
中止分裂的條件已經在決策樹中闡述,這裏再也不進行闡述。
(1)最小節點數
當節點的數據量小於一個指定的數量時,不繼續分裂。兩個緣由:一是數據量較少時,再作分裂容易強化噪聲數據的做用;二是下降樹生長的複雜性。提早結束分裂必定程度上有利於下降過擬合的影響。
(2)熵或者基尼值小於閥值。
由上述可知,熵和基尼值的大小表示數據的複雜程度,當熵或者基尼值太小時,表示數據的純度比較大,若是熵或者基尼值小於必定程度時,節點中止分裂。
(3)決策樹的深度達到指定的條件
節點的深度能夠理解爲節點與決策樹跟節點的距離,如根節點的子節點的深度爲1,由於這些節點與跟節點的距離爲1,子節點的深度要比父節點的深度大1。決策樹的深度是全部葉子節點的最大深度,當深度到達指定的上限大小時,中止分裂。
(4)全部特徵已經使用完畢,不能繼續進行分裂。
被動式中止分裂的條件,當已經沒有可分的屬性時,直接將當前節點設置爲葉子節點。
程序設計及源代碼(C#版本)
(1)數據處理
用二維數組存儲原始的數據,每一行表示一條記錄,前n-1列表示數據的屬性,第n列表示分類的標籤。
static double[][] allData;
爲了方便後面的處理,對離散屬性進行數字化處理,將離散值表示成數字,並用一個鏈表數組進行存儲,數組的第一個元素表示屬性1的離散值。
static List<String>[] featureValues;
那麼通過處理後的表1數據能夠轉化爲如表2所示的數據:
表2 初始化後的數據
當每天氣 |
溫度 |
溼度 |
季節 |
明每天氣 |
1 |
25 |
50 |
1 |
1 |
2 |
21 |
48 |
1 |
2 |
2 |
18 |
70 |
1 |
3 |
1 |
28 |
41 |
2 |
1 |
3 |
8 |
65 |
3 |
2 |
1 |
18 |
43 |
2 |
1 |
2 |
24 |
56 |
4 |
1 |
3 |
18 |
76 |
4 |
2 |
3 |
31 |
61 |
2 |
1 |
2 |
6 |
43 |
3 |
3 |
1 |
15 |
55 |
4 |
2 |
3 |
4 |
58 |
3 |
3 |
其中,對於當每天氣屬性,數字{1,2,3}分別表示{晴,陰,雨};對於季節屬性{1,2,3,4}分別表示{春天、夏天、冬天、秋天};對於明每天氣{1,2,3}分別表示{晴、陰、雨}。
(2)兩個類:節點類和分裂信息
a)節點類Node
該類存儲了節點的信息,包括節點的數據量、節點選擇的分裂屬性、節點輸出類、子節點的個數、子節點的分類偏差等。
1 class Node 2 { 3 /// <summary> 4 /// 各個子節點的取值 5 /// </summary> 6 public List<String> features { get; set; } 7 /// <summary> 8 /// 分裂屬性的類型 9 /// </summary> 10 public String feature_Type { get; set; } 11 /// <summary> 12 /// 分裂的屬性 13 /// </summary> 14 public String SplitFeature { get; set; } 15 /// <summary> 16 /// 節點對應各個分類的數目 17 /// </summary> 18 public double[] ClassCount { get; set; } 19 /// <summary> 20 /// 各個孩子節點 21 /// </summary> 22 public List<Node> childNodes { get; set; } 23 /// <summary> 24 /// 父親節點(未用到) 25 /// </summary> 26 public Node Parent { get; set; } 27 /// <summary> 28 /// 佔比最大的類別 29 /// </summary> 30 public String finalResult { get; set; } 31 /// <summary> 32 /// 數的深度 33 /// </summary> 34 public int deep { get; set; } 35 /// <summary> 36 /// 該節點佔比最大的類標號 37 /// </summary> 38 public int result { get; set; } 39 /// <summary> 40 /// 節點的數量 41 /// </summary> 42 public int rowCount{ get; set; } 43 44 45 public void setClassCount(double[] count) 46 { 47 this.ClassCount = count; 48 double max = ClassCount[0]; 49 int result = 0; 50 for (int i = 1; i < ClassCount.Length; i++) 51 { 52 if (max < ClassCount[i]) 53 { 54 max = ClassCount[i]; 55 result = i; 56 } 57 } 58 //wrong = Convert.ToInt32(nums.Count - ClassCount[result]); 59 this.result = result; 60 } 61 }
b)分裂信息類SplitInfo
該類存儲節點進行分裂的信息,包括各個子節點的行座標、子節點各個類的數目、該節點分裂的屬性、屬性的類型等。
1 class SplitInfo 2 { 3 /// <summary> 4 /// 分裂的列下標 5 /// </summary> 6 public int splitIndex { get; set; } 7 /// <summary> 8 /// 數據類型 9 /// </summary> 10 public int type { get; set; } 11 /// <summary> 12 /// 分裂屬性的取值 13 /// </summary> 14 public List<String> features { get; set; } 15 /// <summary> 16 /// 各個節點的行座標鏈表 17 /// </summary> 18 public List<int>[] temp { get; set; } 19 /// <summary> 20 /// 每一個節點各種的數目 21 /// </summary> 22 public double[][] class_Count { get; set; } 23 }
(3)節點分裂方法findBestSplit(Node node,List<int> nums,int[] isUsed),該方法對節點進行分裂,返回值Node
其中:
node表示即將進行分裂的節點;
nums表示節點數據對應的行座標列表;
isUsed表示到該節點位置全部屬性的使用狀況(1:表示該屬性不能再次使用,0:表示該屬性可使用);
findBestSplit主要有如下幾個組成部分:
1)節點分裂中止的斷定
判斷節點是否須要繼續分裂,分裂判斷條件如上文所述。源代碼以下:
1 public static Object[] ifEnd(Node node, double entropy,int[] isUsed) 2 { 3 try 4 { 5 double[] count = node.ClassCount; 6 int rowCount = node.rowCount; 7 int maxResult = 0; 8 double maxRate = 0; 9 #region 數達到某一深度 10 int deep = node.deep; 11 if (deep >= maxDeep) 12 { 13 maxResult = node.result + 1; 14 node.feature_Type=("result"); 15 node.features=(new List<String>() { maxResult + "" }); 16 return new Object[] { true, node }; 17 } 18 #endregion 19 #region 純度(其實跟後面的有點重了,記得要修改) 20 //maxResult = 1; 21 //for (int i = 1; i < count.Length; i++) 22 //{ 23 // if (count[i] / rowCount >= 0.95) 24 // { 25 // node.setFeatureType("result"); 26 // node.setFeatures(new List<String> { "" + (i + 1) }); 27 // return new Object[] { true, node }; 28 // } 29 //} 30 //node.setLeafWrong(rowCount - Convert.ToInt32(count[maxResult - 1])); 31 #endregion 32 #region 熵爲0 33 if (entropy == 0) 34 { 35 maxRate = count[0] / rowCount; 36 maxResult = 1; 37 for (int i = 1; i < count.Length; i++) 38 { 39 if (count[i] / rowCount >= maxRate) 40 { 41 maxRate = count[i] / rowCount; 42 maxResult = i + 1; 43 } 44 } 45 node.feature_Type=("result"); 46 node.features=(new List<String> { maxResult + "" }); 47 return new Object[] { true, node }; 48 } 49 #endregion 50 #region 屬性已經分完 51 //int[] isUsed = node.; 52 bool flag = true; 53 for (int i = 0; i < isUsed.Length - 1; i++) 54 { 55 if (isUsed[i] == 0) 56 { 57 flag = false; 58 break; 59 } 60 } 61 if (flag) 62 { 63 maxRate = count[0] / rowCount; 64 maxResult = 1; 65 for (int i = 1; i < count.Length; i++) 66 { 67 if (count[i] / rowCount >= maxRate) 68 { 69 maxRate = count[i] / rowCount; 70 maxResult = i + 1; 71 } 72 } 73 node.feature_Type=("result"); 74 node.features=(new List<String> { "" + (maxResult) }); 75 return new Object[] { true, node }; 76 } 77 #endregion 78 #region 數據量少於100 79 if (rowCount < Limit_Node) 80 { 81 maxRate = count[0] / rowCount; 82 maxResult = 1; 83 for (int i = 1; i < count.Length; i++) 84 { 85 if (count[i] / rowCount >= maxRate) 86 { 87 maxRate = count[i] / rowCount; 88 maxResult = i + 1; 89 } 90 } 91 node.feature_Type=("result"); 92 node.features=(new List<String> { "" + (maxResult) }); 93 return new Object[] { true, node }; 94 } 95 #endregion 96 return new Object[] { false, node }; 97 } 98 catch (Exception e) 99 { 100 return new Object[] { false, node }; 101 } 102 }
2)尋找最優的分裂屬性
尋找最優的分裂屬性須要計算每個分裂屬性分裂後的信息增益,計算公式上文已給出,其中熵的計算代碼以下:
1 public static double CalEntropy(double[] counts, int countAll) 2 { 3 try 4 { 5 double allShang = 0; 6 for (int i = 0; i < counts.Length; i++) 7 { 8 if (counts[i] == 0) 9 { 10 continue; 11 } 12 double rate = counts[i] / countAll; 13 allShang = allShang + rate * Math.Log(rate, 2); 14 } 15 return -allShang; 16 } 17 catch (Exception e) 18 { 19 return 0; 20 } 21 }
3)進行分裂,同時子節點也執行相同的分類步驟
其實就是遞歸的過程,對每個子節點執行findBestSplit方法進行分裂。
所有源代碼:
1 #region ID3核心算法 2 /// <summary> 3 /// 測試 4 /// </summary> 5 /// <param name="node"></param> 6 /// <param name="data"></param> 7 public static String findResult(Node node, String[] data) 8 { 9 List<String> featrues = node.features; 10 String type = node.feature_Type; 11 if (type == "result") 12 { 13 return featrues[0]; 14 } 15 int split = Convert.ToInt32(node.SplitFeature); 16 List<Node> childNodes = node.childNodes; 17 double[] resultCount = node.ClassCount; 18 if (type == "連續") 19 { 20 21 22 for (int i = 0; i < featrues.Count; i++) 23 { 24 double value = Convert.ToDouble(featrues[i]); 25 if (Convert.ToDouble(data[split]) <= value) 26 { 27 return findResult(childNodes[i], data); 28 } 29 } 30 return findResult(childNodes[featrues.Count], data); 31 } 32 else 33 { 34 for (int i = 0; i < featrues.Count; i++) 35 { 36 if (data[split] == featrues[i]) 37 { 38 return findResult(childNodes[i], data); 39 } 40 if (i == featrues.Count - 1) 41 { 42 double count = resultCount[0]; 43 int maxInt = 0; 44 for (int j = 1; j < resultCount.Length; j++) 45 { 46 if (count < resultCount[j]) 47 { 48 count = resultCount[j]; 49 maxInt = j; 50 } 51 } 52 return findResult(childNodes[0], data); 53 } 54 } 55 } 56 return null; 57 } 58 /// <summary> 59 /// 判斷是否還須要分裂 60 /// </summary> 61 /// <param name="node"></param> 62 /// <returns></returns> 63 public static Object[] ifEnd(Node node, double entropy,int[] isUsed) 64 { 65 try 66 { 67 double[] count = node.ClassCount; 68 int rowCount = node.rowCount; 69 int maxResult = 0; 70 double maxRate = 0; 71 #region 數達到某一深度 72 int deep = node.deep; 73 if (deep >= maxDeep) 74 { 75 maxResult = node.result + 1; 76 node.feature_Type=("result"); 77 node.features=(new List<String>() { maxResult + "" }); 78 return new Object[] { true, node }; 79 } 80 #endregion 81 #region 純度(其實跟後面的有點重了,記得要修改) 82 //maxResult = 1; 83 //for (int i = 1; i < count.Length; i++) 84 //{ 85 // if (count[i] / rowCount >= 0.95) 86 // { 87 // node.setFeatureType("result"); 88 // node.setFeatures(new List<String> { "" + (i + 1) }); 89 // return new Object[] { true, node }; 90 // } 91 //} 92 //node.setLeafWrong(rowCount - Convert.ToInt32(count[maxResult - 1])); 93 #endregion 94 #region 熵爲0 95 if (entropy == 0) 96 { 97 maxRate = count[0] / rowCount; 98 maxResult = 1; 99 for (int i = 1; i < count.Length; i++) 100 { 101 if (count[i] / rowCount >= maxRate) 102 { 103 maxRate = count[i] / rowCount; 104 maxResult = i + 1; 105 } 106 } 107 node.feature_Type=("result"); 108 node.features=(new List<String> { maxResult + "" }); 109 return new Object[] { true, node }; 110 } 111 #endregion 112 #region 屬性已經分完 113 //int[] isUsed = node.; 114 bool flag = true; 115 for (int i = 0; i < isUsed.Length - 1; i++) 116 { 117 if (isUsed[i] == 0) 118 { 119 flag = false; 120 break; 121 } 122 } 123 if (flag) 124 { 125 maxRate = count[0] / rowCount; 126 maxResult = 1; 127 for (int i = 1; i < count.Length; i++) 128 { 129 if (count[i] / rowCount >= maxRate) 130 { 131 maxRate = count[i] / rowCount; 132 maxResult = i + 1; 133 } 134 } 135 node.feature_Type=("result"); 136 node.features=(new List<String> { "" + (maxResult) }); 137 return new Object[] { true, node }; 138 } 139 #endregion 140 #region 數據量少於100 141 if (rowCount < Limit_Node) 142 { 143 maxRate = count[0] / rowCount; 144 maxResult = 1; 145 for (int i = 1; i < count.Length; i++) 146 { 147 if (count[i] / rowCount >= maxRate) 148 { 149 maxRate = count[i] / rowCount; 150 maxResult = i + 1; 151 } 152 } 153 node.feature_Type=("result"); 154 node.features=(new List<String> { "" + (maxResult) }); 155 return new Object[] { true, node }; 156 } 157 #endregion 158 return new Object[] { false, node }; 159 } 160 catch (Exception e) 161 { 162 return new Object[] { false, node }; 163 } 164 } 165 #region 排序算法 166 public static void InsertSort(double[] values, List<int> arr, int StartIndex, int endIndex) 167 { 168 for (int i = StartIndex + 1; i <= endIndex; i++) 169 { 170 int key = arr[i]; 171 double init = values[i]; 172 int j = i - 1; 173 while (j >= StartIndex && values[j] > init) 174 { 175 arr[j + 1] = arr[j]; 176 values[j + 1] = values[j]; 177 j--; 178 } 179 arr[j + 1] = key; 180 values[j + 1] = init; 181 } 182 } 183 static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high) 184 { 185 int mid = low + ((high - low) >> 1);//計算數組中間的元素的下標 186 187 //使用三數取中法選擇樞軸 188 if (values[mid] > values[high])//目標: arr[mid] <= arr[high] 189 { 190 swap(values, arr, mid, high); 191 } 192 if (values[low] > values[high])//目標: arr[low] <= arr[high] 193 { 194 swap(values, arr, low, high); 195 } 196 if (values[mid] > values[low]) //目標: arr[low] >= arr[mid] 197 { 198 swap(values, arr, mid, low); 199 } 200 //此時,arr[mid] <= arr[low] <= arr[high] 201 return low; 202 //low的位置上保存這三個位置中間的值 203 //分割時能夠直接使用low位置的元素做爲樞軸,而不用改變分割函數了 204 } 205 static void swap(double[] values, List<int> arr, int t1, int t2) 206 { 207 double temp = values[t1]; 208 values[t1] = values[t2]; 209 values[t2] = temp; 210 int key = arr[t1]; 211 arr[t1] = arr[t2]; 212 arr[t2] = key; 213 } 214 static void QSort(double[] values, List<int> arr, int low, int high) 215 { 216 int first = low; 217 int last = high; 218 219 int left = low; 220 int right = high; 221 222 int leftLen = 0; 223 int rightLen = 0; 224 225 if (high - low + 1 < 10) 226 { 227 InsertSort(values, arr, low, high); 228 return; 229 } 230 231 //一次分割 232 int key = SelectPivotMedianOfThree(values, arr, low, high);//使用三數取中法選擇樞軸 233 double inti = values[key]; 234 int currentKey = arr[key]; 235 236 while (low < high) 237 { 238 while (high > low && values[high] >= inti) 239 { 240 if (values[high] == inti)//處理相等元素 241 { 242 swap(values, arr, right, high); 243 right--; 244 rightLen++; 245 } 246 high--; 247 } 248 arr[low] = arr[high]; 249 values[low] = values[high]; 250 while (high > low && values[low] <= inti) 251 { 252 if (values[low] == inti) 253 { 254 swap(values, arr, left, low); 255 left++; 256 leftLen++; 257 } 258 low++; 259 } 260 arr[high] = arr[low]; 261 values[high] = values[low]; 262 } 263 arr[low] = currentKey; 264 values[low] = values[key]; 265 //一次快排結束 266 //把與樞軸key相同的元素移到樞軸最終位置周圍 267 int i = low - 1; 268 int j = first; 269 while (j < left && values[i] != inti) 270 { 271 swap(values, arr, i, j); 272 i--; 273 j++; 274 } 275 i = low + 1; 276 j = last; 277 while (j > right && values[i] != inti) 278 { 279 swap(values, arr, i, j); 280 i++; 281 j--; 282 } 283 QSort(values, arr, first, low - 1 - leftLen); 284 QSort(values, arr, low + 1 + rightLen, last); 285 } 286 #endregion 287 /// <summary> 288 /// 尋找最佳的分裂點 289 /// </summary> 290 /// <param name="num"></param> 291 /// <param name="node"></param> 292 public static Node findBestSplit(Node node, int lastCol,List<int> nums,int[] isUsed) 293 { 294 try 295 { 296 //判斷是否繼續分裂 297 double totalShang = CalEntropy(node.ClassCount, nums.Count); 298 Object[] check = ifEnd(node, totalShang, isUsed); 299 if ((bool)check[0]) 300 { 301 node = (Node)check[1]; 302 return node; 303 } 304 #region 變量聲明 305 SplitInfo info = new SplitInfo(); 306 //int[] isUsed = node.getUsed(); //連續變量or離散變量 307 //List<int> nums = node.getNum(); //樣本的標號 308 int RowCount = nums.Count; //樣本總數 309 double jubuMax = 0; //局部最大熵 310 #endregion 311 for (int i = 0; i < isUsed.Length - 1; i++) 312 { 313 if (isUsed[i] == 1) 314 { 315 continue; 316 } 317 #region 離散變量 318 if (type[i] == 0) 319 { 320 int[] allFeatureCount = new int[0]; //全部類別的數量 321 double[][] allCount = new double[allNum[i]][]; 322 for (int j = 0; j < allCount.Length; j++) 323 { 324 allCount[j] = new double[classCount]; 325 } 326 int[] countAllFeature = new int[allNum[i]]; 327 List<int>[] temp = new List<int>[allNum[i]]; 328 for (int j = 0; j < temp.Length; j++) 329 { 330 temp[j] = new List<int>(); 331 } 332 for (int j = 0; j < nums.Count; j++) 333 { 334 int index = Convert.ToInt32(allData[nums[j]][i]); 335 temp[index - 1].Add(nums[j]); 336 countAllFeature[index - 1]++; 337 allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++; 338 } 339 double allShang = 0; 340 for (int j = 0; j < allCount.Length; j++) 341 { 342 allShang = allShang + CalEntropy(allCount[j], countAllFeature[j]) * countAllFeature[j] / RowCount; 343 } 344 allShang = (totalShang - allShang); 345 if (allShang > jubuMax) 346 { 347 info.features=new List<String>(); 348 info.type=0; 349 info.temp=(temp); 350 info.splitIndex=(i); 351 info.class_Count=(allCount); 352 jubuMax = allShang; 353 allFeatureCount = countAllFeature; 354 } 355 } 356 #endregion 357 #region 連續變量 358 else 359 { 360 double[] leftCount = new double[classCount]; //作節點各個類別的數量 361 double[] rightCount = new double[classCount]; //右節點各個類別的數量 362 double[] values = new double[nums.Count]; 363 List<String> List_Feature = new List<string>(); 364 for (int j = 0; j < values.Length; j++) 365 { 366 values[j] = allData[nums[j]][i]; 367 } 368 QSort(values, nums, 0, nums.Count - 1); 369 int eachNum = nums.Count / 5; 370 double lianxuMax = 0; //連續型屬性的最大熵 371 int index = 1; 372 double[][] counts = new double[5][]; 373 List<int>[] temp = new List<int>[5]; 374 for (int j = 0; j < 5; j++) 375 { 376 counts[j] = new double[classCount]; 377 temp[j] = new List<int>(); 378 } 379 for (int j = 0; j < nums.Count - 1; j++) 380 { 381 if (j >= index * eachNum&&index<5) 382 { 383 List_Feature.Add(allData[nums[j]][i]+""); 384 lianxuMax += eachNum*CalEntropy(counts[index - 1], eachNum)/RowCount; 385 index++; 386 } 387 temp[index-1].Add(nums[j]); 388 counts[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1])-1]++; 389 } 390 lianxuMax += ((eachNum + nums.Count % 5)*CalEntropy(counts[index - 1], eachNum + nums.Count % 5) / RowCount); 391 lianxuMax = totalShang - lianxuMax; 392 if (lianxuMax > jubuMax) 393 { 394 info.splitIndex=(i); 395 info.features=(List_Feature); 396 info.type=(1); 397 jubuMax = lianxuMax; 398 info.temp=(temp); 399 info.class_Count=(counts); 400 } 401 } 402 #endregion 403 } 404 #region 如何找不到最佳的分裂屬性,則設爲葉節點 405 if (info.splitIndex == -1) 406 { 407 double[] finalCount = node.ClassCount; 408 double max = finalCount[0]; 409 int result = 1; 410 for (int i = 1; i < finalCount.Length; i++) 411 { 412 if (finalCount[i] > max) 413 { 414 max = finalCount[i]; 415 result = (i + 1); 416 } 417 } 418 node.feature_Type=("result"); 419 node.features=(new List<String> { "" + result }); 420 return node; 421 } 422 #endregion 423 int deep = node.deep; 424 #region 分裂 425 node.SplitFeature=("" + info.splitIndex); 426 427 List<Node> childNode = new List<Node>(); 428 int[] used = new int[isUsed.Length]; 429 for (int i = 0; i < used.Length; i++) 430 { 431 used[i] = isUsed[i]; 432 } 433 if (info.type == 0) 434 { 435 used[info.splitIndex] = 1; 436 node.feature_Type=("離散"); 437 } 438 else 439 { 440 used[info.splitIndex] = 0; 441 node.feature_Type=("連續"); 442 } 443 int sumLeaf = 0; 444 int sumWrong = 0; 445 List<int>[] rowIndex = info.temp; 446 List<String> features = info.features; 447 for (int j = 0; j < rowIndex.Length; j++) 448 { 449 if (rowIndex[j].Count == 0) 450 { 451 continue; 452 } 453 if (info.type == 0) 454 features.Add(""+(j+1)); 455 Node node1 = new Node(); 456 //node1.setNum(info.getTemp()[j]); 457 node1.setClassCount(info.class_Count[j]); 458 //node1.setUsed(used); 459 node1.deep=(deep + 1); 460 node1.rowCount = info.temp[j].Count; 461 node1 = findBestSplit(node1, info.splitIndex,info.temp[j], used); 462 childNode.Add(node1); 463 } 464 node.features=(features); 465 node.childNodes=(childNode); 466 467 #endregion 468 return node; 469 } 470 catch (Exception e) 471 { 472 Console.WriteLine(e.StackTrace); 473 return node; 474 } 475 } 476 /// <summary> 477 /// 計算熵 478 /// </summary> 479 /// <param name="counts"></param> 480 /// <param name="countAll"></param> 481 /// <returns></returns> 482 public static double CalEntropy(double[] counts, int countAll) 483 { 484 try 485 { 486 double allShang = 0; 487 for (int i = 0; i < counts.Length; i++) 488 { 489 if (counts[i] == 0) 490 { 491 continue; 492 } 493 double rate = counts[i] / countAll; 494 allShang = allShang + rate * Math.Log(rate, 2); 495 } 496 return -allShang; 497 } 498 catch (Exception e) 499 { 500 return 0; 501 } 502 } 503 #endregion
(注:上述代碼只是ID3的核心代碼,數據預處理的代碼並無給出,只要將預處理後的數據輸入到主方法findBestSplit中,就能夠獲得最終的結果)
總結
ID3是基本的決策樹構建算法,做爲決策樹經典的構建算法,其具備結構簡單、清晰易懂的特色。雖然ID3比較靈活方便,可是有如下幾個缺點:
(1)採用信息增益進行分裂,分裂的精確度可能沒有采用信息增益率進行分裂高
(2)不能處理連續型數據,只能經過離散化將連續性數據轉化爲離散型數據
(3)不能處理缺省值
(4)沒有對決策樹進行剪枝處理,極可能會出現過擬合的問題