分類算法-3.多分類中的混淆矩陣

加載手寫識別數字數據集git

import numpy
from sklearn import datasets
import matplotlib.pyplot as plt 

digits = datasets.load_digits()
x = digits.data
y = digits.target

from sklearn.model_selection import train_test_split

x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.8,random_state=666)

用邏輯迴歸訓練算法

from sklearn.linear_model import LogisticRegression

log_reg = LogisticRegression()

# sklearn中默認使用OVR方式解決多分類問題
log_reg.fit(x_train,y_train)
y_predict = log_reg.predict(x_test)
log_reg.score(x_test,y_test)

查看多分類問題的混淆矩陣dom

from sklearn.metrics import confusion_matrix

cfm = confusion_matrix(y_test,y_predict)

將數據與灰度值對應起來:code

# cmap爲顏色映射,gray爲像素灰度值
plt.matshow(cfm,cmap=plt.cm.gray)

去除預測正確的對角線數據,查看混淆矩陣中的其餘值blog

row_sum = numpy.sum(cfm,axis=1)
err_matrix = cfm / row_sum
numpy.fill_diagonal(err_matrix,0)

plt.matshow(err_matrix,cmap=plt.cm.gray)

上圖不只能夠看出哪一個地方犯的錯誤多,還能夠看出是什麼樣的錯誤,例:算法會偏向於將值爲1的數據預測爲9,將值爲8的數預測爲1。
在算法方面,應該考慮調整一、八、9的決策閾值以加強算法的準確率。在手寫識別數據集方面,應該考慮處理數據,如消除數據集的噪點和干擾點,提升清晰度和可識別程度。get

相關文章
相關標籤/搜索