機器學習之決策樹熵&信息增量求解算法實現

此文不對理論作相關闡述,僅涉及代碼實現:java

1.熵計算公式:ide

             P爲正例,Q爲反例this

     Entropy(S)   = -PLog2(P) - QLog2(Q);google

2.信息增量計算:spa

    Gain(S,Sv) = Entropy(S) - (|Sv|/|S|)ΣEntropy(Sv);code

 

舉例:orm

轉化數據輸入:blog

 5  14
 Outlook       Sunny  Sunny  Overcast  Rain  Rain    Rain    Overcast  Sunny  Sunny    Rain    Sunny   Overcast   Overcast    Rain
 Temperature   Hot    Hot    Hot       Mild  Cool    Cool        Cool   Mild  Cool     Mild    Mild    Mild       Hot         Mild
 Humidity      High   High   High      High  Normal  Normal  Normal     High  Normal   Normal  Normal  High       Normal      High
 Wind          Weak   Strong Weak      Weak  Weak    Strong  Strong    Weak   Weak     Weak    Strong  Strong     Weak        Strong
 PlayTennis    No     No     Yes       Yes   Yes     No      Yes       No     Yes      Yes     Yes     Yes        Yes         No
 Outlook Temperature Humidity Wind PlayTennis

 

 1 package com.qunar.data.tree;
 2 
 3 /**
 4  * *********************************************************
 5  * <p/>
 6  * Author:     XiJun.Gong
 7  * Date:       2016-09-02 15:28
 8  * Version:    default 1.0.0
 9  * Class description:
10  * <p>統計該類型出現的次數</p>
11  * <p/>
12  * *********************************************************
13  */
14 public class CountMap<T> {
15 
16     private T key;     //類型
17     private int value;   //出現的次數
18 
19     public CountMap() {
20         this(null, 0);
21     }
22 
23     public CountMap(T key, int value) {
24         this.key = key;
25         this.value = value;
26     }
27 
28     public T getKey() {
29         return key;
30     }
31 
32     public void setKey(T key) {
33         this.key = key;
34     }
35 
36     public int getValue() {
37         return value;
38     }
39 
40     public void setValue(int value) {
41         this.value = value;
42     }
43 }
View Code
  1 package com.qunar.data.tree;
  2 
  3 import com.google.common.collect.ArrayListMultimap;
  4 import com.google.common.collect.Maps;
  5 import com.google.common.collect.Multimap;
  6 import com.google.common.collect.Sets;
  7 
  8 import java.util.*;
  9 
 10 /**
 11  * *********************************************************
 12  * <p/>
 13  * Author:     XiJun.Gong
 14  * Date:       2016-09-02 14:24
 15  * Version:    default 1.0.0
 16  * Class description:
 17  * <p>決策樹</p>
 18  * <p/>
 19  * *********************************************************
 20  */
 21 
 22 public class DecisionTree<T, K> {
 23 
 24     private static String positiveExampleType = "Yes";
 25     private static String counterExampleType = "No";
 26 
 27 
 28     public double pLog2(final double p) {
 29         if (0 == p) return 0;
 30         return p * (Math.log(p) / Math.log(2));
 31     }
 32 
 33     /**
 34      * 熵計算
 35      *
 36      * @param positiveExample 正例個數
 37      * @param counterExample  反例個數
 38      * @return 熵值
 39      */
 40     public double entropy(final double positiveExample, final double counterExample) {
 41 
 42         double total = positiveExample + counterExample;
 43         double positiveP = positiveExample / total;
 44         double counterP = counterExample / total;
 45         return -1d * (pLog2(positiveP) + pLog2(counterP));
 46     }
 47 
 48     /**
 49      * @param features 特徵列表
 50      * @param results  對應結果
 51      * @return 將信息整合成新的格式
 52      */
 53     public Multimap<T, CountMap<K>> merge(final List<T> features, final List<T> results) {
 54         //數據轉化
 55         Multimap<T, CountMap<K>> InfoMap = ArrayListMultimap.create();
 56         Iterator result = results.iterator();
 57         for (T feature : features) {
 58             K res = (K) result.next();
 59             boolean tag = false;
 60             Collection<CountMap<K>> countMaps = InfoMap.get(feature);
 61             for (CountMap countMap : countMaps) {
 62                 if (countMap.getKey().equals(res)) {
 63                     /*修改值*/
 64                     int num = countMap.getValue() + 1;
 65                     InfoMap.remove(feature, countMap);
 66                     InfoMap.put(feature, new CountMap<K>(res, num));
 67                     tag = true;
 68                     break;
 69                 }
 70             }
 71             if (!tag)
 72                 InfoMap.put(feature, new CountMap<K>(res, 1));
 73         }
 74 
 75         return InfoMap;
 76     }
 77 
 78     /**
 79      * 信息增益
 80      *
 81      * @param infoMap   因素(Outlook,Temperature,Humidity,Wind)對應的結果
 82      * @param dataTable 輸入的數據表
 83      * @param type      因素中的類型(Outlook{Sunny,Overcast,Rain})
 84      * @param entropyS  總的熵值
 85      * @param totalSize 總的樣本數
 86      * @return 信息增益
 87      */
 88     public double gain(Multimap<T, CountMap<K>> infoMap,
 89                        Map<K, List<T>> dataTable,
 90                        final String type,
 91                        double entropyS,
 92                        final int totalSize) {
 93         //去重
 94         Set<T> subTypes = Sets.newHashSet();
 95         subTypes.addAll(dataTable.get(type));
 96         /*計算*/
 97         for (T subType : subTypes) {
 98             Collection<CountMap<K>> countMaps = infoMap.get(subType);
 99             double subSize = 0;
100             double positiveExample = 0;
101             double counterExample = 0;
102             for (CountMap<K> countMap : countMaps) {
103                 subSize += countMap.getValue();
104                 if (positiveExampleType.equals(countMap.getKey()))
105                     positiveExample = countMap.getValue();
106                 else
107                     counterExample = countMap.getValue();
108             }
109             entropyS -= (subSize / totalSize) * entropy(positiveExample, counterExample);
110         }
111         return entropyS;
112     }
113 
114     /**
115      * 計算
116      *
117      * @param dataTable  數據表
118      * @param types      因素列表{Outlook,Temperature,Humidity,Wind}
119      * @param resultType 結果(PlayTennis)
120      * @return 返回信息增益集合
121      */
122     public Map<String, Double> calculate(Map<K, List<T>> dataTable, List<K> types, K resultType) {
123 
124         Map<String, Double> answer = Maps.newHashMap();
125         List<T> results = dataTable.get(resultType);
126         int totalSize = results.size();
127         int positiveExample = 0;
128         int counterExample = 0;
129         double entropyS = 0d;
130         for (T ExampleType : results) {
131             if (positiveExampleType.equals(ExampleType)) {
132                 ++positiveExample;
133                 continue;
134             }
135             ++counterExample;
136         }
137         /*計算總的熵*/
138         entropyS = entropy(positiveExample, counterExample);
139 
140         Multimap<T, CountMap<K>> infoMap;
141         for (K type : types) {
142             infoMap = merge(dataTable.get(type), results);
143             double _gain = gain(infoMap, dataTable, (String) type, entropyS, totalSize);
144             answer.put((String) type, _gain);
145         }
146         return answer;
147     }
148 
149 }   1package com.qunar.data.tree;
 2 
 3 import com.google.common.collect.Lists;
 4 import com.google.common.collect.Maps;
 5 
 6 import java.util.*;
 7 
 8 /**
 9  * *********************************************************
10  * <p/>
11  * Author:     XiJun.Gong
12  * Date:       2016-09-02 16:43
13  * Version:    default 1.0.0
14  * Class description:
15  * <p/>
16  * *********************************************************
17  */
18 public class Main {
19 
20     public static void main(String args[]) {
21 
22         Scanner scanner = new Scanner(System.in);
23         while (scanner.hasNext()) {
24             DecisionTree<String, String> dt = new DecisionTree();
25             Map<String, List<String>> dataTable = Maps.newHashMap();
26             /*Map<String, List<String>> dataTable = Maps.newHashMap();*/
27             List<String> types = Lists.newArrayList();
28             String resultType;
29             int factorSize = scanner.nextInt();
30             int demoSize = scanner.nextInt();
31             String type;
32 
33             for (int i = 0; i < factorSize; i++) {
34                 List<String> demos = Lists.newArrayList();
35                 type = scanner.next();
36                 for (int j = 0; j < demoSize; j++) {
37                     demos.add(scanner.next());
38                 }
39                 dataTable.put(type, demos);
40             }
41             for (int i = 1; i < factorSize; i++) {
42                 types.add(scanner.next());
43             }
44             resultType = scanner.next();
45             Map<String, Double> ans = dt.calculate(dataTable, types, resultType);
46             List<Map.Entry<String, Double>> list = new ArrayList<Map.Entry<String, Double>>(ans.entrySet());
47             Collections.sort(list, new Comparator<Map.Entry<String, Double>>() {
48 
49 
50                 @Override
51                 public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
52                     return (o2.getValue() > o1.getValue() ? 1 : -1);
53                 }
54             });
55 
56             for (Map.Entry<String, Double> iterator : list) {
57                 System.out.println(iterator.getKey() + "= " + iterator.getValue());
58             }
59         }
60     }
61 
62 }
63 /**
64  *使用舉例:*
65  5  14
66  Outlook       Sunny  Sunny  Overcast  Rain  Rain    Rain    Overcast  Sunny  Sunny    Rain    Sunny   Overcast   Overcast    Rain
67  Temperature   Hot    Hot    Hot       Mild  Cool    Cool        Cool   Mild  Cool     Mild    Mild    Mild       Hot         Mild
68  Humidity      High   High   High      High  Normal  Normal  Normal     High  Normal   Normal  Normal  High       Normal      High
69  Wind          Weak   Strong Weak      Weak  Weak    Strong  Strong    Weak   Weak     Weak    Strong  Strong     Weak        Strong
70  PlayTennis    No     No     Yes       Yes   Yes     No      Yes       No     Yes      Yes     Yes     Yes        Yes         No
71  Outlook Temperature Humidity Wind PlayTennis
72  */

結果:ip

Outlook= 0.2467498197744391
Humidity= 0.15183550136234136
Wind= 0.04812703040826927
Temperature= 0.029222565658954647
相關文章
相關標籤/搜索