java編寫ID3決策樹

說明:每一個樣本都會裝入Data樣本對象,決策樹生成算法接收的是一個Array<Data>樣本列表,因此構建測試數據時也要符合格式,最後生成的決策樹是樹的根節點,經過裏面提供的showTree()方法可查看整個樹結構,下面奉上源碼。java

 

Data.java算法

package ai.tree.data;

import java.util.HashMap;

/**
 * 樣本類
 * @author ChenLuyang
 * @date 2019/2/21
 */
public class Data implements Cloneable{
    /**
     * K是特徵描述,V是特徵值
     */
    private HashMap<String,String> feature = new HashMap<String, String>();

    /**
     * 該樣本結論
     */
    private String result;

    public Data(HashMap<String,String> feature,String result){
        this.feature = feature;
        this.result = result;
    }

    public HashMap<String, String> getFeature() {
        return feature;
    }

    public String getResult() {
        return result;
    }

    private void setFeature(HashMap<String, String> feature) {
        this.feature = feature;
    }

    @Override
    public Data clone()
    {
        Data object=null;
        try {
            object = (Data) super.clone();
            object.setFeature((HashMap<String, String>) this.feature.clone());
        } catch (CloneNotSupportedException e) {
            e.printStackTrace();
        }

        return object;
    }
}

  

DecisionTree.javaide

package ai.tree.algorithm;

import ai.tree.data.Data;

import java.math.BigDecimal;
import java.util.*;

/**
 * @author ChenLuyang
 * @date 2019/2/21
 */
public class DecisionTree {
    /**
     * 遞歸構建決策樹
     *
     * @param dataList 樣本集合
     * @return ai.tree.algorithm.DecisionTree.TreeNode 使用傳入樣本構建的決策節點
     * @author ChenLuyang
     * @date 2019/2/21 16:05
     */
    public TreeNode createTree(List<Data> dataList) {
        //建立當前節點
        TreeNode<String, String, String> nowTreeNode = new TreeNode<String, String, String>();
        //當前節點的各個分支節點
        Map<String, TreeNode> featureDecisionMap = new HashMap<String, TreeNode>();

        //統計當前樣本集中全部的分類結果
        Set<String> resultSet = new HashSet<String>();
        for (Data data :
                dataList) {
            resultSet.add(data.getResult());
        }

        //若是當前樣本集只有一種類別,則表示不用分類了,返回當前節點
        if (resultSet.size() == 1) {
            String resultClassify = resultSet.iterator().next();

            nowTreeNode.setResultNode(resultClassify);

            return nowTreeNode;
        }

        //若是數據集中特徵爲空,則選擇整個集合中出現次數最多的分類,做爲分類結果
        if (dataList.get(0).getFeature().size() == 0) {
            Map<String, Integer> countMap = new HashMap<String, Integer>();
            for (Data data :
                    dataList) {
                Integer num = countMap.get(data.getResult());
                if (num == null) {
                    countMap.put(data.getResult(), 1);
                } else {
                    countMap.put(data.getResult(), num + 1);
                }
            }

            String tmpResult = "";
            Integer tmpNum = 0;
            for (String res :
                    countMap.keySet()) {
                if (countMap.get(res) > tmpNum) {
                    tmpNum = countMap.get(res);
                    tmpResult = res;
                }
            }

            nowTreeNode.setResultNode(tmpResult);

            return nowTreeNode;
        }

        //尋找當前最優分類
        String bestLabel = chooseBestFeatureToSplit(dataList);

        //提取最優特徵的全部可能值
        Set<String> bestLabelInfoSet = new HashSet<String>();
        for (Data data :
                dataList) {
            bestLabelInfoSet.add(data.getFeature().get(bestLabel));
        }

        //使用最優特徵的各個特徵值進行分類
        for (String labelInfo :
                bestLabelInfoSet) {
            for (Data data :
                    dataList) {
            }
            List<Data> branchDataList = splitDataList(dataList, bestLabel, labelInfo);

            //最優特徵下該特徵值的節點
            TreeNode branchTreeNode = createTree(branchDataList);
            featureDecisionMap.put(labelInfo, branchTreeNode);
        }

        nowTreeNode.setDecisionNode(bestLabel, featureDecisionMap);

        return nowTreeNode;
    }

