數據挖掘-決策樹ID3分類算法的C++實現

數據挖掘課上面老師介紹了下決策樹ID3算法,我抽空餘時間把這個算法用C++實現了一遍。 node

決策樹算法是很是經常使用的分類算法,是逼近離散目標函數的方法,學習獲得的函數以決策樹的形式表示。其基本思路是不斷選取產生信息增益最大的屬性來劃分樣例集和,構造決策樹。信息增益定義爲結點與其子結點的信息熵之差。信息熵是香農提出的,用於描述信息不純度(不穩定性),其計算公式是 ios

Pi爲子集合中不一樣性(而二元分類即正樣例和負樣例)的樣例的比例。這樣信息收益能夠定義爲樣本按照某屬性劃分時形成熵減小的指望,能夠區分訓練樣本中正負樣本的能力,其計算公司是 算法

我實現該算法針對的樣例集合以下 編程

該表記錄了在不一樣氣候條件下是否去打球的狀況,要求根據該表用程序輸出決策樹 數據結構

C++代碼以下,程序中有詳細註釋 函數

[cpp]  view plain copy
  1. #include <iostream>  
  2. #include <string>  
  3. #include <vector>  
  4. #include <map>  
  5. #include <algorithm>  
  6. #include <cmath>  
  7. using namespace std;  
  8. #define MAXLEN 6//輸入每行的數據個數  
  9.   
  10. //多叉樹的實現   
  11. //1 廣義表  
  12. //2 父指針表示法,適於常常找父結點的應用  
  13. //3 子女鏈表示法,適於常常找子結點的應用  
  14. //4 左長子,右兄弟表示法,實現比較麻煩  
  15. //5 每一個結點的全部孩子用vector保存  
  16. //教訓:數據結構的設計很重要,本算法採用5比較合適,同時  
  17. //注意維護剩餘樣例和剩餘屬性信息,建樹時橫向遍歷考循環屬性的值,  
  18. //縱向遍歷靠遞歸調用  
  19.   
  20. vector <vector <string> > state;//實例集  
  21. vector <string> item(MAXLEN);//對應一行實例集  
  22. vector <string> attribute_row;//保存首行即屬性行數據  
  23. string end("end");//輸入結束  
  24. string yes("yes");  
  25. string no("no");  
  26. string blank("");  
  27. map<string,vector < string > > map_attribute_values;//存儲屬性對應的全部的值  
  28. int tree_size = 0;  
  29. struct Node{//決策樹節點  
  30.     string attribute;//屬性值  
  31.     string arrived_value;//到達的屬性值  
  32.     vector<Node *> childs;//全部的孩子  
  33.     Node(){  
  34.         attribute = blank;  
  35.         arrived_value = blank;  
  36.     }  
  37. };  
  38. Node * root;  
  39.   
  40. //根據數據實例計算屬性與值組成的map  
  41. void ComputeMapFrom2DVector(){  
  42.     unsigned int i,j,k;  
  43.     bool exited = false;  
  44.     vector<string> values;  
  45.     for(i = 1; i < MAXLEN-1; i++){//按照列遍歷  
  46.         for (j = 1; j < state.size(); j++){  
  47.             for (k = 0; k < values.size(); k++){  
  48.                 if(!values[k].compare(state[j][i])) exited = true;  
  49.             }  
  50.             if(!exited){  
  51.                 values.push_back(state[j][i]);//注意Vector的插入都是從前面插入的,注意更新it,始終指向vector頭  
  52.             }  
  53.             exited = false;  
  54.         }  
  55.         map_attribute_values[state[0][i]] = values;  
  56.         values.erase(values.begin(), values.end());  
  57.     }     
  58. }  
  59.   
  60. //根據具體屬性和值來計算熵  
  61. double ComputeEntropy(vector <vector <string> > remain_state, string attribute, string value,bool ifparent){  
  62.     vector<int> count (2,0);  
  63.     unsigned int i,j;  
  64.     bool done_flag = false;//哨兵值  
  65.     for(j = 1; j < MAXLEN; j++){  
  66.         if(done_flag) break;  
  67.         if(!attribute_row[j].compare(attribute)){  
  68.             for(i = 1; i < remain_state.size(); i++){  
  69.                 if((!ifparent&&!remain_state[i][j].compare(value)) || ifparent){//ifparent記錄是否算父節點  
  70.                     if(!remain_state[i][MAXLEN - 1].compare(yes)){  
  71.                         count[0]++;  
  72.                     }  
  73.                     else count[1]++;  
  74.                 }  
  75.             }  
  76.             done_flag = true;  
  77.         }  
  78.     }  
  79.     if(count[0] == 0 || count[1] == 0 ) return 0;//所有是正實例或者負實例  
  80.     //具體計算熵 根據[+count[0],-count[1]],log2爲底經過換底公式換成天然數底數  
  81.     double sum = count[0] + count[1];  
  82.     double entropy = -count[0]/sum*log(count[0]/sum)/log(2.0) - count[1]/sum*log(count[1]/sum)/log(2.0);  
  83.     return entropy;  
  84. }  
  85.       
  86. //計算按照屬性attribute劃分當前剩餘實例的信息增益  
  87. double ComputeGain(vector <vector <string> > remain_state, string attribute){  
  88.     unsigned int j,k,m;  
  89.     //首先求不作劃分時的熵  
  90.     double parent_entropy = ComputeEntropy(remain_state, attribute, blank, true);  
  91.     double children_entropy = 0;  
  92.     //而後求作劃分後各個值的熵  
  93.     vector<string> values = map_attribute_values[attribute];  
  94.     vector<double> ratio;  
  95.     vector<int> count_values;  
  96.     int tempint;  
  97.     for(m = 0; m < values.size(); m++){  
  98.         tempint = 0;  
  99.         for(k = 1; k < MAXLEN - 1; k++){  
  100.             if(!attribute_row[k].compare(attribute)){  
  101.                 for(j = 1; j < remain_state.size(); j++){  
  102.                     if(!remain_state[j][k].compare(values[m])){  
  103.                         tempint++;  
  104.                     }  
  105.                 }  
  106.             }  
  107.         }  
  108.         count_values.push_back(tempint);  
  109.     }  
  110.       
  111.     for(j = 0; j < values.size(); j++){  
  112.         ratio.push_back((double)count_values[j] / (double)(remain_state.size()-1));  
  113.     }  
  114.     double temp_entropy;  
  115.     for(j = 0; j < values.size(); j++){  
  116.         temp_entropy = ComputeEntropy(remain_state, attribute, values[j], false);  
  117.         children_entropy += ratio[j] * temp_entropy;  
  118.     }  
  119.     return (parent_entropy - children_entropy);   
  120. }  
  121.   
  122. int FindAttriNumByName(string attri){  
  123.     for(int i = 0; i < MAXLEN; i++){  
  124.         if(!state[0][i].compare(attri)) return i;  
  125.     }  
  126.     cerr<<"can't find the numth of attribute"<<endl;   
  127.     return 0;  
  128. }  
  129.   
  130. //找出樣例中佔多數的正/負性  
  131. string MostCommonLabel(vector <vector <string> > remain_state){  
  132.     int p = 0, n = 0;  
  133.     for(unsigned i = 0; i < remain_state.size(); i++){  
  134.         if(!remain_state[i][MAXLEN-1].compare(yes)) p++;  
  135.         else n++;  
  136.     }  
  137.     if(p >= n) return yes;  
  138.     else return no;  
  139. }  
  140.   
  141. //判斷樣例是否正負性都爲label  
  142. bool AllTheSameLabel(vector <vector <string> > remain_state, string label){  
  143.     int count = 0;  
  144.     for(unsigned int i = 0; i < remain_state.size(); i++){  
  145.         if(!remain_state[i][MAXLEN-1].compare(label)) count++;  
  146.     }  
  147.     if(count == remain_state.size()-1) return true;  
  148.     else return false;  
  149. }  
  150.   
  151. //計算信息增益,DFS構建決策樹  
  152. //current_node爲當前的節點  
  153. //remain_state爲剩餘待分類的樣例  
  154. //remian_attribute爲剩餘尚未考慮的屬性  
  155. //返回根結點指針  
  156. Node * BulidDecisionTreeDFS(Node * p, vector <vector <string> > remain_state, vector <string> remain_attribute){  
  157.     //if(remain_state.size() > 0){  
  158.         //printv(remain_state);  
  159.     //}  
  160.     if (p == NULL)  
  161.         p = new Node();  
  162.     //先看搜索到樹葉的狀況  
  163.     if (AllTheSameLabel(remain_state, yes)){  
  164.         p->attribute = yes;  
  165.         return p;  
  166.     }  
  167.     if (AllTheSameLabel(remain_state, no)){  
  168.         p->attribute = no;  
  169.         return p;  
  170.     }  
  171.     if(remain_attribute.size() == 0){//全部的屬性均已經考慮完了,尚未分盡  
  172.         string label = MostCommonLabel(remain_state);  
  173.         p->attribute = label;  
  174.         return p;  
  175.     }  
  176.   
  177.     double max_gain = 0, temp_gain;  
  178.     vector <string>::iterator max_it = remain_attribute.begin();  
  179.     vector <string>::iterator it1;  
  180.     for(it1 = remain_attribute.begin(); it1 < remain_attribute.end(); it1++){  
  181.         temp_gain = ComputeGain(remain_state, (*it1));  
  182.         if(temp_gain > max_gain) {  
  183.             max_gain = temp_gain;  
  184.             max_it = it1;  
  185.         }  
  186.     }  
  187.     //下面根據max_it指向的屬性來劃分當前樣例,更新樣例集和屬性集  
  188.     vector <string> new_attribute;  
  189.     vector <vector <string> > new_state;  
  190.     for(vector <string>::iterator it2 = remain_attribute.begin(); it2 < remain_attribute.end(); it2++){  
  191.         if((*it2).compare(*max_it)) new_attribute.push_back(*it2);  
  192.     }  
  193.     //肯定了最佳劃分屬性,注意保存  
  194.     p->attribute = *max_it;  
  195.     vector <string> values = map_attribute_values[*max_it];  
  196.     int attribue_num = FindAttriNumByName(*max_it);  
  197.     new_state.push_back(attribute_row);  
  198.     for(vector <string>::iterator it3 = values.begin(); it3 < values.end(); it3++){  
  199.         for(unsigned int i = 1; i < remain_state.size(); i++){  
  200.             if(!remain_state[i][attribue_num].compare(*it3)){  
  201.                 new_state.push_back(remain_state[i]);  
  202.             }  
  203.         }  
  204.         Node * new_node = new Node();  
  205.         new_node->arrived_value = *it3;  
  206.         if(new_state.size() == 0){//表示當前沒有這個分支的樣例,當前的new_node爲葉子節點  
  207.             new_node->attribute = MostCommonLabel(remain_state);  
  208.         }  
  209.         else   
  210.             BulidDecisionTreeDFS(new_node, new_state, new_attribute);  
  211.         //遞歸函數返回時即回溯時須要1 將新結點加入父節點孩子容器 2清除new_state容器  
  212.         p->childs.push_back(new_node);  
  213.         new_state.erase(new_state.begin()+1,new_state.end());//注意先清空new_state中的前一個取值的樣例,準備遍歷下一個取值樣例  
  214.     }  
  215.     return p;  
  216. }  
  217.   
  218. void Input(){  
  219.     string s;  
  220.     while(cin>>s,s.compare(end) != 0){//-1爲輸入結束  
  221.         item[0] = s;  
  222.         for(int i = 1;i < MAXLEN; i++){  
  223.             cin>>item[i];  
  224.         }  
  225.         state.push_back(item);//注意首行信息也輸入進去,即屬性  
  226.     }  
  227.     for(int j = 0; j < MAXLEN; j++){  
  228.         attribute_row.push_back(state[0][j]);  
  229.     }  
  230. }  
  231.   
  232. void PrintTree(Node *p, int depth){  
  233.     for (int i = 0; i < depth; i++) cout << '\t';//按照樹的深度先輸出tab  
  234.     if(!p->arrived_value.empty()){  
  235.         cout<<p->arrived_value<<endl;  
  236.         for (int i = 0; i < depth+1; i++) cout << '\t';//按照樹的深度先輸出tab  
  237.     }  
  238.     cout<<p->attribute<<endl;  
  239.     for (vector<Node*>::iterator it = p->childs.begin(); it != p->childs.end(); it++){  
  240.         PrintTree(*it, depth + 1);  
  241.     }  
  242. }  
  243.   
  244. void FreeTree(Node *p){  
  245.     if (p == NULL)  
  246.         return;  
  247.     for (vector<Node*>::iterator it = p->childs.begin(); it != p->childs.end(); it++){  
  248.         FreeTree(*it);  
  249.     }  
  250.     delete p;  
  251.     tree_size++;  
  252. }  
  253.   
  254. int main(){  
  255.     Input();  
  256.     vector <string> remain_attribute;  
  257.       
  258.     string outlook("Outlook");  
  259.     string Temperature("Temperature");  
  260.     string Humidity("Humidity");  
  261.     string Wind("Wind");  
  262.     remain_attribute.push_back(outlook);  
  263.     remain_attribute.push_back(Temperature);  
  264.     remain_attribute.push_back(Humidity);  
  265.     remain_attribute.push_back(Wind);  
  266.     vector <vector <string> > remain_state;  
  267.     for(unsigned int i = 0; i < state.size(); i++){  
  268.         remain_state.push_back(state[i]);   
  269.     }  
  270.     ComputeMapFrom2DVector();  
  271.     root = BulidDecisionTreeDFS(root,remain_state,remain_attribute);  
  272.     cout<<"the decision tree is :"<<endl;  
  273.     PrintTree(root,0);  
  274.     FreeTree(root);  
  275.     cout<<endl;  
  276.     cout<<"tree_size:"<<tree_size<<endl;  
  277.     return 0;  
  278. }  
