樸素貝葉斯在文本分類中的應用之 伯努利

貝葉斯定理:算法

           對於隨機事件A和B:A發生的機率爲P(A),B發生的機率爲P(B),在B發生的狀況下,A發生的機率爲P(A|B)。A和B一塊兒發生的聯合機率爲P(AB)。有:P(A|B) X P(B) = P(AB) = P(B|A) X P(A),則有:優化

P(A|B) = P(B|A)P(A) / P(B)ui

      文本分類(Text Categorization)是指計算機將一片文檔歸於預先給定的某一類或幾類的過程。文本分類的特徵提取過程是分詞。目前比較好的中文分詞器有中科院的ictclas,庖丁,IK等等。通過分詞後,每一個詞就是一個特徵。分詞中能夠本身配置停用詞庫,擴展詞庫等。特徵選擇有諸如TF-IDF,CHI等特徵選擇算法,就不在此贅述。spa

      樸素貝葉斯計算先驗機率P(C)和條件機率P(X|C)的方法有兩種:多項式模型伯努利模型。二者在計算的時候有兩點差異:多項式會統計詞頻,而伯努利認爲單詞出現就記爲1,沒出現記爲0,能夠看到一個是基於詞頻,一個是基於文檔頻率;伯努利在分類時,將詞庫中的沒有出如今待分類文本的詞做爲反方考慮orm

      在計算條件機率時,當待分類文本中的某個詞沒有出如今詞庫中時,機率爲0,會致使很嚴重的問題,須要考慮拉普拉斯平滑(laplace smoothing):它是將全部詞出現的次數+1,再進行統計。索引

      再一個問題就是機率過小而詞數太多,會超double,用log將乘法轉成加法事件



 伯努利樸素貝葉斯算法僞代碼以下:內存



伯努利樸素貝葉斯代碼文檔

 

