樸素貝葉斯算法java實現(多項式模型)

網上有不少對樸素貝葉斯算法的說明的文章,在對算法實現前,參考了一下幾篇文章:java

NLP系列(2)_用樸素貝葉斯進行文本分類(上)算法

NLP系列(3)_用樸素貝葉斯進行文本分類(下)安全

帶你搞懂樸素貝葉斯分類算法ide

其中「帶你搞懂樸素貝葉斯算法」在我看來比較容易理解,上面兩篇比較詳細,更深刻。函數

算法java實現測試

第一步對訓練集進行預處理,分詞並計算詞頻,獲得存儲訓練集的特徵集合編碼

/**
            * 全部訓練集分詞特徵集合
     * 第一個String表明分類標籤,也就是存儲該類別訓練集的文件名
     * 第二個String表明某條訓練集的路徑,這裏存儲的是該條語料的絕對路徑
     * Map<String, Integer>存儲的是該條訓練集的特徵詞和詞頻
     *
     */
    private static Map<String, Map<String, Map<String, Integer>>> allTrainFileSegsMap = new HashMap<String, Map<String, Map<String, Integer>>>();
    /**
     * 放大因子
     * 在計算中,因各個詞的先驗機率都比較小,咱們乘以固定的值放大,便於計算
     */
    private static BigDecimal zoomFactor = new BigDecimal(10);

    /**
     * 對傳入的訓練集進行分詞,獲取訓練集分詞後的詞和詞頻集合
     * @param trainFilePath  訓練集路徑
     */
    public static void getFeatureClassForTrainText(String trainFilePath){
        //經過將訓練集路徑字符串轉變成抽象路徑,建立一個File對象
        File trainFileDirs = new File(trainFilePath);
        //獲取該路徑下的全部分類路徑
        File[] trainFileDirList = trainFileDirs.listFiles();
        if (trainFileDirList == null){
            System.out.println("訓練數據集不存在");
        }
        for (File trainFileDir : trainFileDirList){
            //讀取該分類下的全部訓練文件
            List<String> fileList = null;
            try {
                fileList = FileOptionUtil.readDirs(trainFileDir.getAbsolutePath());
                if (fileList.size() != 0){
                    //遍歷訓練集目錄數據,進行分詞和類別標籤處理
                    for(String filePath : fileList){
                        System.out.println("開始對此訓練集進行分詞處理:" + filePath);
                        //分詞處理,獲取每條訓練集文本的詞和詞頻
                        //若知道文件編碼的話,不要用下述的判斷編碼格式了,效率過低
//                        Map<String, Integer> contentSegs = IKWordSegmentation.segString(FileOptionUtil.readFile(filePath, FileOptionUtil.getCodeString(filePath)));
                        Map<String, Integer> contentSegs = IKWordSegmentation.segString(FileOptionUtil.readFile(filePath, "gbk"));
                        if (allTrainFileSegsMap.containsKey(trainFileDir.getName())){
                            Map<String, Map<String, Integer>> allSegsMap = allTrainFileSegsMap.get(trainFileDir.getName());
                            allSegsMap.put(filePath, contentSegs);
                            allTrainFileSegsMap.put(trainFileDir.getName(), allSegsMap);
                        } else {
                            Map<String, Map<String, Integer>> allSegsMap = new HashMap<String, Map<String, Integer>>();
                            allSegsMap.put(filePath, contentSegs);
                            allTrainFileSegsMap.put(trainFileDir.getName(), allSegsMap);
                        }
                    }
                } else {
                    System.out.println("該分類下沒有待訓練語料");
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
View Code

第二步計算類別的先驗機率spa

/**
     * 計算類別C的先驗機率
     * 先驗機率P(c)= 類c下單詞總數/整個訓練樣本的單詞總數
     * @param category
     * @return 類C的先驗機率
     */
    public static BigDecimal prioriProbability(String category){
        BigDecimal categoryWordsCount = new BigDecimal(categoryWordCount(category));
        BigDecimal allTrainFileWordCount = new BigDecimal(getAllTrainCategoryWordsCount());
        return categoryWordsCount.divide(allTrainFileWordCount, 10, BigDecimal.ROUND_CEILING);
    }
View Code

第三步計算特徵詞的條件機率.net

/**
     * 多項式樸素貝葉斯類條件機率
     * 類條件機率P(IK|c)=(類c下單詞IK在各個文檔中出現過的次數之和+1)/(類c下單詞總數+|V|)
     * V是訓練樣本的單詞表(即抽取單詞,單詞出現屢次,只算一個),
     * |V|則表示訓練樣本包含多少種單詞。 P(IK|c)能夠看做是單詞tk在證實d屬於類c上提供了多大的證據,
     * 而P(c)則能夠認爲是類別c在總體上佔多大比例(有多大可能性)
     * @param category
     * @param word
     * @return
     */
    public static BigDecimal categoryConditionalProbability(String category, String word){
        BigDecimal wordCount = new BigDecimal(wordInCategoryCount(word, category) + 1);
        BigDecimal categoryTrainFileWordCount = new BigDecimal(categoryWordCount(category) + getAllTrainCategoryWordCount());
        return wordCount.divide(categoryTrainFileWordCount, 10, BigDecimal.ROUND_CEILING);
    }
View Code

第四步計算給定文本的分類結果3d

/**
     * 多項式樸素貝葉斯分類結果
     * P(C_i|w_1,w_2...w_n) = P(w_1,w_2...w_n|C_i) * P(C_i) / P(w_1,w_2...w_n)
     * = P(w_1|C_i) * P(w_2|C_i)...P(w_n|C_i) * P(C_i) / (P(w_1) * P(w_2) ...P(w_n))
     * @param words
     * @return
     */
    public static Map<String, BigDecimal> classifyResult(Set<String> words){
        Map<String, BigDecimal> resultMap = new HashMap<String, BigDecimal>();
        //獲取訓練語料集全部的分類集合
        Set<String> categorySet = allTrainFileSegsMap.keySet();
        //循環計算每一個類別的機率
        for (String categorySetLabel : categorySet){
            BigDecimal probability = new BigDecimal(1.0);
            for (String word : words){
                probability = probability.multiply(categoryConditionalProbability(categorySetLabel, word)).multiply(zoomFactor);
            }
            resultMap.put(categorySetLabel, probability.multiply(prioriProbability(categorySetLabel)));
        }
        return resultMap;
    }
View Code

輔助函數

/**
     * 對分類結果進行比較,得出機率最大的類
     * @param classifyResult
     * @return
     */
    public static String getClassifyResultName(Map<String, BigDecimal> classifyResult){
        String classifyName = "";
        if (classifyResult.isEmpty()){
            return classifyName;
        }
        BigDecimal result = new BigDecimal(0);
        Set<String> classifyResultSet = classifyResult.keySet();
        for (String classifyResultSetString : classifyResultSet){
            if (classifyResult.get(classifyResultSetString).compareTo(result) >= 1){
                result = classifyResult.get(classifyResultSetString);
                classifyName = classifyResultSetString;
            }
        }
        return classifyName;
    }

    /**
     * 統計給定類別下的單詞總數(帶詞頻計算)
     * @param categoryLabel  指定類別參數
     * @return
     */
    public static Long categoryWordCount(String categoryLabel){
        Long sum = 0L;
        Map<String, Map<String, Integer>> categoryWordMap = allTrainFileSegsMap.get(categoryLabel);
        if (categoryWordMap == null){
            return sum;
        }
        Set<String> categoryWordMapKeySet = categoryWordMap.keySet();
        for (String categoryLabelString : categoryWordMapKeySet){
            Map<String, Integer> categoryWordMapDataMap = categoryWordMap.get(categoryLabelString);
            List<Map.Entry<String, Integer>> dataWordMapList = new ArrayList<Map.Entry<String, Integer>>(categoryWordMapDataMap.entrySet());
            for (int i=0; i<dataWordMapList.size(); i++){
                sum += dataWordMapList.get(i).getValue();
            }
        }
        return sum;
    }

    /**
     * 獲取訓練樣本全部詞的總數(詞總數計算是帶上詞頻的,也就是能夠重複算數)
     * @return
     */
    public static Long getAllTrainCategoryWordsCount(){
        Long sum = 0L;
        //獲取全部分類
        Set<String> categoryLabels = allTrainFileSegsMap.keySet();
        //循環相加每一個類下的詞總數
        for (String categoryLabel : categoryLabels){
            sum += categoryWordCount(categoryLabel);
        }
        return sum;
    }

    /**
     * 獲取訓練樣本下各個類別不重複詞的總詞數,區別於getAllTrainCategoryWordsCount()方法,此處計算不計算詞頻
     * 備註:此處並非嚴格意義上的進行全量詞表生成後的計算,也就是加入類別1有"中國=6"、類別2有"中國=2",總詞數算中國兩次,
     * 也就是說,咱們在計算的時候並無生成全局詞表(將全部詞都做爲出現一次)
     * @return
     */
    public static Long getAllTrainCategoryWordCount(){
        Long sum = 0L;
        //獲取全部分類
        Set<String> categoryLabels = allTrainFileSegsMap.keySet();
        for (String cateGoryLabelsLabel : categoryLabels){
            Map<String, Map<String, Integer>> categoryWordMap = allTrainFileSegsMap.get(cateGoryLabelsLabel);
            List<Map.Entry<String, Map<String, Integer>>> categoryWordMapList = new ArrayList<Map.Entry<String, Map<String, Integer>>>(categoryWordMap.entrySet());
            for (int i=0; i<categoryWordMapList.size(); i++){
                sum += categoryWordMapList.get(i).getValue().size();
            }
        }
        return sum;
    }

    /**
     * 計算測試數據的每一個單詞在每一個類下出現的總數
     * @param word
     * @param categoryLabel
     * @return
     */
    public static Long wordInCategoryCount(String word, String categoryLabel){
        Long sum = 0L;
        Map<String, Map<String, Integer>> categoryWordMap = allTrainFileSegsMap.get(categoryLabel);
        Set<String> categoryWordMapKeySet = categoryWordMap.keySet();
        for (String categoryWordMapKeySetFile : categoryWordMapKeySet){
            Map<String, Integer> categoryWordMapDataMap = categoryWordMap.get(categoryWordMapKeySetFile);
            Integer value = categoryWordMapDataMap.get(word);
            if (value!=null && value>0){
                sum += value;
            }
        }
        return sum;
    }

    /**
     * 獲取全部分類類別
     * @return
     */
    public Set<String> getAllCategory(){
        return allTrainFileSegsMap.keySet();
    }
View Code

main函數測試

//main方法
    public static void main(String[] args){
        BayesNB.getFeatureClassForTrainText("/Users/zhouyh/work/yanfa/xunlianji/train/");
        String s = "全國假日旅遊部際協調會議的各成員單位和中央各有關部門圍繞一個目標,積極配合,主動工做,抓得深刻,抓得紮實。主要有如下幾個特色:一是安全工做有部署有檢查有跟蹤。國務院安委會辦公室節前深刻部署全面檢查,節中及時總結,下發關於黃金週後期安全工做的緊急通知;鐵路、民航、交通等部門針對黃金週先後期旅客集中返程交通壓力較大狀況,及時調遣應急運力;質檢總局進一步強化節日期間質量安全監管工做;旅遊部門每日及時發佈旅遊信息通報,有效引導遊客。二是各方面主動協調密切配合。各省區市增強了在安全事故問題上的協調與溝通,化解了一些跨省區矛盾和問題;鐵道、民航部門準時準確報送信息;中宣部和中央文明辦以黃金週旅遊爲載體," +
                "部署精神文明建設和踐行社會主義榮辱觀的宣傳活動;中國氣象局及時將黃金週每日氣象分析送交各有關部門;公安部專門部署警力,爲協調遊客流動大的城市及景區作了大量工做;旅遊部門密切配合有關部門作好各種事故處理和投訴調解工做。三是政府各部門的社會服務意識大爲加強。外交部及其駐外領事館及時提供境外安全信息爲旅遊者服務;中央電視臺、地方電視臺和各大媒體及各地方媒體提供的旅遊信息十分豐富;氣象信息服務充分具體;中消協提出多項旅遊警示。各部門的密切配合和主動服務配合,確保了本次黃金週的順利平穩運行。";
        Set<String> words = IKWordSegmentation.segString(s).keySet();

        Map<String, BigDecimal> resultMap = BayesNB.classifyResult(words);
        String category = BayesNB.getClassifyResultName(resultMap);
        System.out.println(category);
    }
View Code

通過上述步驟便可實現簡單的多項式模型算法,有部分代碼參考了網上的算法代碼。

相關文章
相關標籤/搜索