輸入的訓練數據以下
[plain]  view plain copy
  1. Day Outlook Temperature Humidity Wind PlayTennis  
  2. 1 Sunny Hot High Weak no  
  3. 2 Sunny Hot High Strong no  
  4. 3 Overcast Hot High Weak yes  
  5. 4 Rainy Mild High Weak yes  
  6. 5 Rainy Cool Normal Weak yes  
  7. 6 Rainy Cool Normal Strong no  
  8. 7 Overcast Cool Normal Strong yes  
  9. 8 Sunny Mild High Weak no  
  10. 9 Sunny Cool Normal Weak yes  
  11. 10 Rainy Mild Normal Weak yes  
  12. 11 Sunny Mild Normal Strong yes  
  13. 12 Overcast Mild High Strong yes  
  14. 13 Overcast Hot Normal Weak yes  
  15. 14 Rainy Mild High Strong no  
  16. end  

程序輸出決策樹以下

能夠用圖形表示爲 學習


有了決策樹後,就能夠根據氣候條件作預測了 優化

例如若是氣候數據是{Sunny,Cool,Normal,Strong} ,根據決策樹到左側的yes葉節點,能夠斷定會去游泳。 spa

另外在編寫這個程序時在數據結構的設計上面走了彎路,多叉樹的實現有不少方法,本算法採用每一個結點的全部孩子用vector保存比較合適,同時注意維護剩餘樣例和剩餘屬性信息,建樹時橫向遍歷靠循環屬性的值,縱向遍歷靠遞歸調用 ,整體是DFS,樹和圖的遍歷在編程時常常遇到,得熟練掌握。程序有些地方的效率還得優化,有不足的點地方還望你們拍磚。 .net

相關文章
相關標籤/搜索