Java代碼  收藏代碼get

  1. /** 

  2.  * @author zhongmin.yzm 

  3.  * 語料訓練並載入內存 

  4.  * */  

  5. public class TrainingDataManager {  

  6.   

  7.     /** 特徵索引 */  

  8.     private Map<String, Integer> termIndex;  

  9.     /** 類索引 */  

  10.     private Map<String, Integer> classIndex;  

  11.     /** 索引-類名 */  

  12.     public List<String>          className;  

  13.   

  14.     /**類的個數*/  

  15.     private int                  numClasses = 0;  

  16.   

  17.     /**訓練樣本的全部特徵(出現屢次只算一個)*/  

  18.     private int                  vocabulary = 0;  

  19.   

  20.     /**訓練文本總數*/  

  21.     private int                  DocsNum    = 0;  

  22.   

  23.     /**屬於某類的文檔個數*/  

  24.     private int[]                classDocs;  

  25.   

  26.     /**類別c中包含屬性 x的訓練文本數量*/  

  27.     private int[][]              classKeyMap;  

  28.   

  29.     /** 標誌位: 分類時的優化 */  

  30.     private static boolean       flag[];  

  31.   

  32.     private void buildIndex(List<List<String>> contents, List<String> labels) {  

  33.         classIndex = new HashMap<String, Integer>();  

  34.         termIndex = new HashMap<String, Integer>();  

  35.         className = new ArrayList<String>();  

  36.         Integer idTerm = new Integer(-1);  

  37.         Integer idClass = new Integer(-1);  

  38.         DocsNum = labels.size();  

  39.         for (int i = 0; i < DocsNum; ++i) {  

  40.             List<String> content = contents.get(i);  

  41.             String label = labels.get(i);  

  42.             if (!classIndex.containsKey(label)) {  

  43.                 idClass++;  

  44.                 classIndex.put(label, idClass);  

  45.                 className.add(label);  

  46.             }  

  47.             for (String term : content) {  

  48.                 if (!termIndex.containsKey(term)) {  

  49.                     idTerm++;  

  50.                     termIndex.put(term, idTerm);  

  51.                 }  

  52.             }  

  53.         }  

  54.         vocabulary = termIndex.size();  

  55.         numClasses = classIndex.size();  

  56.     }  

  57.   

  58.     public void startTraining(List<List<String>> contents, List<String> labels) {  

  59.         buildIndex(contents, labels);  

  60.         //去重  

  61.         List<List<Integer>> contentsIndex = new ArrayList<List<Integer>>();  

  62.         for (int i = 0; i < DocsNum; ++i) {  

  63.             List<Integer> contentIndex = new ArrayList<Integer>();  

  64.             List<String> content = contents.get(i);  

  65.             for (String str : content) {  

  66.                 Integer wordIndex = termIndex.get(str);  

  67.                 contentIndex.add(wordIndex);  

  68.             }  

  69.             Collections.sort(contentIndex);  

  70.             int num = contentIndex.size();  

  71.             List<Integer> tmp = new ArrayList<Integer>();  

  72.             for (int j = 0; j < num; ++j) {  

  73.                 if (j == 0 || contentIndex.get(j - 1) != contentIndex.get(j)) {  

  74.                     tmp.add(contentIndex.get(j));  

  75.                 }  

  76.             }  

  77.             contentsIndex.add(tmp);  

  78.         }  

  79.         //  

  80.         classDocs = new int[numClasses];  

  81.         classKeyMap = new int[numClasses][vocabulary];  

  82.         flag = new boolean[vocabulary];  

  83.         for (int i = 0; i < DocsNum; ++i) {  

  84.             List<Integer> content = contentsIndex.get(i);  

  85.             String label = labels.get(i);  

  86.             Integer labelIndex = classIndex.get(label);  

  87.             classDocs[labelIndex]++;  

  88.             for (Integer wordIndex : content) {  

  89.                 classKeyMap[labelIndex][wordIndex]++;  

  90.             }  

  91.         }  

  92.     }  

  93.   

  94.     /** 分類 時間複雜度 O(c*v) */  

  95.     public String classify(List<String> text) {  

  96.         double maxPro = Double.NEGATIVE_INFINITY;  

  97.         int resultIndex = 0;  

  98.         //標記待分類文本中哪些特徵 屬於 特徵表  

  99.         for (int i = 0; i < vocabulary; ++i)  

  100.             flag[i] = false;  

  101.         for (String term : text) {  

  102.             Integer wordIndex = termIndex.get(term);  

  103.             if (wordIndex != null)  

  104.                 flag[wordIndex] = true;  

  105.         }  

  106.         //對特徵集中的每一個特徵: 若出如今待分類文本中,直接計算;不然做爲反方參與  

  107.         for (int classIndex = 0; classIndex < numClasses; ++classIndex) {  

  108.             double pro = Math.log10(getPreProbability(classIndex));  

  109.             for (int wordIndex = 0; wordIndex < vocabulary; ++wordIndex) {  

  110.                 if (flag[wordIndex])  

  111.                     pro += Math.log10(getClassConditionalProbability(classIndex, wordIndex));  

  112.                 else  

  113.                     pro += Math.log10(1 - getClassConditionalProbability(classIndex, wordIndex));  

  114.             }  

  115.             if (maxPro < pro) {  

  116.                 maxPro = pro;  

  117.                 resultIndex = classIndex;  

  118.             }  

  119.         }  

  120.         return className.get(resultIndex);  

  121.     }  

  122.   

  123.     /** 先驗機率: 類C包含的文檔數/總文檔數 */  

  124.     private double getPreProbability(int classIndex) {  

  125.         double ret = 0.0;  

  126.         ret = 1.0 * classDocs[classIndex] / DocsNum;  

  127.         return ret;  

  128.     }  

  129.   

  130.     /** 條件機率: 類C中包含關鍵字t的文檔個數/類C包含的文檔數 */  

  131.     private double getClassConditionalProbability(int classIndex, int termIndex) {  

  132.         int NCX = classKeyMap[classIndex][termIndex];  

  133.         int N = classDocs[classIndex];  

  134.         double ret = (NCX + 1.0) / (N + DocsNum);  

  135.         return ret;  

  136.     }  

  137.   

  138. }  

相關文章
相關標籤/搜索