    /**
     * 計算傳入數據集中的最優分類特徵
     *
     * @param dataList
     * @return int 最優分類特徵的描述
     * @author ChenLuyang
     * @date 2019/2/21 14:12
     */
    public String chooseBestFeatureToSplit(List<Data> dataList) {
        //目前數據集中的特徵集合
        Set<String> futureSet = dataList.get(0).getFeature().keySet();

        //未分類時的熵
        BigDecimal baseEntropy = calcShannonEnt(dataList);

        //熵差
        BigDecimal bestInfoGain = new BigDecimal("0");
        //最優特徵
        String bestFeature = "";

        //按照各特徵分類
        for (String future :
                futureSet) {
            //該特徵分類後的熵
            BigDecimal futureEntropy = new BigDecimal("0");

            //該特徵的全部特徵值去重集合
            Set<String> futureInfoSet = new HashSet<String>();
            for (Data data :
                    dataList) {
                futureInfoSet.add(data.getFeature().get(future));
            }

            //按照該特徵的特徵值一一分類
            for (String futureInfo :
                    futureInfoSet) {
                List<Data> splitResultDataList = splitDataList(dataList, future, futureInfo);

                //分類後樣本數佔總樣本數的比例
                BigDecimal tmpProb = new BigDecimal(splitResultDataList.size() + "").divide(new BigDecimal(dataList.size() + ""), 5, BigDecimal.ROUND_HALF_DOWN);

                //所佔比例乘以分類後的樣本熵,而後再進行熵的累加
                futureEntropy = futureEntropy.add(tmpProb.multiply(calcShannonEnt(splitResultDataList)));
            }

            BigDecimal subEntropy = baseEntropy.subtract(futureEntropy);

            if (subEntropy.compareTo(bestInfoGain) >= 0) {
                bestInfoGain = subEntropy;
                bestFeature = future;
            }
        }

        return bestFeature;
    }

    /**
     * 計算傳入樣本集的熵值
     *
     * @param dataList 樣本集
     * @return java.math.BigDecimal 熵
     * @author ChenLuyang
     * @date 2019/2/22 9:41
     */
    public BigDecimal calcShannonEnt(List<Data> dataList) {
        //樣本總數
        BigDecimal sumEntries = new BigDecimal(dataList.size() + "");
        //香農熵
        BigDecimal shannonEnt = new BigDecimal("0");
        //統計各個分類結果的樣本數量
        Map<String, Integer> resultCountMap = new HashMap<String, Integer>();
        for (Data data :
                dataList) {
            Integer dataResultCount = resultCountMap.get(data.getResult());
            if (dataResultCount == null) {
                resultCountMap.put(data.getResult(), 1);
            } else {
                resultCountMap.put(data.getResult(), dataResultCount + 1);
            }
        }

        for (String resultCountKey :
                resultCountMap.keySet()) {
            BigDecimal resultCountValue = new BigDecimal(resultCountMap.get(resultCountKey).toString());

            BigDecimal prob = resultCountValue.divide(sumEntries, 5, BigDecimal.ROUND_HALF_DOWN);
            shannonEnt = shannonEnt.subtract(prob.multiply(new BigDecimal(Math.log(prob.doubleValue()) / Math.log(2) + "")));
        }

        return shannonEnt;
    }

    /**
     * 根據某個特徵的特徵值,進行樣本數據的劃分,將劃分後的樣本數據集返回
     *
     * @param dataList 待劃分的樣本數據集
     * @param future   篩選的特徵依據
     * @param info     篩選的特徵值依據
     * @return java.util.List<ai.tree.data.Data> 按照指定特徵值分類後的數據集
     * @author ChenLuyang
     * @date 2019/2/21 18:26
     */
    public List<Data> splitDataList(List<Data> dataList, String future, String info) {
        List<Data> resultDataList = new ArrayList<Data>();
        for (Data data :
                dataList) {
            if (data.getFeature().get(future).equals(info)) {
                Data newData = (Data) data.clone();
                newData.getFeature().remove(future);
                resultDataList.add(newData);
            }
        }

        return resultDataList;
    }

    /**
     * L:每個特徵的描述信息的類型
     * F:特徵的類型
     * R:最終分類結果的類型
     */
    public class TreeNode<L, F, R> {
        /**
         * 該節點的最優特徵的描述信息
         */
        private L label;

        /**
         * 根據不一樣的特徵做出響應的決定。
         * K爲特徵值,V爲該特徵值做出的決策節點
         */
        private Map<F, TreeNode> featureDecisionMap;

        /**
         * 是否爲最終分類節點
         */
        private boolean isFinal;

        /**
         * 最終分類結果信息
         */
        private R resultClassify;

        /**
         * 設置葉子節點
         *
         * @param resultClassify 最終分類結果
         * @return void
         * @author ChenLuyang
         * @date 2019/2/22 18:31
         */
        public void setResultNode(R resultClassify) {
            this.isFinal = true;
            this.resultClassify = resultClassify;
        }

        /**
         * 設置分支節點
         *
         * @param label              當前分支節點的描述信息(特徵)
         * @param featureDecisionMap 當前分支節點的各個特徵值,與其對應的子節點
         * @return void
         * @author ChenLuyang
         * @date 2019/2/22 18:31
         */
        public void setDecisionNode(L label, Map<F, TreeNode> featureDecisionMap) {
            this.isFinal = false;
            this.label = label;
            this.featureDecisionMap = featureDecisionMap;
        }

