樸素貝葉斯在文本分類中的應用之「多項式」

  1.    貝葉斯分類算法基於托馬斯貝葉斯發明的貝葉斯定理,他提出的貝葉斯定理對於現代機率論和數理統計的發展有重要的影響。算法

  2.       貝葉斯定理:ui

  3.            對於隨機事件A和B:A發生的機率爲P(A),B發生的機率爲P(B),在B發生的狀況下,A發生的機率爲P(A|B)。A和B一塊兒發生的聯合機率爲P(AB)。有:P(A|B) X P(B) = P(AB) = P(B|A) X P(A),則有:this

  4. P(A|B) = P(B|A)P(A) / P(B)spa

  5.       文本分類(Text Categorization)是指計算機將一片文檔歸於預先給定的某一類或幾類的過程。文本分類的特徵提取過程是分詞。目前比較好的中文分詞器有中科院的ictclas,庖丁,IK等等。通過分詞後,每一個詞就是一個特徵。分詞中能夠本身配置停用詞庫,擴展詞庫等。特徵選擇有諸如TF-IDF,CHI等特徵選擇算法,就不在此贅述。orm

  6.       樸素貝葉斯計算先驗機率P(C)和條件機率P(X|C)的方法有兩種:多項式模型伯努利模型。二者在計算的時候有兩點差異:多項式會統計詞頻,而伯努利認爲單詞出現就記爲1,沒出現記爲0,能夠看到一個是基於詞頻,一個是基於文檔頻率;伯努利在分類時,將詞庫中的沒有出如今待分類文本的詞做爲反方考慮索引

  7.       在計算條件機率時,當待分類文本中的某個詞沒有出如今詞庫中時,機率爲0,會致使很嚴重的問題,須要考慮拉普拉斯平滑(laplace smoothing):它是將全部詞出現的次數+1,再進行統計。事件

  8.       再一個問題就是機率過小而詞數太多,會超double,用log將乘法轉成加法文檔

  9.       多項式樸素貝葉斯算法僞代碼以下:get






  10. public class NaiveBayesManager {  it

  11.   

  12.     /**關鍵詞索引 關鍵詞-索引*/  

  13.     private Map<String, Integer> termIndex;  

  14.     /**類的索引 類名稱-索引*/  

  15.     private Map<String, Integer> classIndex;  

  16.     /** 類名 */  

  17.     private List<String>         className;  

  18.   

  19.     /**某類的文檔中全部特徵出現的總次數*/  

  20.     private int                  classTermsCount[];  

  21.   

  22.     /**某類的文檔中某特徵出現的次數之和*/  

  23.     private int                  classKeyMap[][];  

  24.   

  25.     /**類的個數*/  

  26.     private int                  numClasses      = 0;  

  27.   

  28.     /**訓練樣本的全部特徵(出現屢次只算一個)*/  

  29.     private int                  vocabulary      = 0;  

  30.   

  31.     /**訓練樣本的特徵總次數*/  

  32.     private int                  totalTermsCount = 0;  

  33.   

  34.     /** 創建類名和特徵名的索引 */  

  35.     private void buildIndex(List<Corpus> orignCorpus) {  

  36.         classIndex = new HashMap<String, Integer>();  

  37.         termIndex = new HashMap<String, Integer>();  

  38.         className = new ArrayList<String>();  

  39.         Integer idTerm = new Integer(-1);  

  40.         Integer idClass = new Integer(-1);  

  41.         for (int i = 0; i < orignCorpus.size(); ++i) {  

  42.             Corpus corpus = orignCorpus.get(i);  

  43.             List<String> terms = corpus.getSegments();  

  44.             String label = corpus.getCat();  

  45.             if (!classIndex.containsKey(label)) {  

  46.                 idClass++;  

  47.                 classIndex.put(label, idClass);  

  48.                 className.add(label);  

  49.             }  

  50.             for (String term : terms) {  

  51.                 totalTermsCount++;  

  52.                 if (!termIndex.containsKey(term)) {  

  53.                     idTerm++;  

  54.                     termIndex.put(term, idTerm);  

  55.                 }  

  56.             }  

  57.         }  

  58.         vocabulary = termIndex.size();  

  59.         numClasses = classIndex.size();  

  60.     }  

  61.   

  62.     /** 

  63.      * 訓練 

  64.      * */  

  65.     public void startTraining(List<Corpus> orignCorpus) {  

  66.         buildIndex(orignCorpus);  

  67.         classTermsCount = new int[numClasses + 1];  

  68.         classKeyMap = new int[numClasses + 1][vocabulary + 1];  

  69.         for (int i = 0; i < orignCorpus.size(); ++i) {  

  70.             Corpus corpus = orignCorpus.get(i);  

  71.             List<String> terms = corpus.getSegments();  

  72.             String label = corpus.getCat();  

  73.             Integer labelIndex = classIndex.get(label);  

  74.             for (String term : terms) {  

  75.                 Integer wordIndex = termIndex.get(term);  

  76.                 classTermsCount[labelIndex]++;  

  77.                 classKeyMap[labelIndex][wordIndex]++;  

  78.             }  

  79.         }  

  80.     }  

  81.   

  82.     public String classify(List<String> terms) {  

  83.         int result = 0;  

  84.         double maxPro = Double.NEGATIVE_INFINITY;  

  85.         for (int cIndex = 0; cIndex < numClasses; ++cIndex) {  

  86.             double pro = Math.log10(getPreProbability(cIndex));  

  87.             for (String term : terms) {  

  88.                 pro += Math.log10(getClassConditonalProbability(cIndex, term));  

  89.             }  

  90.             if (maxPro < pro) {  

  91.                 maxPro = pro;  

  92.                 result = cIndex;  

  93.             }  

  94.         }  

  95.         return className.get(result);  

  96.     }  

  97.   

  98.     private double getPreProbability(int classIndex) {  

  99.         double ret = 0;  

  100.         int NC = classTermsCount[classIndex];  

  101.         int N = totalTermsCount;  

  102.         ret = 1.0 * NC / N;  

  103.         return ret;  

  104.     }  

  105.   

  106.     private double getClassConditonalProbability(int classIndex, String term) {  

  107.         double ret = 0;  

  108.         int NCX = 0;  

  109.         int N = 0;  

  110.         int V = 0;  

  111.   

  112.         Integer wordIndex = termIndex.get(term);  

  113.         if (wordIndex != null)  

  114.             NCX = classKeyMap[classIndex][wordIndex];  

  115.   

  116.         N = classTermsCount[classIndex];  

  117.   

  118.         V = vocabulary;  

  119.   

  120.         ret = (NCX + 1.0) / (N + V); //laplace smoothing. 拉普拉斯平滑處理   

  121.         return ret;  

  122.     }  

  123.   

  124.     public Map<String, Integer> getTermIndex() {  

  125.         return termIndex;  

  126.     }  

  127.   

  128.     public void setTermIndex(Map<String, Integer> termIndex) {  

  129.         this.termIndex = termIndex;  

  130.     }  

  131.   

  132.     public Map<String, Integer> getClassIndex() {  

  133.         return classIndex;  

  134.     }  

  135.   

  136.     public void setClassIndex(Map<String, Integer> classIndex) {  

  137.         this.classIndex = classIndex;  

  138.     }  

  139.   

  140.     public List<String> getClassName() {  

  141.         return className;  

  142.     }  

  143.   

  144.     public void setClassName(List<String> className) {  

  145.         this.className = className;  

  146.     }  

  147.   

  148.     public int[] getClassTermsCount() {  

  149.         return classTermsCount;  

  150.     }  

  151.   

  152.     public void setClassTermsCount(int[] classTermsCount) {  

  153.         this.classTermsCount = classTermsCount;  

  154.     }  

  155.   

  156.     public int[][] getClassKeyMap() {  

  157.         return classKeyMap;  

  158.     }  

  159.   

  160.     public void setClassKeyMap(int[][] classKeyMap) {  

  161.         this.classKeyMap = classKeyMap;  

  162.     }  

  163.   

  164.     public int getNumClasses() {  

  165.         return numClasses;  

  166.     }  

  167.   

  168.     public void setNumClasses(int numClasses) {  

  169.         this.numClasses = numClasses;  

  170.     }  

  171.   

  172.     public int getVocabulary() {  

  173.         return vocabulary;  

  174.     }  

  175.   

  176.     public void setVocabulary(int vocabulary) {  

  177.         this.vocabulary = vocabulary;  

  178.     }  

  179.   

  180.     public int getTotalTermsCount() {  

  181.         return totalTermsCount;  

  182.     }  

  183.   

  184.     public void setTotalTermsCount(int totalTermsCount) {  

  185.         this.totalTermsCount = totalTermsCount;  

  186.     }  

  187.   

  188.     public static String getSplitword() {  

  189.         return splitWord;  

  190.     }  

  191.   

  192. }  

相關文章
相關標籤/搜索