精準率和召回率是兩個不一樣的評價指標,不少時候它們之間存在着差別,具體在使用的時候如何解讀精準率和召回率,應該視具體使用場景而定算法
有些場景,人們可能更注重精準率,如股票預測系統,咱們定義股票升爲1,股票降爲0,咱們更關心的是將來升的股票的比例,而在另一些場景中,人們更加註重召回率,如癌症預測系統,定義健康爲1,患病爲0,咱們更關心癌症患者檢查的遺漏狀況。編程
F1 Score 兼顧精準率和召回率,它是二者的調和平均值app
\[\frac{1}{F1} = \frac{1}{2}(\frac{1}{Precision} + \frac{1}{recall})\]
\[F1 = \frac{2\cdot precision\cdot recall}{precision+recall}\]
定義F1 Score測試
def f1_score(precision,recall): try: return 2*precision*recall/(precision+recall) except: return 0
由上看出,F1 Score更偏向於分數小的那個指標spa
精準率和召回率是兩個互相矛盾的目標,提升一個指標,另外一個指標就會不可避免的降低。如何達到二者之間的一個平衡呢?code
回憶邏輯迴歸算法的原理:將一個結果發生的機率大於0.5,就把它分類爲1,發生的機率小於0.5,就把它分類爲0,決策邊界爲:\(\theta ^T \cdot X_b = 0\)blog
這條直線或曲線決定了分類的結果,平移決策邊界,使\(\theta ^T \cdot X_b\)不等於0而是一個閾值:\(\theta ^T \cdot X_b = threshold\)ci
圓形表明分類結果爲0,五角星表明分類結果爲1,由上圖能夠看出,精準率和召回率是兩個互相矛盾的指標,隨着閾值的逐漸增大,召回率逐漸下降,精準率逐漸增大。it
編程實現不一樣閥值下的預測結果及混淆矩陣io
from sklearn.linear_model import LogisticRegression # 數據使用前一節處理後的手寫識別數據集 log_reg = LogisticRegression() log_reg.fit(x_train,y_train)
求每一個測試數據在邏輯迴歸算法中的score值:
decision_score = log_reg.decision_function(x_test)
不一樣閥值下預測的結果
y_predict_1 = numpy.array(decision_score>=-5,dtype='int') y_predict_2 = numpy.array(decision_score>=0,dtype='int') y_predict_3 = numpy.array(decision_score>=5,dtype='int')
查看不一樣閾值下的混淆矩陣:
求出0.1步長下,閾值在[min,max]區間下的精準率和召回率,查看其曲線特徵:
threshold_scores = numpy.arange(numpy.min(decision_score),numpy.max(decision_score),0.1) precision_scores = [] recall_scores = [] # 求出每一個分類閾值下的預測值、精準率和召回率 for score in threshold_scores: y_predict = numpy.array(decision_score>=score,dtype='int') precision_scores.append(precision_score(y_test,y_predict)) recall_scores.append(recall_score(y_test,y_predict))
畫出精準率和召回率隨閾值變化的曲線
plt.plot(threshold_scores,precision_scores) plt.plot(threshold_scores,recall_scores) plt.show()
畫出精準率-召回率曲線
plt.plot(precision_scores,recall_scores) plt.show()
from sklearn.metrics import precision_recall_curve precisions,recalls,thresholds = precision_recall_curve(y_test,decision_score) # sklearn中最後一個精準率爲1,召回率爲0,沒有對應的threshold plt.plot(thresholds,precisions[:-1]) plt.plot(thresholds,recalls[:-1]) plt.show()