貝葉斯定理:算法
對於隨機事件A和B:A發生的機率爲P(A),B發生的機率爲P(B),在B發生的狀況下,A發生的機率爲P(A|B)。A和B一塊兒發生的聯合機率爲P(AB)。有:P(A|B) X P(B) = P(AB) = P(B|A) X P(A),則有:優化
P(A|B) = P(B|A)P(A) / P(B)ui
文本分類(Text Categorization)是指計算機將一片文檔歸於預先給定的某一類或幾類的過程。文本分類的特徵提取過程是分詞。目前比較好的中文分詞器有中科院的ictclas,庖丁,IK等等。通過分詞後,每一個詞就是一個特徵。分詞中能夠本身配置停用詞庫,擴展詞庫等。特徵選擇有諸如TF-IDF,CHI等特徵選擇算法,就不在此贅述。spa
樸素貝葉斯計算先驗機率P(C)和條件機率P(X|C)的方法有兩種:多項式模型和伯努利模型。二者在計算的時候有兩點差異:多項式會統計詞頻,而伯努利認爲單詞出現就記爲1,沒出現記爲0,能夠看到一個是基於詞頻,一個是基於文檔頻率;伯努利在分類時,將詞庫中的沒有出如今待分類文本的詞做爲反方考慮。orm
在計算條件機率時,當待分類文本中的某個詞沒有出如今詞庫中時,機率爲0,會致使很嚴重的問題,須要考慮拉普拉斯平滑(laplace smoothing):它是將全部詞出現的次數+1,再進行統計。索引
再一個問題就是機率過小而詞數太多,會超double,用log將乘法轉成加法。事件
伯努利樸素貝葉斯算法僞代碼以下:內存
伯努利樸素貝葉斯代碼:文檔
/**
* @author zhongmin.yzm
* 語料訓練並載入內存
* */
public class TrainingDataManager {
/** 特徵索引 */
private Map<String, Integer> termIndex;
/** 類索引 */
private Map<String, Integer> classIndex;
/** 索引-類名 */
public List<String> className;
/**類的個數*/
private int numClasses = 0;
/**訓練樣本的全部特徵(出現屢次只算一個)*/
private int vocabulary = 0;
/**訓練文本總數*/
private int DocsNum = 0;
/**屬於某類的文檔個數*/
private int[] classDocs;
/**類別c中包含屬性 x的訓練文本數量*/
private int[][] classKeyMap;
/** 標誌位: 分類時的優化 */
private static boolean flag[];
private void buildIndex(List<List<String>> contents, List<String> labels) {
classIndex = new HashMap<String, Integer>();
termIndex = new HashMap<String, Integer>();
className = new ArrayList<String>();
Integer idTerm = new Integer(-1);
Integer idClass = new Integer(-1);
DocsNum = labels.size();
for (int i = 0; i < DocsNum; ++i) {
List<String> content = contents.get(i);
String label = labels.get(i);
if (!classIndex.containsKey(label)) {
idClass++;
classIndex.put(label, idClass);
className.add(label);
}
for (String term : content) {
if (!termIndex.containsKey(term)) {
idTerm++;
termIndex.put(term, idTerm);
}
}
}
vocabulary = termIndex.size();
numClasses = classIndex.size();
}
public void startTraining(List<List<String>> contents, List<String> labels) {
buildIndex(contents, labels);
//去重
List<List<Integer>> contentsIndex = new ArrayList<List<Integer>>();
for (int i = 0; i < DocsNum; ++i) {
List<Integer> contentIndex = new ArrayList<Integer>();
List<String> content = contents.get(i);
for (String str : content) {
Integer wordIndex = termIndex.get(str);
contentIndex.add(wordIndex);
}
Collections.sort(contentIndex);
int num = contentIndex.size();
List<Integer> tmp = new ArrayList<Integer>();
for (int j = 0; j < num; ++j) {
if (j == 0 || contentIndex.get(j - 1) != contentIndex.get(j)) {
tmp.add(contentIndex.get(j));
}
}
contentsIndex.add(tmp);
}
//
classDocs = new int[numClasses];
classKeyMap = new int[numClasses][vocabulary];
flag = new boolean[vocabulary];
for (int i = 0; i < DocsNum; ++i) {
List<Integer> content = contentsIndex.get(i);
String label = labels.get(i);
Integer labelIndex = classIndex.get(label);
classDocs[labelIndex]++;
for (Integer wordIndex : content) {
classKeyMap[labelIndex][wordIndex]++;
}
}
}
/** 分類 時間複雜度 O(c*v) */
public String classify(List<String> text) {
double maxPro = Double.NEGATIVE_INFINITY;
int resultIndex = 0;
//標記待分類文本中哪些特徵 屬於 特徵表
for (int i = 0; i < vocabulary; ++i)
flag[i] = false;
for (String term : text) {
Integer wordIndex = termIndex.get(term);
if (wordIndex != null)
flag[wordIndex] = true;
}
//對特徵集中的每一個特徵: 若出如今待分類文本中,直接計算;不然做爲反方參與
for (int classIndex = 0; classIndex < numClasses; ++classIndex) {
double pro = Math.log10(getPreProbability(classIndex));
for (int wordIndex = 0; wordIndex < vocabulary; ++wordIndex) {
if (flag[wordIndex])
pro += Math.log10(getClassConditionalProbability(classIndex, wordIndex));
else
pro += Math.log10(1 - getClassConditionalProbability(classIndex, wordIndex));
}
if (maxPro < pro) {
maxPro = pro;
resultIndex = classIndex;
}
}
return className.get(resultIndex);
}
/** 先驗機率: 類C包含的文檔數/總文檔數 */
private double getPreProbability(int classIndex) {
double ret = 0.0;
ret = 1.0 * classDocs[classIndex] / DocsNum;
return ret;
}
/** 條件機率: 類C中包含關鍵字t的文檔個數/類C包含的文檔數 */
private double getClassConditionalProbability(int classIndex, int termIndex) {
int NCX = classKeyMap[classIndex][termIndex];
int N = classDocs[classIndex];
double ret = (NCX + 1.0) / (N + DocsNum);
return ret;
}
}