        /**
         * 展現當前節點的樹結構
         *
         * @return void
         * @author ChenLuyang
         * @date 2019/2/22 16:54
         */
        public String showTree() {
            HashMap<String, String> treeMap = new HashMap<String, String>();
            if (isFinal) {
                String key = "result";
                R value = resultClassify;
                treeMap.put(key, value.toString());
            } else {
                String key = label.toString();
                HashMap<F, String> showFutureMap = new HashMap<F, String>();
                for (F f :
                        featureDecisionMap.keySet()) {
                    showFutureMap.put(f, featureDecisionMap.get(f).showTree());
                }
                String value = showFutureMap.toString();

                treeMap.put(key, value);
            }

            return treeMap.toString();
        }

        public L getLabel() {
            return label;
        }

        public Map<F, TreeNode> getFeatureDecisionMap() {
            return featureDecisionMap;
        }

        public R getResultClassify() {
            return resultClassify;
        }

        public boolean getFinal() {
            return isFinal;
        }
    }
}

  

Start.java測試

package ai.tree.algorithm;

import ai.tree.data.Data;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

/**
 * @author ChenLuyang
 * @date 2019/2/22
 */
public class Start {
    /**
     * 構建測試樣本集,測試樣本以下:
     樣本特徵:{頭髮長短=短髮, 身材=胖, 是否戴眼鏡=有眼鏡} 分類:男
     樣本特徵:{頭髮長短=長髮, 身材=瘦, 是否戴眼鏡=有眼鏡} 分類:女
     樣本特徵:{頭髮長短=短髮, 身材=胖, 是否戴眼鏡=有眼鏡} 分類:女
     樣本特徵:{頭髮長短=長髮, 身材=胖, 是否戴眼鏡=沒眼鏡} 分類:男
     樣本特徵:{頭髮長短=短髮, 身材=瘦, 是否戴眼鏡=沒眼鏡} 分類:男
     樣本特徵:{頭髮長短=長髮, 身材=瘦, 是否戴眼鏡=有眼鏡} 分類:女
     樣本特徵:{頭髮長短=長髮, 身材=胖, 是否戴眼鏡=有眼鏡} 分類:男
     * @author ChenLuyang
     * @date 2019/2/21 15:34
     * @return java.util.List<ai.tree.data.DecisionTreeTestData.Data> 樣本集
     */
    public static List<Data> createDataList(){
        /**
         * 樣本特徵描述
         * @author ChenLuyang
         * @date 2019/2/22 18:55
         * @return java.util.List<ai.tree.data.Data>
         */
        String[] labels = new String[]{"是否戴眼鏡", "頭髮長短", "身材"};

        List<Data> dataList = new ArrayList<Data>();

        HashMap<String,String> feature1 = new HashMap<String, String>();
        feature1.put(labels[0],"有眼鏡");
        feature1.put(labels[1].toString(),"短髮");
        feature1.put(labels[2].toString(),"胖");
        dataList.add(new Data(feature1,"男"));

        HashMap<String,String> feature2 = new HashMap<String, String>();
        feature2.put(labels[0],"有眼鏡");
        feature2.put(labels[1],"長髮");
        feature2.put(labels[2],"瘦");
        dataList.add(new Data(feature2,"女"));

        HashMap<String,String> feature3 = new HashMap<String, String>();
        feature3.put(labels[0],"有眼鏡");
        feature3.put(labels[1],"短髮");
        feature3.put(labels[2],"胖");
        dataList.add(new Data(feature3,"女"));

        HashMap<String,String> feature4 = new HashMap<String, String>();
        feature4.put(labels[0],"沒眼鏡");
        feature4.put(labels[1],"長髮");
        feature4.put(labels[2],"胖");
        dataList.add(new Data(feature4,"男"));

        HashMap<String,String> feature5 = new HashMap<String, String>();
        feature5.put(labels[0],"沒眼鏡");
        feature5.put(labels[1],"短髮");
        feature5.put(labels[2],"瘦");
        dataList.add(new Data(feature5,"男"));

        HashMap<String,String> feature6 = new HashMap<String, String>();
        feature6.put(labels[0],"有眼鏡");
        feature6.put(labels[1],"長髮");
        feature6.put(labels[2],"瘦");
        dataList.add(new Data(feature6,"女"));

        HashMap<String,String> feature7 = new HashMap<String, String>();
        feature7.put(labels[0],"有眼鏡");
        feature7.put(labels[1],"長髮");
        feature7.put(labels[2],"胖");
        dataList.add(new Data(feature7,"男"));

        return dataList;
    }

    public static void main(String[] args) {
        DecisionTree decisionTree = new DecisionTree();

        //使用測試樣本生成決策樹
        DecisionTree.TreeNode tree = decisionTree.createTree(createDataList());

        //展現決策樹
        System.out.println(tree.showTree());
    }
}

  

生成樹結構:{是否戴眼鏡={沒眼鏡={result=男}, 有眼鏡={身材={胖={頭髮長短={長髮={result=男}, 短髮={result=女}}}, 瘦={result=女}}}}}this

相關文章
相關標籤/搜索