[譯] Sklearn 中的樸素貝葉斯分類器

用豆機實現的高斯分佈html

這篇教程詳述了樸素貝葉斯分類器的算法、它的原理優缺點,並提供了一個使用 Sklearn 庫的示例。前端

背景

以著名的泰坦尼克號遇難者數據集爲例。它收集了泰坦尼克號的乘客的我的信息以及是否從那場海難中生還。讓咱們試着用乘客的船票費用來預測一下他可否生還。python

泰坦尼克號上的 500 名乘客android

假設你隨機取了 500 名乘客。在這些樣本中,30% 的人倖存下來。倖存乘客的平均票價爲 100 美圓,而遇難乘客的平均票價爲 50 美圓。如今,假設你有了一個新的乘客。你不知道他是否倖存,但你知道他買了一張 30 美圓的票穿越大西洋。請你預測一下這個乘客是否倖存。ios

原理

好吧,你可能回答說這個乘客沒能倖存。爲何?由於根據上文所取的乘客的隨機子集中所包含的信息,原本的生還概率就很低(30%),而窮人的生還概率則更低。你會把這個乘客放在最可能的組別(低票價組)。這就是樸素貝葉斯分類器所要實現的。git

分析

樸素貝葉斯分類器利用條件機率來彙集信息,並假設特徵之間相對獨立。這是什麼意思呢?舉個例子,這意味着咱們必須假定泰坦尼克號的房間溫馨度與票價無關。顯然這個假設是錯誤的,這就是爲何咱們將這個假設稱爲樸素(Naive)的緣由。樸素假設使得計算得以簡化,即便在很是大的數據集上也是如此。讓咱們來一探究竟。github

樸素貝葉斯分類器本質上是尋找能描述給定特徵條件下屬於某個類別的機率的函數,這個函數寫做 P(Survival | f1,…, fn)。咱們使用貝葉斯定理來簡化計算:算法

式 1:貝葉斯定理後端

P(Survival) 很容易計算,而咱們構建分類器也不須要用到 P(f1,…, fn),所以問題回到計算 P(f1,…, fn | Survival) 上來。咱們應用條件機率公式來再一次簡化計算:bash

式 2:初步拓展

上式最後一行的每一項的計算都須要一個包含全部條件的數據集。爲了計算 {Survival, f_1, …, f_n-1} 條件下 fn 的機率(即 P(fn | Survival, f_1, …, f_n-1)),咱們須要有足夠多不一樣的知足條件 {Survival, f_1, …, f_n-1} 的 fn 值。這會須要大量的數據,並致使維度災難。這時樸素假設(Naive Assumption)的好處就凸顯出來了。假設特徵是獨立的,咱們能夠認爲條件 {Survival, f_1, …, f_n-1} 的機率等於 {Survival} 的機率,以此來簡化計算:

式 3:應用樸素假設

最後,爲了分類,新建一個特徵向量,咱們只須要選擇是否生還的值(1 或 0),令 P(f1, …, fn|Survival) 最高,即爲最終的分類結果:

式 4:argmax 分類器

注意:常見的錯誤是認爲分類器輸出的機率是對的。事實上,樸素貝葉斯被稱爲差估計器,因此不要太認真地看待這些輸出機率。

找出合適的分佈函數

最後一步就是實現分類器。怎樣爲機率函數 P(f_i| Survival) 創建模型呢?在 Sklearn 庫中有三種模型:

正態分佈

二項式分佈

Python 代碼

接下來,基於泰坦尼克遇難者數據集,咱們實現了一個經典的高斯樸素貝葉斯。咱們將使用船艙等級、性別、年齡、兄弟姐妹數目、父母/子女數量、票價和登船口岸這些信息。

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import time
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB

# 導入數據集
data = pd.read_csv("data/train.csv")

# 將分類變量轉換爲數字
data["Sex_cleaned"]=np.where(data["Sex"]=="male",0,1)
data["Embarked_cleaned"]=np.where(data["Embarked"]=="S",0,
                                  np.where(data["Embarked"]=="C",1,
                                           np.where(data["Embarked"]=="Q",2,3)
                                          )
                                 )
# 清除數據集中的非數字值(NaN)
data=data[[
    "Survived",
    "Pclass",
    "Sex_cleaned",
    "Age",
    "SibSp",
    "Parch",
    "Fare",
    "Embarked_cleaned"
]].dropna(axis=0, how='any')

# 將數據集拆分紅訓練集和測試集
X_train, X_test = train_test_split(data, test_size=0.5, random_state=int(time.time()))
複製代碼
# 實例化分類器
gnb = GaussianNB()
used_features =[
    "Pclass",
    "Sex_cleaned",
    "Age",
    "SibSp",
    "Parch",
    "Fare",
    "Embarked_cleaned"
]

# 訓練分類器
gnb.fit(
    X_train[used_features].values,
    X_train["Survived"]
)
y_pred = gnb.predict(X_test[used_features])

# 打印結果
print("Number of mislabeled points out of a total {} points : {}, performance {:05.2f}%"
      .format(
          X_test.shape[0],
          (X_test["Survived"] != y_pred).sum(),
          100*(1-(X_test["Survived"] != y_pred).sum()/X_test.shape[0])
))
複製代碼

Number of mislabeled points out of a total 357 points: 68, performance 80.95%

這個分類器的正確率爲 80.95%

使用單個特徵說明

讓咱們試着只使用票價信息來約束分類器。下面咱們計算 P(Survival = 1) 和 P(Survival = 0) 的機率:

mean_survival=np.mean(X_train["Survived"])
mean_not_survival=1-mean_survival
print("Survival prob = {:03.2f}%, Not survival prob = {:03.2f}%"
      .format(100*mean_survival,100*mean_not_survival))
複製代碼

Survival prob = 39.50%, Not survival prob = 60.50%

而後,根據式 3,咱們只須要得出機率分佈函數 P(fare| Survival = 0) 和 P(fare| Survival = 1)。咱們選用高斯樸素貝葉斯分類器,所以,必須假設數據按高斯分佈。

式 5:高斯公式(σ:標準差 / μ:均值)

而後,咱們須要算出是否生還值不一樣的狀況下,票價數據集的均值和標準差。咱們獲得如下結果:

mean_fare_survived = np.mean(X_train[X_train["Survived"]==1]["Fare"])
std_fare_survived = np.std(X_train[X_train["Survived"]==1]["Fare"])
mean_fare_not_survived = np.mean(X_train[X_train["Survived"]==0]["Fare"])
std_fare_not_survived = np.std(X_train[X_train["Survived"]==0]["Fare"])

print("mean_fare_survived = {:03.2f}".format(mean_fare_survived))
print("std_fare_survived = {:03.2f}".format(std_fare_survived))
print("mean_fare_not_survived = {:03.2f}".format(mean_fare_not_survived))
print("std_fare_not_survived = {:03.2f}".format(std_fare_not_survived))
複製代碼
mean_fare_survived = 54.75
std_fare_survived = 66.91
mean_fare_not_survived = 24.61
std_fare_not_survived = 36.29
複製代碼

讓咱們看看關於生還未生還的直方圖的結果分佈:

圖 1:各個是否生還值的票價直方圖和高斯分佈(縮放等級並不對應)

能夠發現,分佈與數據集並無很好地擬合。在實現模型以前,最好驗證特徵分佈是否遵循上述三種模型中的一種。若是連續特徵不具備正態分佈,則應使用變換或不一樣的方法將其轉換成正態分佈。爲了便於說明,這咱們將分佈看做是正態的。應用式 1 貝葉斯定理,可得如下這個分類器:

圖 2:高斯分類器

若是票價分類器的值超過 78(classifier(Fare) ≥ ~78),則 P(fare| Survival = 1) ≥ P(fare| Survival = 0),咱們將這我的歸類爲生還。不然咱們就將他歸爲未生還。咱們獲得了一個正確率爲 64.15% 的分類器。

若是咱們在同一數據集上訓練 Sklearn 高斯樸素貝葉斯分類器,將會獲得徹底相同的結果:

from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()
used_features =["Fare"]
y_pred = gnb.fit(X_train[used_features].values, X_train["Survived"]).predict(X_test[used_features])
print("Number of mislabeled points out of a total {} points : {}, performance {:05.2f}%"
      .format(
          X_test.shape[0],
          (X_test["Survived"] != y_pred).sum(),
          100*(1-(X_test["Survived"] != y_pred).sum()/X_test.shape[0])
))
print("Std Fare not_survived {:05.2f}".format(np.sqrt(gnb.sigma_)[0][0]))
print("Std Fare survived: {:05.2f}".format(np.sqrt(gnb.sigma_)[1][0]))
print("Mean Fare not_survived {:05.2f}".format(gnb.theta_[0][0]))
print("Mean Fare survived: {:05.2f}".format(gnb.theta_[1][0]))
複製代碼
Number of mislabeled points out of a total 357 points: 128, performance 64.15%
Std Fare not_survived 36.29
Std Fare survived: 66.91
Mean Fare not_survived 24.61
Mean Fare survived: 54.75
複製代碼

樸素貝葉斯分類器的優缺點

優勢:

  • 計算迅速
  • 實現簡單
  • 在小數據集上表現良好
  • 在高維度數據上表現良好
  • 即便樸素假設沒有徹底知足,也能表現良好。在許多狀況下,創建一個好的分類器只須要近似的數據就夠了。

缺點:

  • 須要移除相關特徵,由於它們會在模型中被計算兩次,這將致使該特徵的重要性被高估。
  • 若是測試集中,某分類變量的一個類別沒有在訓練集中出現過,那麼模型會把這種狀況設爲零機率。它將沒法作出預測。這一般被稱爲『零位頻率』。咱們可使用平滑技術來解決這個問題。最簡單的平滑技術之一稱爲拉普拉斯平滑。當你訓練一個樸素貝葉斯分類器時,Sklearn 會默認使用拉普拉斯平滑算法。

結語

很是感謝你閱讀這篇文章。我但願它能幫助你理解樸素貝葉斯分類器的概念以及它的優勢

致謝 Antoine ToubhansFlavian HautboisAdil BaajRaphaël Meudec

若是發現譯文存在錯誤或其餘須要改進的地方,歡迎到 掘金翻譯計劃 對譯文進行修改並 PR,也可得到相應獎勵積分。文章開頭的 本文永久連接 即爲本文在 GitHub 上的 MarkDown 連接。


掘金翻譯計劃 是一個翻譯優質互聯網技術文章的社區,文章來源爲 掘金 上的英文分享文章。內容覆蓋 AndroidiOS前端後端區塊鏈產品設計人工智能等領域,想要查看更多優質譯文請持續關注 掘金翻譯計劃官方微博知乎專欄

相關文章
相關標籤/搜索