如上一篇文章所述,ID3方法主要有幾個缺點:一是採用信息增益進行數據分裂,準確性不如信息增益率;二是不能對連續數據進行處理,只能經過連續數據離散化進行處理;三是沒有采用剪枝的策略,決策樹的結構可能會過於複雜,可能會出現過擬合的狀況。node
C4.5在ID3的基礎上對上述三個方面進行了相應的改進:算法
a) C4.5對節點進行分裂時採用信息增益率做爲分裂的依據;數組
b) 可以對連續數據進行處理;ide
c) C4.5採用剪枝的策略,對徹底生長的決策樹進行剪枝處理,必定程度上下降過擬合的影響。函數
1.採用信息增益率做爲分裂的依據測試
信息增益率的計算公式爲:this
其中表示信息增益,表示分裂子節點數據量的信息增益,計算公式爲:spa
其中m表示節點的數量,Ni表示第i個節點的數據量,N表示父親節點的數據量,說白了,實際上是分裂節點的熵。設計
信息增益率越大,說明分裂的效果越好。
以一個實際的例子說明C4.5如何經過信息增益率選擇分裂的屬性:
表1 原始數據表
當每天氣 |
溫度 |
溼度 |
日期 |
逛街 |
晴 |
25 |
50 |
工做日 |
否 |
晴 |
21 |
48 |
工做日 |
是 |
晴 |
18 |
70 |
週末 |
是 |
晴 |
28 |
41 |
週末 |
是 |
陰 |
8 |
65 |
工做日 |
是 |
陰 |
18 |
43 |
工做日 |
否 |
陰 |
24 |
56 |
週末 |
是 |
陰 |
18 |
76 |
週末 |
否 |
雨 |
31 |
61 |
週末 |
否 |
雨 |
6 |
43 |
週末 |
是 |
雨 |
15 |
55 |
工做日 |
否 |
雨 |
4 |
58 |
工做日 |
否 |
以當每天氣爲例:
一共有三個屬性值,晴、陰、雨,一共分裂成三個子節點。
根據上述公式,能夠計算信息增益率以下:
因此使用天氣屬性進行分裂能夠獲得信息增益率0.44。
2.對連續型屬性進行處理
C4.5處理離散型屬性的方式與ID3一致,新增對連續型屬性的處理。處理方式是先根據連續型屬性進行排序,而後採用一刀切的方式將數據砍成兩半。
那麼如何選擇切割點呢?很簡單,直接計算每個切割點切割後的信息增益,而後選擇使分裂效果最優的切割點。以溫度爲例:
從上圖能夠看出,理論上來說,N條數據就有N-1個切割點,爲了選取最優的切割墊,要計算按每一次切割的信息增益,計算量是比較大的,那麼有沒有簡化的方法呢?有,注意到,其實有些切割點是很明顯能夠排除的。好比說上圖右側的第2條和第3條記錄,二者的類標籤(逛街)都是「是」,若是從這裏切割的話,就將兩個原本相同的類分開了,確定不會比將他們歸爲一類的切分方法好,所以,能夠經過去除先後兩個類標籤相同的切割點以簡化計算的複雜度,以下圖所示:
從圖中能夠看出,最終切割點的數目從原來的11個減小到如今的6個,下降了計算的複雜度。
肯定了分割點以後,接下來就是選擇最優的分割點了,注意,對連續型屬性是採用信息增益進行內部擇優的,由於若是使用信息增益率進行分裂會出現傾向於選擇分割先後兩個節點數據量相差最大的分割點,爲了不這種狀況,選擇信息增益選擇分割點。選擇了最優的分割點以後,再計算信息增益率跟其餘的屬性進行比較,肯定最優的分裂屬性。
3. 剪枝
決策樹只已經提到,剪枝是在徹底生長的決策樹的基礎上,對生長後分類效果不佳的子樹進行修剪,減少決策樹的複雜度,下降過擬合的影響。
C4.5採用悲觀剪枝方法(PEP)。悲觀剪枝認爲若是決策樹的精度在剪枝先後沒有影響的話,則進行剪枝。怎樣纔算是沒有影響?若是剪枝後的偏差小於剪枝前經度的上限,則說明剪枝後的效果與更佳,此時須要子樹進行剪枝操做。
進行剪枝必須知足的條件:
其中:
表示子樹的偏差;
表示葉子節點的偏差;
令子樹偏差的經度知足二項分佈,根據二項分佈的性質,,,其中,N爲子樹的數據量;一樣,葉子節點的偏差。
上述公式中,0.5表示修正因子。因爲對父節點進行分裂總會獲得比父節點分類結果更好的效果,所以,所以從理論上來講,父節點的偏差老是不小於孩子節點的偏差,所以須要進行修正,給每個節點都加上0.5的修正所以,在計算偏差的時候,子節點因爲加上了修正的因子,就沒法保證總偏差老是低於父節點。
算例:
因爲,因此應該進行剪枝。
程序設計及源代碼(C#版)
程序的設計過程
(1)數據格式
對原始的數據進行數字化處理,並以二維數據的形式存儲,每一行表示一條記錄,前n-1列表示屬性,最後一列表示分類的標籤。
如表1的數據能夠轉化爲表2:
表2 初始化後的數據
當每天氣 |
溫度 |
溼度 |
季節 |
明每天氣 |
1 |
25 |
50 |
1 |
1 |
2 |
21 |
48 |
1 |
2 |
2 |
18 |
70 |
1 |
3 |
1 |
28 |
41 |
2 |
1 |
3 |
8 |
65 |
3 |
2 |
1 |
18 |
43 |
2 |
1 |
2 |
24 |
56 |
4 |
1 |
3 |
18 |
76 |
4 |
2 |
3 |
31 |
61 |
2 |
1 |
2 |
6 |
43 |
3 |
3 |
1 |
15 |
55 |
4 |
2 |
3 |
4 |
58 |
3 |
3 |
其中,對於「當每天氣」屬性,數字{1,2,3}分別表示{晴,陰,雨};對於「季節」屬性{1,2,3,4}分別表示{春天、夏天、冬天、秋天};對於類標籤「明每天氣」,數字{1,2,3}分別表示{晴、陰、雨}。
代碼以下所示:
static double[][] allData; //存儲進行訓練的數據
static List<String>[] featureValues; //離散屬性對應的離散值
featureValues是鏈表數組,數組的長度爲屬性的個數,數組的每一個元素爲該屬性的離散值鏈表。
(2)兩個類:節點類和分裂信息
a)節點類Node
該類表示一個節點,屬性包括節點選擇的分裂屬性、節點的輸出類、孩子節點、深度等。注意,與ID3中相比,新增了兩個屬性:leafWrong和leafNode_Count分別表示葉子節點的總分類偏差和葉子節點的個數,主要是爲了方便剪枝。
1 class Node 2 { 3 /// <summary> 4 /// 各個子節點對應的取值 5 /// </summary> 6 //public List<String> features; 7 public List<String> features{get;set;} 8 /// <summary> 9 /// 分裂屬性的數據類型(1:連續 0:離散) 10 /// </summary> 11 public String feature_Type {get;set;} 12 /// <summary> 13 /// 分裂屬性列的下標 14 /// </summary> 15 public String SplitFeature {get;set;} 16 /// <summary> 17 /// 各種別的數量統計 18 /// </summary> 19 public double[] ClassCount {get;set;} 20 /// <summary> 21 /// 數據量 22 /// </summary> 23 public int rowCount { get; set; } 24 /// <summary> 25 /// 各個子節點 26 /// </summary> 27 public List<Node> childNodes {get;set;} 28 /// <summary> 29 /// 父親節點 30 /// </summary> 31 public Node Parent {get;set;} 32 /// <summary> 33 /// 該節點佔比最大的類別 34 /// </summary> 35 public String finalResult {get;set;} 36 /// <summary> 37 /// 數的深度 38 /// </summary> 39 public int deep {get;set;} 40 /// <summary> 41 /// 節點佔比最大類的標號 42 /// </summary> 43 public int result {get;set;} 44 /// <summary> 45 /// 子節點的錯誤數 46 /// </summary> 47 public int leafWrong {get;set;} 48 /// <summary> 49 /// 子節點的數目 50 /// </summary> 51 public int leafNode_Count {get;set;} 52 53 public double getErrorCount() 54 { 55 return rowCount - ClassCount[result]; 56 } 57 #region 58 public void setClassCount(double[] count) 59 { 60 this.ClassCount = count; 61 double max = ClassCount[0]; 62 int result = 0; 63 for (int i = 1; i < ClassCount.Length; i++) 64 { 65 if (max < ClassCount[i]) 66 { 67 max = ClassCount[i]; 68 result = i; 69 } 70 } 71 this.result = result; 72 } 73 #endregion 74 }
b)分裂信息類,該類存儲節點進行分裂的信息,包括各個子節點的行座標、子節點各個類的數目、該節點分裂的屬性、屬性的類型等。
1 class SplitInfo 2 { 3 /// <summary> 4 /// 分裂的屬性下標 5 /// </summary> 6 public int splitIndex { get; set; } 7 /// <summary> 8 /// 數據類型 9 /// </summary> 10 public int type { get; set; } 11 /// <summary> 12 /// 分裂屬性的取值 13 /// </summary> 14 public List<String> features { get; set; } 15 /// <summary> 16 /// 各個節點的行座標鏈表 17 /// </summary> 18 public List<int>[] temp { get; set; } 19 /// <summary> 20 /// 每一個節點各種的數目 21 /// </summary> 22 public double[][] class_Count { get; set; } 23 }
主方法findBestSplit(Node node,List<int> nums,int[] isUsed),該方法對節點進行分裂
其中:
node表示即將進行分裂的節點;
nums表示節點數據的行座標列表;
isUsed表示到該節點位置全部屬性的使用狀況;
findBestSplit的這個方法主要有如下幾個組成部分:
1)節點分裂中止的斷定
節點分裂條件如上文所述,源代碼以下:
1 public static bool ifEnd(Node node, double entropy,int[] isUsed) 2 { 3 try 4 { 5 double[] count = node.ClassCount; 6 int rowCount = node.rowCount; 7 int maxResult = 0; 8 #region 數達到某一深度 9 int deep = node.deep; 10 if (deep >= maxDeep) 11 { 12 maxResult = node.result + 1; 13 node.feature_Type=("result"); 14 node.features=(new List<String>() { maxResult + "" }); 15 node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1])); 16 node.leafNode_Count = 1; 17 return true; 18 } 19 #endregion 20 #region 純度(其實跟後面的有點重了,記得要修改) 21 //maxResult = 1; 22 //for (int i = 1; i < count.Length; i++) 23 //{ 24 // if (count[i] / rowCount >= 0.95) 25 // { 26 // node.feature_Type=("result"); 27 // node.features=(new List<String> { "" + (i + 1) }); 28 // node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1])); 29 // node.leafNode_Count = 1; 30 // return true; 31 // } 32 //} 33 #endregion 34 #region 熵爲0 35 if (entropy == 0) 36 { 37 maxResult = node.result+1; 38 node.feature_Type=("result"); 39 node.features=(new List<String> { maxResult + "" }); 40 node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1])); 41 node.leafNode_Count = 1; 42 return true; 43 } 44 #endregion 45 #region 屬性已經分完 46 bool flag = true; 47 for (int i = 0; i < isUsed.Length - 1; i++) 48 { 49 if (isUsed[i] == 0) 50 { 51 flag = false; 52 break; 53 } 54 } 55 if (flag) 56 { 57 maxResult = node.result+1; 58 node.feature_Type=("result"); 59 node.features=(new List<String> { "" + (maxResult) }); 60 node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1])); 61 node.leafNode_Count = 1; 62 return true; 63 } 64 #endregion 65 #region 數據量少於100 66 if (rowCount < Limit_Node) 67 { 68 maxResult = node.result+1; 69 node.feature_Type=("result"); 70 node.features=(new List<String> { "" + (maxResult) }); 71 node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1])); 72 node.leafNode_Count = 1; 73 return true; 74 } 75 #endregion 76 return false; 77 } 78 catch (Exception e) 79 { 80 return false; 81 } 82 }
2)尋找最優的分裂屬性
尋找最優的分裂屬性須要計算每個分裂屬性分裂後的信息增益率,計算公式上文已給出,其中熵的計算代碼以下:
1 public static double CalEntropy(double[] counts, int countAll) 2 { 3 try 4 { 5 double allShang = 0; 6 for (int i = 0; i < counts.Length; i++) 7 { 8 if (counts[i] == 0) 9 { 10 continue; 11 } 12 double rate = counts[i] / countAll; 13 allShang = allShang + rate * Math.Log(rate, 2); 14 } 15 return allShang; 16 } 17 catch (Exception e) 18 { 19 return 0; 20 } 21 }
3)進行分裂,同時對子節點進行迭代處理
其實就是遞歸的工程,對每個子節點執行findBestSplit方法進行分裂。
findBestSplit源代碼:
1 public static Node findBestSplit(Node node, List<int> nums, int[] isUsed) 2 { 3 try 4 { 5 //判斷是否繼續分裂 6 double totalShang = CalEntropy(node.ClassCount, node.rowCount); 7 if (ifEnd(node, totalShang,isUsed)) 8 { 9 return node; 10 } 11 #region 變量聲明 12 SplitInfo info = new SplitInfo(); 13 int RowCount = nums.Count; //樣本總數 14 double jubuMax = 0; //局部最大熵 15 #endregion 16 for (int i = 0; i < isUsed.Length - 1; i++) 17 { 18 if (isUsed[i] == 1) 19 { 20 continue; 21 } 22 #region 離散變量 23 if (type[i] == 0) 24 { 25 int[] allFeatureCount = new int[0]; //全部類別的數量 26 double[][] allCount = new double[allNum[i]][]; 27 for (int j = 0; j < allCount.Length; j++) 28 { 29 allCount[j] = new double[classCount]; 30 } 31 int[] countAllFeature = new int[allNum[i]]; 32 List<int>[] temp = new List<int>[allNum[i]]; 33 for (int j = 0; j < temp.Length; j++) 34 { 35 temp[j] = new List<int>(); 36 } 37 for (int j = 0; j < nums.Count; j++) 38 { 39 int index = Convert.ToInt32(allData[nums[j]][i]); 40 temp[index - 1].Add(nums[j]); 41 countAllFeature[index - 1]++; 42 allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++; 43 } 44 double allShang = 0; 45 double chushu = 0; 46 for (int j = 0; j < allCount.Length; j++) 47 { 48 allShang = allShang + CalEntropy(allCount[j], countAllFeature[j]) * countAllFeature[j] / RowCount; 49 if (countAllFeature[j] > 0) 50 { 51 double rate = countAllFeature[j] / Convert.ToDouble(RowCount); 52 chushu = chushu + rate * Math.Log(rate, 2); 53 } 54 } 55 allShang = (-totalShang + allShang); 56 if (allShang > jubuMax) 57 { 58 info.features = new List<string>(); 59 info.type = 0; 60 info.temp = temp; 61 info.splitIndex = i; 62 info.class_Count = allCount; 63 jubuMax = allShang; 64 allFeatureCount = countAllFeature; 65 } 66 } 67 #endregion 68 #region 連續變量 69 else 70 { 71 double[] leftCount = new double[classCount]; //作節點各個類別的數量 72 double[] rightCount = new double[classCount]; //右節點各個類別的數量 73 double[] count1 = new double[classCount]; //子集1的統計量 74 //double[] count2 = new double[node.getCount().Length]; //子集2的統計量 75 double[] count2 = new double[node.ClassCount.Length]; //子集2的統計量 76 for (int j = 0; j < node.ClassCount.Length; j++) 77 { 78 count2[j] = node.ClassCount[j]; 79 } 80 int all1 = 0; //子集1的樣本量 81 int all2 = nums.Count; //子集2的樣本量 82 double lastValue = 0; //上一個記錄的類別 83 double currentValue = 0; //當前類別 84 double lastPoint = 0; //上一個點的值 85 double currentPoint = 0; //當前點的值 86 int splitPoint = 0; 87 double splitValue = 0; 88 double[] values = new double[nums.Count]; 89 for (int j = 0; j < values.Length; j++) 90 { 91 values[j] = allData[nums[j]][i]; 92 } 93 QSort(values, nums, 0, nums.Count - 1); 94 double chushu = 0; 95 double lianxuMax = 0; //連續型屬性的最大熵 96 for (int j = 0; j < nums.Count - 1; j++) 97 { 98 currentValue = allData[nums[j]][lieshu - 1]; 99 currentPoint = allData[nums[j]][i]; 100 if (j == 0) 101 { 102 lastValue = currentValue; 103 lastPoint = currentPoint; 104 } 105 if (currentValue != lastValue) 106 { 107 double shang1 = CalEntropy(count1, all1); 108 double shang2 = CalEntropy(count2, all2); 109 double allShang = shang1 * all1 / (all1 + all2) + shang2 * all2 / (all1 + all2); 110 allShang = (-totalShang + allShang); 111 if (lianxuMax < allShang) 112 { 113 lianxuMax = allShang; 114 for (int k = 0; k < count1.Length; k++) 115 { 116 leftCount[k] = count1[k]; 117 rightCount[k] = count2[k]; 118 } 119 splitPoint = j; 120 splitValue = (currentPoint + lastPoint) / 2; 121 } 122 } 123 all1++; 124 count1[Convert.ToInt32(currentValue) - 1]++; 125 count2[Convert.ToInt32(currentValue) - 1]--; 126 all2--; 127 lastValue = currentValue; 128 lastPoint = currentPoint; 129 } 130 double rate1 = Convert.ToDouble(leftCount[0] + leftCount[1]) / (leftCount[0] + leftCount[1] + rightCount[0] + rightCount[1]); 131 chushu = 0; 132 if (rate1 > 0) 133 { 134 chushu = chushu + rate1 * Math.Log(rate1, 2); 135 } 136 double rate2 = Convert.ToDouble(rightCount[0] + rightCount[1]) / (leftCount[0] + leftCount[1] + rightCount[0] + rightCount[1]); 137 if (rate2 > 0) 138 { 139 chushu = chushu + rate2 * Math.Log(rate2, 2); 140 } 141 //lianxuMax = lianxuMax ; 142 //lianxuMax = lianxuMax; 143 if (lianxuMax > jubuMax) 144 { 145 //info.setSplitIndex(i); 146 info.splitIndex=(i); 147 //info.setFeatures(new List<String> { splitValue + "" }); 148 info.features = (new List<String> { splitValue + "" }); 149 //info.setType(1); 150 info.type=(1); 151 jubuMax = lianxuMax; 152 //info.setType(1); 153 List<int>[] allInt = new List<int>[2]; 154 allInt[0] = new List<int>(); 155 allInt[1] = new List<int>(); 156 for (int k = 0; k < splitPoint; k++) 157 { 158 allInt[0].Add(nums[k]); 159 } 160 for (int k = splitPoint; k < nums.Count; k++) 161 { 162 allInt[1].Add(nums[k]); 163 } 164 info.temp=(allInt); 165 //info.setTemp(allInt); 166 double[][] alls = new double[2][]; 167 alls[0] = new double[leftCount.Length]; 168 alls[1] = new double[leftCount.Length]; 169 for (int k = 0; k < leftCount.Length; k++) 170 { 171 alls[0][k] = leftCount[k]; 172 alls[1][k] = rightCount[k]; 173 } 174 info.class_Count=(alls); 175 //info.setclassCount(alls); 176 } 177 } 178 #endregion 179 } 180 #region 若是找不到最佳的分裂屬性,則設爲葉節點 181 if (info.splitIndex == -1) 182 { 183 double[] finalCount = node.ClassCount; 184 double max = finalCount[0]; 185 int result = 1; 186 for (int i = 1; i < finalCount.Length; i++) 187 { 188 if (finalCount[i] > max) 189 { 190 max = finalCount[i]; 191 result = (i + 1); 192 } 193 } 194 node.feature_Type=("result"); 195 node.features=(new List<String> { "" + result }); 196 return node; 197 } 198 #endregion 199 #region 分裂 200 int deep = node.deep; 201 node.SplitFeature=("" + info.splitIndex); 202 203 List<Node> childNode = new List<Node>(); 204 int[] used = new int[isUsed.Length]; 205 for (int i = 0; i < used.Length; i++) 206 { 207 used[i] = isUsed[i]; 208 } 209 if (info.type == 0) 210 { 211 used[info.splitIndex] = 1; 212 node.feature_Type=("離散"); 213 } 214 else 215 { 216 used[info.splitIndex] = 0; 217 node.feature_Type=("連續"); 218 } 219 int sumLeaf = 0; 220 int sumWrong = 0; 221 List<int>[] rowIndex = info.temp; 222 List<String> features = info.features; 223 for (int j = 0; j < rowIndex.Length; j++) 224 { 225 if (rowIndex[j].Count == 0) 226 { 227 continue; 228 } 229 if (info.type == 0) 230 features.Add("" + (j + 1)); 231 Node node1 = new Node(); 232 node1.setClassCount(info.class_Count[j]); 233 node1.deep=(deep + 1); 234 node1.rowCount = info.temp[j].Count; 235 node1 = findBestSplit(node1, info.temp[j], used); 236 sumLeaf += node1.leafNode_Count; 237 sumWrong += node1.leafWrong; 238 childNode.Add(node1); 239 } 240 node.leafNode_Count = (sumLeaf); 241 node.leafWrong = (sumWrong); 242 node.features=(features); 243 node.childNodes=(childNode); 244 #endregion 245 return node; 246 } 247 catch (Exception e) 248 { 249 Console.WriteLine(e.StackTrace); 250 return node; 251 } 252 }
(4)剪枝
悲觀剪枝方法(PEP):
1 public static void prune(Node node) 2 { 3 if (node.feature_Type == "result") 4 return; 5 double treeWrong = node.getErrorCount() + 0.5; 6 double leafError = node.leafWrong + 0.5 * node.leafNode_Count; 7 double var = Math.Sqrt(leafError * (1 - Convert.ToDouble(leafError) / node.nums.Count)); 8 double panbie = leafError + var - treeWrong; 9 if (panbie > 0) 10 { 11 node.feature_Type=("result"); 12 node.childNodes=(null); 13 int result = (node.result + 1); 14 node.features=(new List<String>() { "" + result }); 15 } 16 else 17 { 18 List<Node> childNodes = node.childNodes; 19 for (int i = 0; i < childNodes.Count; i++) 20 { 21 prune(childNodes[i]); 22 } 23 } 24 }
C4.5核心算法的全部源代碼:
1 #region C4.5核心算法 2 /// <summary> 3 /// 測試 4 /// </summary> 5 /// <param name="node"></param> 6 /// <param name="data"></param> 7 public static String findResult(Node node, String[] data) 8 { 9 List<String> featrues = node.features; 10 String type = node.feature_Type; 11 if (type == "result") 12 { 13 return featrues[0]; 14 } 15 int split = Convert.ToInt32(node.SplitFeature); 16 List<Node> childNodes = node.childNodes; 17 double[] resultCount = node.ClassCount; 18 if (type == "連續") 19 { 20 double value = Convert.ToDouble(featrues[0]); 21 if (Convert.ToDouble(data[split]) <= value) 22 { 23 return findResult(childNodes[0], data); 24 } 25 else 26 { 27 return findResult(childNodes[1], data); 28 } 29 } 30 else 31 { 32 for (int i = 0; i < featrues.Count; i++) 33 { 34 if (data[split] == featrues[i]) 35 { 36 return findResult(childNodes[i], data); 37 } 38 if (i == featrues.Count - 1) 39 { 40 double count = resultCount[0]; 41 int maxInt = 0; 42 for (int j = 1; j < resultCount.Length; j++) 43 { 44 if (count < resultCount[j]) 45 { 46 count = resultCount[j]; 47 maxInt = j; 48 } 49 } 50 return findResult(childNodes[0], data); 51 } 52 } 53 } 54 return null; 55 } 56 /// <summary> 57 /// 判斷是否還須要分裂 58 /// </summary> 59 /// <param name="node"></param> 60 /// <returns></returns> 61 public static bool ifEnd(Node node, double entropy,int[] isUsed) 62 { 63 try 64 { 65 double[] count = node.ClassCount; 66 int rowCount = node.rowCount; 67 int maxResult = 0; 68 #region 數達到某一深度 69 int deep = node.deep; 70 if (deep >= maxDeep) 71 { 72 maxResult = node.result + 1; 73 node.feature_Type=("result"); 74 node.features=(new List<String>() { maxResult + "" }); 75 node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1])); 76 node.leafNode_Count = 1; 77 return true; 78 } 79 #endregion 80 #region 純度(其實跟後面的有點重了,記得要修改) 81 //maxResult = 1; 82 //for (int i = 1; i < count.Length; i++) 83 //{ 84 // if (count[i] / rowCount >= 0.95) 85 // { 86 // node.feature_Type=("result"); 87 // node.features=(new List<String> { "" + (i + 1) }); 88 // node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1])); 89 // node.leafNode_Count = 1; 90 // return true; 91 // } 92 //} 93 #endregion 94 #region 熵爲0 95 if (entropy == 0) 96 { 97 maxResult = node.result+1; 98 node.feature_Type=("result"); 99 node.features=(new List<String> { maxResult + "" }); 100 node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1])); 101 node.leafNode_Count = 1; 102 return true; 103 } 104 #endregion 105 #region 屬性已經分完 106 bool flag = true; 107 for (int i = 0; i < isUsed.Length - 1; i++) 108 { 109 if (isUsed[i] == 0) 110 { 111 flag = false; 112 break; 113 } 114 } 115 if (flag) 116 { 117 maxResult = node.result+1; 118 node.feature_Type=("result"); 119 node.features=(new List<String> { "" + (maxResult) }); 120 node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1])); 121 node.leafNode_Count = 1; 122 return true; 123 } 124 #endregion 125 #region 數據量少於100 126 if (rowCount < Limit_Node) 127 { 128 maxResult = node.result+1; 129 node.feature_Type=("result"); 130 node.features=(new List<String> { "" + (maxResult) }); 131 node.leafWrong=(rowCount - Convert.ToInt32(count[maxResult - 1])); 132 node.leafNode_Count = 1; 133 return true; 134 } 135 #endregion 136 return false; 137 } 138 catch (Exception e) 139 { 140 return false; 141 } 142 } 143 #region 排序算法 144 public static void InsertSort(double[] values, List<int> arr, int StartIndex, int endIndex) 145 { 146 for (int i = StartIndex + 1; i <= endIndex; i++) 147 { 148 int key = arr[i]; 149 double init = values[i]; 150 int j = i - 1; 151 while (j >= StartIndex && values[j] > init) 152 { 153 arr[j + 1] = arr[j]; 154 values[j + 1] = values[j]; 155 j--; 156 } 157 arr[j + 1] = key; 158 values[j + 1] = init; 159 } 160 } 161 static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high) 162 { 163 int mid = low + ((high - low) >> 1);//計算數組中間的元素的下標 164 165 //使用三數取中法選擇樞軸 166 if (values[mid] > values[high])//目標: arr[mid] <= arr[high] 167 { 168 swap(values, arr, mid, high); 169 } 170 if (values[low] > values[high])//目標: arr[low] <= arr[high] 171 { 172 swap(values, arr, low, high); 173 } 174 if (values[mid] > values[low]) //目標: arr[low] >= arr[mid] 175 { 176 swap(values, arr, mid, low); 177 } 178 //此時,arr[mid] <= arr[low] <= arr[high] 179 return low; 180 //low的位置上保存這三個位置中間的值 181 //分割時能夠直接使用low位置的元素做爲樞軸,而不用改變分割函數了 182 } 183 static void swap(double[] values, List<int> arr, int t1, int t2) 184 { 185 double temp = values[t1]; 186 values[t1] = values[t2]; 187 values[t2] = temp; 188 int key = arr[t1]; 189 arr[t1] = arr[t2]; 190 arr[t2] = key; 191 } 192 static void QSort(double[] values, List<int> arr, int low, int high) 193 { 194 int first = low; 195 int last = high; 196 197 int left = low; 198 int right = high; 199 200 int leftLen = 0; 201 int rightLen = 0; 202 203 if (high - low + 1 < 10) 204 { 205 InsertSort(values, arr, low, high); 206 return; 207 } 208 209 //一次分割 210 int key = SelectPivotMedianOfThree(values, arr, low, high);//使用三數取中法選擇樞軸 211 double inti = values[key]; 212 int currentKey = arr[key]; 213 214 while (low < high) 215 { 216 while (high > low && values[high] >= inti) 217 { 218 if (values[high] == inti)//處理相等元素 219 { 220 swap(values, arr, right, high); 221 right--; 222 rightLen++; 223 } 224 high--; 225 } 226 arr[low] = arr[high]; 227 values[low] = values[high]; 228 while (high > low && values[low] <= inti) 229 { 230 if (values[low] == inti) 231 { 232 swap(values, arr, left, low); 233 left++; 234 leftLen++; 235 } 236 low++; 237 } 238 arr[high] = arr[low]; 239 values[high] = values[low]; 240 } 241 arr[low] = currentKey; 242 values[low] = values[key]; 243 //一次快排結束 244 //把與樞軸key相同的元素移到樞軸最終位置周圍 245 int i = low - 1; 246 int j = first; 247 while (j < left && values[i] != inti) 248 { 249 swap(values, arr, i, j); 250 i--; 251 j++; 252 } 253 i = low + 1; 254 j = last; 255 while (j > right && values[i] != inti) 256 { 257 swap(values, arr, i, j); 258 i++; 259 j--; 260 } 261 QSort(values, arr, first, low - 1 - leftLen); 262 QSort(values, arr, low + 1 + rightLen, last); 263 } 264 #endregion 265 /// <summary> 266 /// 尋找最佳的分裂點 267 /// </summary> 268 /// <param name="num"></param> 269 /// <param name="node"></param> 270 public static Node findBestSplit(Node node, List<int> nums, int[] isUsed) 271 { 272 try 273 { 274 //判斷是否繼續分裂 275 double totalShang = CalEntropy(node.ClassCount, node.rowCount); 276 if (ifEnd(node, totalShang,isUsed)) 277 { 278 return node; 279 } 280 #region 變量聲明 281 SplitInfo info = new SplitInfo(); 282 int RowCount = nums.Count; //樣本總數 283 double jubuMax = 0; //局部最大熵 284 #endregion 285 for (int i = 0; i < isUsed.Length - 1; i++) 286 { 287 if (isUsed[i] == 1) 288 { 289 continue; 290 } 291 #region 離散變量 292 if (type[i] == 0) 293 { 294 int[] allFeatureCount = new int[0]; //全部類別的數量 295 double[][] allCount = new double[allNum[i]][]; 296 for (int j = 0; j < allCount.Length; j++) 297 { 298 allCount[j] = new double[classCount]; 299 } 300 int[] countAllFeature = new int[allNum[i]]; 301 List<int>[] temp = new List<int>[allNum[i]]; 302 for (int j = 0; j < temp.Length; j++) 303 { 304 temp[j] = new List<int>(); 305 } 306 for (int j = 0; j < nums.Count; j++) 307 { 308 int index = Convert.ToInt32(allData[nums[j]][i]); 309 temp[index - 1].Add(nums[j]); 310 countAllFeature[index - 1]++; 311 allCount[index - 1][Convert.ToInt32(allData[nums[j]][lieshu - 1]) - 1]++; 312 } 313 double allShang = 0; 314 double chushu = 0; 315 for (int j = 0; j < allCount.Length; j++) 316 { 317 allShang = allShang + CalEntropy(allCount[j], countAllFeature[j]) * countAllFeature[j] / RowCount; 318 if (countAllFeature[j] > 0) 319 { 320 double rate = countAllFeature[j] / Convert.ToDouble(RowCount); 321 chushu = chushu + rate * Math.Log(rate, 2); 322 } 323 } 324 allShang = (-totalShang + allShang); 325 if (allShang > jubuMax) 326 { 327 info.features = new List<string>(); 328 info.type = 0; 329 info.temp = temp; 330 info.splitIndex = i; 331 info.class_Count = allCount; 332 jubuMax = allShang; 333 allFeatureCount = countAllFeature; 334 } 335 } 336 #endregion 337 #region 連續變量 338 else 339 { 340 double[] leftCount = new double[classCount]; //作節點各個類別的數量 341 double[] rightCount = new double[classCount]; //右節點各個類別的數量 342 double[] count1 = new double[classCount]; //子集1的統計量 343 //double[] count2 = new double[node.getCount().Length]; //子集2的統計量 344 double[] count2 = new double[node.ClassCount.Length]; //子集2的統計量 345 for (int j = 0; j < node.ClassCount.Length; j++) 346 { 347 count2[j] = node.ClassCount[j]; 348 } 349 int all1 = 0; //子集1的樣本量 350 int all2 = nums.Count; //子集2的樣本量 351 double lastValue = 0; //上一個記錄的類別 352 double currentValue = 0; //當前類別 353 double lastPoint = 0; //上一個點的值 354 double currentPoint = 0; //當前點的值 355 int splitPoint = 0; 356 double splitValue = 0; 357 double[] values = new double[nums.Count]; 358 for (int j = 0; j < values.Length; j++) 359 { 360 values[j] = allData[nums[j]][i]; 361 } 362 QSort(values, nums, 0, nums.Count - 1); 363 double chushu = 0; 364 double lianxuMax = 0; //連續型屬性的最大熵 365 for (int j = 0; j < nums.Count - 1; j++) 366 { 367 currentValue = allData[nums[j]][lieshu - 1]; 368 currentPoint = allData[nums[j]][i]; 369 if (j == 0) 370 { 371 lastValue = currentValue; 372 lastPoint = currentPoint; 373 } 374 if (currentValue != lastValue) 375 { 376 double shang1 = CalEntropy(count1, all1); 377 double shang2 = CalEntropy(count2, all2); 378 double allShang = shang1 * all1 / (all1 + all2) + shang2 * all2 / (all1 + all2); 379 allShang = (-totalShang + allShang); 380 if (lianxuMax < allShang) 381 { 382 lianxuMax = allShang; 383 for (int k = 0; k < count1.Length; k++) 384 { 385 leftCount[k] = count1[k]; 386 rightCount[k] = count2[k]; 387 } 388 splitPoint = j; 389 splitValue = (currentPoint + lastPoint) / 2; 390 } 391 } 392 all1++; 393 count1[Convert.ToInt32(currentValue) - 1]++; 394 count2[Convert.ToInt32(currentValue) - 1]--; 395 all2--; 396 lastValue = currentValue; 397 lastPoint = currentPoint; 398 } 399 double rate1 = Convert.ToDouble(leftCount[0] + leftCount[1]) / (leftCount[0] + leftCount[1] + rightCount[0] + rightCount[1]); 400 chushu = 0; 401 if (rate1 > 0) 402 { 403 chushu = chushu + rate1 * Math.Log(rate1, 2); 404 } 405 double rate2 = Convert.ToDouble(rightCount[0] + rightCount[1]) / (leftCount[0] + leftCount[1] + rightCount[0] + rightCount[1]); 406 if (rate2 > 0) 407 { 408 chushu = chushu + rate2 * Math.Log(rate2, 2); 409 } 410 //lianxuMax = lianxuMax ; 411 //lianxuMax = lianxuMax; 412 if (lianxuMax > jubuMax) 413 { 414 //info.setSplitIndex(i); 415 info.splitIndex=(i); 416 //info.setFeatures(new List<String> { splitValue + "" }); 417 info.features = (new List<String> { splitValue + "" }); 418 //info.setType(1); 419 info.type=(1); 420 jubuMax = lianxuMax; 421 //info.setType(1); 422 List<int>[] allInt = new List<int>[2]; 423 allInt[0] = new List<int>(); 424 allInt[1] = new List<int>(); 425 for (int k = 0; k < splitPoint; k++) 426 { 427 allInt[0].Add(nums[k]); 428 } 429 for (int k = splitPoint; k < nums.Count; k++) 430 { 431 allInt[1].Add(nums[k]); 432 } 433 info.temp=(allInt); 434 //info.setTemp(allInt); 435 double[][] alls = new double[2][]; 436 alls[0] = new double[leftCount.Length]; 437 alls[1] = new double[leftCount.Length]; 438 for (int k = 0; k < leftCount.Length; k++) 439 { 440 alls[0][k] = leftCount[k]; 441 alls[1][k] = rightCount[k]; 442 } 443 info.class_Count=(alls); 444 //info.setclassCount(alls); 445 } 446 } 447 #endregion 448 } 449 #region 若是找不到最佳的分裂屬性,則設爲葉節點 450 if (info.splitIndex == -1) 451 { 452 double[] finalCount = node.ClassCount; 453 double max = finalCount[0]; 454 int result = 1; 455 for (int i = 1; i < finalCount.Length; i++) 456 { 457 if (finalCount[i] > max) 458 { 459 max = finalCount[i]; 460 result = (i + 1); 461 } 462 } 463 node.feature_Type=("result"); 464 node.features=(new List<String> { "" + result }); 465 return node; 466 } 467 #endregion 468 #region 分裂 469 int deep = node.deep; 470 node.SplitFeature=("" + info.splitIndex); 471 472 List<Node> childNode = new List<Node>(); 473 int[] used = new int[isUsed.Length]; 474 for (int i = 0; i < used.Length; i++) 475 { 476 used[i] = isUsed[i]; 477 } 478 if (info.type == 0) 479 { 480 used[info.splitIndex] = 1; 481 node.feature_Type=("離散"); 482 } 483 else 484 { 485 used[info.splitIndex] = 0; 486 node.feature_Type=("連續"); 487 } 488 int sumLeaf = 0; 489 int sumWrong = 0; 490 List<int>[] rowIndex = info.temp; 491 List<String> features = info.features; 492 for (int j = 0; j < rowIndex.Length; j++) 493 { 494 if (rowIndex[j].Count == 0) 495 { 496 continue; 497 } 498 if (info.type == 0) 499 features.Add("" + (j + 1)); 500 Node node1 = new Node(); 501 node1.setClassCount(info.class_Count[j]); 502 node1.deep=(deep + 1); 503 node1.rowCount = info.temp[j].Count; 504 node1 = findBestSplit(node1, info.temp[j], used); 505 sumLeaf += node1.leafNode_Count; 506 sumWrong += node1.leafWrong; 507 childNode.Add(node1); 508 } 509 node.leafNode_Count = (sumLeaf); 510 node.leafWrong = (sumWrong); 511 node.features=(features); 512 node.childNodes=(childNode); 513 #endregion 514 return node; 515 } 516 catch (Exception e) 517 { 518 Console.WriteLine(e.StackTrace); 519 return node; 520 } 521 } 522 /// <summary> 523 /// 計算熵 524 /// </summary> 525 /// <param name="counts"></param> 526 /// <param name="countAll"></param> 527 /// <returns></returns> 528 public static double CalEntropy(double[] counts, int countAll) 529 { 530 try 531 { 532 double allShang = 0; 533 for (int i = 0; i < counts.Length; i++) 534 { 535 if (counts[i] == 0) 536 { 537 continue; 538 } 539 double rate = counts[i] / countAll; 540 allShang = allShang + rate * Math.Log(rate, 2); 541 } 542 return allShang; 543 } 544 catch (Exception e) 545 { 546 return 0; 547 } 548 } 549 550 #region 悲觀剪枝 551 public static void prune(Node node) 552 { 553 if (node.feature_Type == "result") 554 return; 555 double treeWrong = node.getErrorCount() + 0.5; 556 double leafError = node.leafWrong + 0.5 * node.leafNode_Count; 557 double var = Math.Sqrt(leafError * (1 - Convert.ToDouble(leafError) / node.rowCount)); 558 double panbie = leafError + var - treeWrong; 559 if (panbie > 0) 560 { 561 node.feature_Type = "result"; 562 node.childNodes = null; 563 int result = node.result + 1; 564 node.features= new List<String>() { "" + result }; 565 } 566 else 567 { 568 List<Node> childNodes = node.childNodes; 569 for (int i = 0; i < childNodes.Count; i++) 570 { 571 prune(childNodes[i]); 572 } 573 } 574 } 575 #endregion 576 #endregion
總結:
要記住,C4.5是分類樹最終要的算法,算法的思想其實很簡單,可是分類的準確性高。能夠說C4.5是ID3的升級版和強化版,解決了ID3未能解決的問題。要重點記住如下幾個方面:
1.C4.5是採用信息增益率選擇分裂的屬性,解決了ID3選擇屬性時的偏向性問題;
2.C4.5可以對連續數據進行處理,採用一刀切的方式將連續型的數據切成兩份,在選擇切割點的時候使用信息增益做爲擇優的條件;
3.C4.5採用悲觀剪枝的策略,必定程度上下降了過擬合的影響。