說明:每一個樣本都會裝入Data樣本對象,決策樹生成算法接收的是一個Array<Data>樣本列表,因此構建測試數據時也要符合格式,最後生成的決策樹是樹的根節點,經過裏面提供的showTree()方法可查看整個樹結構,下面奉上源碼。java
Data.java算法
package ai.tree.data; import java.util.HashMap; /** * 樣本類 * @author ChenLuyang * @date 2019/2/21 */ public class Data implements Cloneable{ /** * K是特徵描述,V是特徵值 */ private HashMap<String,String> feature = new HashMap<String, String>(); /** * 該樣本結論 */ private String result; public Data(HashMap<String,String> feature,String result){ this.feature = feature; this.result = result; } public HashMap<String, String> getFeature() { return feature; } public String getResult() { return result; } private void setFeature(HashMap<String, String> feature) { this.feature = feature; } @Override public Data clone() { Data object=null; try { object = (Data) super.clone(); object.setFeature((HashMap<String, String>) this.feature.clone()); } catch (CloneNotSupportedException e) { e.printStackTrace(); } return object; } }
DecisionTree.javaide
package ai.tree.algorithm; import ai.tree.data.Data; import java.math.BigDecimal; import java.util.*; /** * @author ChenLuyang * @date 2019/2/21 */ public class DecisionTree { /** * 遞歸構建決策樹 * * @param dataList 樣本集合 * @return ai.tree.algorithm.DecisionTree.TreeNode 使用傳入樣本構建的決策節點 * @author ChenLuyang * @date 2019/2/21 16:05 */ public TreeNode createTree(List<Data> dataList) { //建立當前節點 TreeNode<String, String, String> nowTreeNode = new TreeNode<String, String, String>(); //當前節點的各個分支節點 Map<String, TreeNode> featureDecisionMap = new HashMap<String, TreeNode>(); //統計當前樣本集中全部的分類結果 Set<String> resultSet = new HashSet<String>(); for (Data data : dataList) { resultSet.add(data.getResult()); } //若是當前樣本集只有一種類別,則表示不用分類了,返回當前節點 if (resultSet.size() == 1) { String resultClassify = resultSet.iterator().next(); nowTreeNode.setResultNode(resultClassify); return nowTreeNode; } //若是數據集中特徵爲空,則選擇整個集合中出現次數最多的分類,做爲分類結果 if (dataList.get(0).getFeature().size() == 0) { Map<String, Integer> countMap = new HashMap<String, Integer>(); for (Data data : dataList) { Integer num = countMap.get(data.getResult()); if (num == null) { countMap.put(data.getResult(), 1); } else { countMap.put(data.getResult(), num + 1); } } String tmpResult = ""; Integer tmpNum = 0; for (String res : countMap.keySet()) { if (countMap.get(res) > tmpNum) { tmpNum = countMap.get(res); tmpResult = res; } } nowTreeNode.setResultNode(tmpResult); return nowTreeNode; } //尋找當前最優分類 String bestLabel = chooseBestFeatureToSplit(dataList); //提取最優特徵的全部可能值 Set<String> bestLabelInfoSet = new HashSet<String>(); for (Data data : dataList) { bestLabelInfoSet.add(data.getFeature().get(bestLabel)); } //使用最優特徵的各個特徵值進行分類 for (String labelInfo : bestLabelInfoSet) { for (Data data : dataList) { } List<Data> branchDataList = splitDataList(dataList, bestLabel, labelInfo); //最優特徵下該特徵值的節點 TreeNode branchTreeNode = createTree(branchDataList); featureDecisionMap.put(labelInfo, branchTreeNode); } nowTreeNode.setDecisionNode(bestLabel, featureDecisionMap); return nowTreeNode; } /** * 計算傳入數據集中的最優分類特徵 * * @param dataList * @return int 最優分類特徵的描述 * @author ChenLuyang * @date 2019/2/21 14:12 */ public String chooseBestFeatureToSplit(List<Data> dataList) { //目前數據集中的特徵集合 Set<String> futureSet = dataList.get(0).getFeature().keySet(); //未分類時的熵 BigDecimal baseEntropy = calcShannonEnt(dataList); //熵差 BigDecimal bestInfoGain = new BigDecimal("0"); //最優特徵 String bestFeature = ""; //按照各特徵分類 for (String future : futureSet) { //該特徵分類後的熵 BigDecimal futureEntropy = new BigDecimal("0"); //該特徵的全部特徵值去重集合 Set<String> futureInfoSet = new HashSet<String>(); for (Data data : dataList) { futureInfoSet.add(data.getFeature().get(future)); } //按照該特徵的特徵值一一分類 for (String futureInfo : futureInfoSet) { List<Data> splitResultDataList = splitDataList(dataList, future, futureInfo); //分類後樣本數佔總樣本數的比例 BigDecimal tmpProb = new BigDecimal(splitResultDataList.size() + "").divide(new BigDecimal(dataList.size() + ""), 5, BigDecimal.ROUND_HALF_DOWN); //所佔比例乘以分類後的樣本熵,而後再進行熵的累加 futureEntropy = futureEntropy.add(tmpProb.multiply(calcShannonEnt(splitResultDataList))); } BigDecimal subEntropy = baseEntropy.subtract(futureEntropy); if (subEntropy.compareTo(bestInfoGain) >= 0) { bestInfoGain = subEntropy; bestFeature = future; } } return bestFeature; } /** * 計算傳入樣本集的熵值 * * @param dataList 樣本集 * @return java.math.BigDecimal 熵 * @author ChenLuyang * @date 2019/2/22 9:41 */ public BigDecimal calcShannonEnt(List<Data> dataList) { //樣本總數 BigDecimal sumEntries = new BigDecimal(dataList.size() + ""); //香農熵 BigDecimal shannonEnt = new BigDecimal("0"); //統計各個分類結果的樣本數量 Map<String, Integer> resultCountMap = new HashMap<String, Integer>(); for (Data data : dataList) { Integer dataResultCount = resultCountMap.get(data.getResult()); if (dataResultCount == null) { resultCountMap.put(data.getResult(), 1); } else { resultCountMap.put(data.getResult(), dataResultCount + 1); } } for (String resultCountKey : resultCountMap.keySet()) { BigDecimal resultCountValue = new BigDecimal(resultCountMap.get(resultCountKey).toString()); BigDecimal prob = resultCountValue.divide(sumEntries, 5, BigDecimal.ROUND_HALF_DOWN); shannonEnt = shannonEnt.subtract(prob.multiply(new BigDecimal(Math.log(prob.doubleValue()) / Math.log(2) + ""))); } return shannonEnt; } /** * 根據某個特徵的特徵值,進行樣本數據的劃分,將劃分後的樣本數據集返回 * * @param dataList 待劃分的樣本數據集 * @param future 篩選的特徵依據 * @param info 篩選的特徵值依據 * @return java.util.List<ai.tree.data.Data> 按照指定特徵值分類後的數據集 * @author ChenLuyang * @date 2019/2/21 18:26 */ public List<Data> splitDataList(List<Data> dataList, String future, String info) { List<Data> resultDataList = new ArrayList<Data>(); for (Data data : dataList) { if (data.getFeature().get(future).equals(info)) { Data newData = (Data) data.clone(); newData.getFeature().remove(future); resultDataList.add(newData); } } return resultDataList; } /** * L:每個特徵的描述信息的類型 * F:特徵的類型 * R:最終分類結果的類型 */ public class TreeNode<L, F, R> { /** * 該節點的最優特徵的描述信息 */ private L label; /** * 根據不一樣的特徵做出響應的決定。 * K爲特徵值,V爲該特徵值做出的決策節點 */ private Map<F, TreeNode> featureDecisionMap; /** * 是否爲最終分類節點 */ private boolean isFinal; /** * 最終分類結果信息 */ private R resultClassify; /** * 設置葉子節點 * * @param resultClassify 最終分類結果 * @return void * @author ChenLuyang * @date 2019/2/22 18:31 */ public void setResultNode(R resultClassify) { this.isFinal = true; this.resultClassify = resultClassify; } /** * 設置分支節點 * * @param label 當前分支節點的描述信息(特徵) * @param featureDecisionMap 當前分支節點的各個特徵值,與其對應的子節點 * @return void * @author ChenLuyang * @date 2019/2/22 18:31 */ public void setDecisionNode(L label, Map<F, TreeNode> featureDecisionMap) { this.isFinal = false; this.label = label; this.featureDecisionMap = featureDecisionMap; } /** * 展現當前節點的樹結構 * * @return void * @author ChenLuyang * @date 2019/2/22 16:54 */ public String showTree() { HashMap<String, String> treeMap = new HashMap<String, String>(); if (isFinal) { String key = "result"; R value = resultClassify; treeMap.put(key, value.toString()); } else { String key = label.toString(); HashMap<F, String> showFutureMap = new HashMap<F, String>(); for (F f : featureDecisionMap.keySet()) { showFutureMap.put(f, featureDecisionMap.get(f).showTree()); } String value = showFutureMap.toString(); treeMap.put(key, value); } return treeMap.toString(); } public L getLabel() { return label; } public Map<F, TreeNode> getFeatureDecisionMap() { return featureDecisionMap; } public R getResultClassify() { return resultClassify; } public boolean getFinal() { return isFinal; } } }
Start.java測試
package ai.tree.algorithm; import ai.tree.data.Data; import java.util.ArrayList; import java.util.HashMap; import java.util.List; /** * @author ChenLuyang * @date 2019/2/22 */ public class Start { /** * 構建測試樣本集,測試樣本以下: 樣本特徵:{頭髮長短=短髮, 身材=胖, 是否戴眼鏡=有眼鏡} 分類:男 樣本特徵:{頭髮長短=長髮, 身材=瘦, 是否戴眼鏡=有眼鏡} 分類:女 樣本特徵:{頭髮長短=短髮, 身材=胖, 是否戴眼鏡=有眼鏡} 分類:女 樣本特徵:{頭髮長短=長髮, 身材=胖, 是否戴眼鏡=沒眼鏡} 分類:男 樣本特徵:{頭髮長短=短髮, 身材=瘦, 是否戴眼鏡=沒眼鏡} 分類:男 樣本特徵:{頭髮長短=長髮, 身材=瘦, 是否戴眼鏡=有眼鏡} 分類:女 樣本特徵:{頭髮長短=長髮, 身材=胖, 是否戴眼鏡=有眼鏡} 分類:男 * @author ChenLuyang * @date 2019/2/21 15:34 * @return java.util.List<ai.tree.data.DecisionTreeTestData.Data> 樣本集 */ public static List<Data> createDataList(){ /** * 樣本特徵描述 * @author ChenLuyang * @date 2019/2/22 18:55 * @return java.util.List<ai.tree.data.Data> */ String[] labels = new String[]{"是否戴眼鏡", "頭髮長短", "身材"}; List<Data> dataList = new ArrayList<Data>(); HashMap<String,String> feature1 = new HashMap<String, String>(); feature1.put(labels[0],"有眼鏡"); feature1.put(labels[1].toString(),"短髮"); feature1.put(labels[2].toString(),"胖"); dataList.add(new Data(feature1,"男")); HashMap<String,String> feature2 = new HashMap<String, String>(); feature2.put(labels[0],"有眼鏡"); feature2.put(labels[1],"長髮"); feature2.put(labels[2],"瘦"); dataList.add(new Data(feature2,"女")); HashMap<String,String> feature3 = new HashMap<String, String>(); feature3.put(labels[0],"有眼鏡"); feature3.put(labels[1],"短髮"); feature3.put(labels[2],"胖"); dataList.add(new Data(feature3,"女")); HashMap<String,String> feature4 = new HashMap<String, String>(); feature4.put(labels[0],"沒眼鏡"); feature4.put(labels[1],"長髮"); feature4.put(labels[2],"胖"); dataList.add(new Data(feature4,"男")); HashMap<String,String> feature5 = new HashMap<String, String>(); feature5.put(labels[0],"沒眼鏡"); feature5.put(labels[1],"短髮"); feature5.put(labels[2],"瘦"); dataList.add(new Data(feature5,"男")); HashMap<String,String> feature6 = new HashMap<String, String>(); feature6.put(labels[0],"有眼鏡"); feature6.put(labels[1],"長髮"); feature6.put(labels[2],"瘦"); dataList.add(new Data(feature6,"女")); HashMap<String,String> feature7 = new HashMap<String, String>(); feature7.put(labels[0],"有眼鏡"); feature7.put(labels[1],"長髮"); feature7.put(labels[2],"胖"); dataList.add(new Data(feature7,"男")); return dataList; } public static void main(String[] args) { DecisionTree decisionTree = new DecisionTree(); //使用測試樣本生成決策樹 DecisionTree.TreeNode tree = decisionTree.createTree(createDataList()); //展現決策樹 System.out.println(tree.showTree()); } }
生成樹結構:{是否戴眼鏡={沒眼鏡={result=男}, 有眼鏡={身材={胖={頭髮長短={長髮={result=男}, 短髮={result=女}}}, 瘦={result=女}}}}}this