本章主要講關於分類的一些機器學習知識點。我會按照如下關鍵點來總結本身的學習心得:(本文源碼在文末,請自行獲取)html
什麼是MNIST數據集git
MNIST數據集是一組由美國高中生和人口調查局員工手寫的70,000個數字圖片數據集。官方連接爲:http://yann.lecun.com/exdb/mnist/github
這組數據集X標籤是28*28大小的像素強度數值,y標籤是一個該圖像對應的一個真實數字。算法
咱們經過sklearn提供的函數能夠對該數據集進行下載:api
這個fetch_mldata會將名字爲MNIST original的數據集經過官方庫中的數據集下載下來,返回的是一個dict對象。(dict對象介紹)數組
通常童鞋應該是下不下來,文章底部的連接中有對應數據供你們直接下載。緩存
你們下載完畢後,執行下方代碼:安全
from sklearn.datasets.base import get_data_home
print (get_data_home())dom
查找到sklearn對於你的機器上的數據集緩存地址,將下載的文件中DataSets中的mnist-original.mat文件直接複製到顯示位置便可。機器學習
如今假定,咱們數據集已經處理完畢。
每次咱們針對數據集的觀察是必不可少的操做,所以先看看返回的dict對象中都有什麼對咱們是有好處的:
經過加載sklearn的數據集,一般包括:DESCR(描述數據集)、data(包含一個數組,每一個實例爲一行,每一個特徵爲一列,即咱們的x)、target(包含一個帶有標記的數組,即咱們的y)
加載其中某些內容:
看出X是一個70000*784的矩陣,也能得出咱們能夠訓練的實例有7萬個,每一個實例均可以表示成爲一個28*28的矩陣(784開根號爲28),對應是一個圖像。y是一個標籤,相應也有7萬個。
咱們將其中某一個數字畫出來,能夠更加直觀的表示:
matplotlib上一章心得說了一些使用介紹,這裏不詳細講了。這裏只說matplotlib中的imshow()函數。
imshow()函數功能就是針對提供的像素點,生成一張2維圖片。該函數的參數很是多,具體請看連接
從顯示的狀況看,這個圖像看上去更像是數字5,驗證一下:
猜的沒錯,是5。
在深刻研究該數據集以前,咱們首先應該爲之後的分類算法劃分測試與訓練集。
這裏須要說明的是,MNIST數據集已經幫助咱們分好了一個數據集劃分,前60,000個數據是訓練集,後10,000爲測試集,因此咱們直接劃分便可:
可是須要注意的是,這裏的每一個集合裏面的數據劃分很是有規律,按照0-9的順序排列,這對於咱們的訓練是存在問題的,咱們應該隨機將集合從新洗牌:numpy中的random類中的permutation函數能夠達到從新洗牌的目的:
具體使用請看連接。
二分類
若是咱們只須要檢測其中的某一個數字,那麼咱們能夠將「識別某個數字與否」這個問題當作是一個二分類問題。假設咱們這裏須要識別數字5,那麼最後識別出的結果就只有兩個:5或非5。咱們至關於須要構建一個數字5的檢測器。
首先,對該任務建立目標向量(其實就是咱們所說的y,這裏所作的操做是一個方便之後的處理,不作這樣的處理也是能夠的),這個是用來標識識別結果是否正確,標籤爲5的表示爲true,其餘爲False:
構建好目標向量以後,選擇一個分類算法構建咱們的分類器,這裏選擇隨機梯度降低分類器(SGDClassifier):
從以前全部引入的估算器,無論是決策樹、隨機森林、仍是線性迴歸、邏輯迴歸仍是這裏的隨即梯度降低,sklearn針對這些模型的初始化都很是的相似,都是首先導入,而後初始化該估算器類的構造函數,這裏判斷咱們是否須要針對算法的某些參數進行修改,若是採用默認則直接無參初始化,若是須要,則傳入須要修改的參數值,以後採用fit()函數,傳入訓練集的x與y,進行訓練,最後獲得一個訓練有素的分類器。
而後咱們經過predict()函數進行預測:
因爲some_digit對應的數字以前咱們發現是5,所以預測結果爲Ture。
二分類的性能評估與權衡
首先,採用交叉驗證測量精度。該方法的含義是使用將訓練集分紅K份,每次用K-1份進行訓練,剩下的1份作測試,每次算出一個精確度並輸出,咱們首先本身編寫:
還能夠用cross_val_score()函數來實現:
咱們能夠發現,每次的精確度都在90%以上,最後一次居然能到96%?個人模型這麼好嗎?!爲了讓咱們放下心來,咱們能夠作以下操做,設置一個很是笨的分類器,直接預測全部數字都不是5,咱們來看看精度:
看見了嗎,咱們認爲的最笨的分類器,居然都能達到90%的機率!有問題,絕對有問題!
我想,你已經明白了,這是由於咱們的數據中,每一個數字大概佔總數的10%左右,所以猜一張圖不是5的機率均可以達到90%!所以,咱們採用精確度這個指標是存在問題的。特別是,當咱們處理偏斜數據集(某類型數據很是多,數據分佈不平衡的數據集)時,精確度這個性能指標是絕對不能夠的。
咱們應該採用混淆矩陣的方法!
混淆矩陣,整體思路就是構成一個矩陣,這個矩陣記錄統計出A類別的實例被錯誤的分紅B類別實例的次數。具體看下圖我進行解釋:
(圖片源自:https://blog.csdn.net/Orange_Spotty_Cat/article/details/80520839)
咱們來看這張圖,咱們想看原本是貓但被誤認爲狗的狀況有多少次,從矩陣中,咱們就找真實值爲貓,預測值爲狗的對應行列便可,找到能夠看出是3次,這就是混淆矩陣。
咱們再昇華一下這個矩陣,從理論的高度解釋一下:
(圖片源自:https://blog.csdn.net/Orange_Spotty_Cat/article/details/80520839)
這裏的右下角的每一個白格子,都表明一種預測狀況。其實對應多分類的問題,也能夠轉換爲二分類問題,只須要將待分類的類別歸爲一類,剩下歸爲一類便可。
好了,這裏的TP、FP、FN、TN要解釋一下編碼的構成:預測是否錯誤+預測的類別:
先說這麼多,咱們先把概念記住,先上一波代碼,構成咱們須要的混淆矩陣吧。
要計算混淆矩陣,須要一組預測才能比較,但測試集永遠記住都是放在項目啓動後再進行,所以咱們仍是採用交叉驗證的方式進行,此次用cross_val_predict()函數:
這裏註明一下,再sklearn中生成的混淆矩陣中,行表示實際類別,列表示預測類別,與以前給出的那個矩陣行列正好相反,不過咱們只要能分出TP、TN、FP、FN就好。
從這個混淆矩陣中,咱們能夠發現53517張被正確預測爲「非5」的類別,4236張被正確預測爲「5」的類別,1185張被錯誤的分爲「非5」的類別,1062張被錯誤的分爲「5」類別。
最優的混淆矩陣應該是隻存在T開頭的屬性,其餘狀況的數字應該是0。
混淆矩陣能夠給咱們提供大量的信息,但若是指標若是能更簡介些,咱們能夠更加直觀的瞭解一些最須要的信息,接下來介紹精度、召回率的概念。
首先,給出精度的概念:
該公式計算的是正確被分爲正類佔預測器所有預測爲正類的數量的比率。爲何叫精度呢?咱們能夠這麼想,預測正類有這麼多,其中預測的對的數量又有這麼多,這不就是計算了一個精度嘛?
接下來,咱們談談召回率:
該公式計算的是正確被分爲正類的數量佔總正類數量(由於FN對應這些狀況的真實值應該是正類呀)的比率。爲何叫召回率呢?咱們能夠舉個例子來思考這個問題。好比車吧,車廠在出售車以前,須要對車進行一個檢測,咱們這裏定義的正類就是「存在問題的車」。固然存在有些車自己存在問題,但沒檢測出來就售出了,而後呢車廠發起召回問題車,這時候咱們就須要召回率來幫忙了!固然,咱們可想而知,召回率過高,估計老闆得氣死。。-。-言歸正傳,召回率也叫查全率,就是計算的是總共的正類中,我一共能正確分出多少的正類的比率。
上代碼,計算吧:
發現沒,這個分類器沒有以前那麼亮眼了,我就知道!!哼!!
咱們還有一個指標能夠結合上述兩個指標,那就是,,,,,F1分數!
上公式:
怎麼理解F1分數呢?F1分數是精度和召回率的諧波平均值。正常的平均值平等對待全部值,而諧波平均值會給予較低的值更高的權重。所以當召回率和精度都很高時,分類器才能獲得較高的F1分數,所以F1分數越高,能說明,咱們的系統更加穩健。
可是,咱們不該該把追求F1分數做爲咱們的最終目標,咱們要看實際的要求權衡召回率和精度,舉個例子讓你們明白:
假設,須要訓練一個分類器來檢測兒童能夠放心觀看的視頻,咱們應該本着寧肯錯殺100不可放過一個的目的,可能會攔截好多好視頻(低召回率),但確保保留下的視頻都是安全的(高精度)。相反,若是須要訓練一個分類器經過圖像監控檢測小偷,這個分類器應該本着無論這我的是不是不是小偷,當他作相似小偷的行爲時,就應該發出警報。因此咱們固然但願的是儘量抓住更多的小偷咯,所以咱們要求召回率要達到99%以上(可能會誤報不少次,但幾乎竊賊都在劫難逃!),但這樣的話,咱們的精度是會降低的。
接下來解釋一下爲什麼精度和召回率不能兼得:
讓咱們想一個問題,在極端的狀況下,假設你在查找一個問題的答案,若是要求精度很是高,那麼返回的結果就會不多,但都是你要的,若是要求召回率(查全率)很高,那麼返回的結果不少,但其中有不少的結果是你不須要的。所以,須要高精度就會致使低召回率,反之亦然。
返回到咱們的數字分類的問題上,咱們看一下SGDClassifier如何進行分類決策的。
這個分類器,對於每一個實例,它會基於決策函數計算一個分值,若是該值大於閾值,則斷定爲正類,不然爲負類。放個圖,解釋一下:
假設,咱們閾值在中間箭頭位置,在閾值的右側能夠找到4個真正類(四個5),一個假正類(一個6)。所以,在該閾值下,精度爲80%(4/5),召回率爲67%(4/6)。當咱們提升閾值,假正類的6就會變成負類,那麼精度會提高到100%(3/3),但一個真正類變成了一個假負類,召回率變爲50%(3/6)。
sklearn不容許直接設置閾值,可是能夠訪問它用於預測的決策分數,咱們能夠基於分數,使用閾值預測,上代碼:
咱們如何選擇閾值呢?咱們先獲取訓練集中全部實例的分數:
再使用precision_recall_curve()函數計算全部可能的閾值的精度和召回率,最後繪製精度和召回率相對於閾值的函數圖像:
還有一種與二元分類器一塊兒使用的工具:受試者工做特徵曲線(ROC),該曲線繪製的是召回率和假正類率(FPR:被錯誤分爲正類的負類實例比率,等於1-真負類率[被正確分類爲弗雷德負類實例比率,也稱爲特異度])。
使用roc_curve()函數計算多種閾值的TPR和FPR,而後繪製FPR對於TPR的曲線:
一樣,召回率越高,分類器產生的假正類越多,虛線表示純隨機分類器的ROC曲線,一個優秀的分類器應該離這條線越遠越好。有一個比較分類器的方法是測量曲線下面積(AUC):
那麼,問題來了,如何選擇指標呢?這裏直接引用書上的原話:當正類很是少見或者你更關注假正類而不是假負類時,你應該選擇PR曲線,反之ROC。
訓練一個隨機森林的分類器,並比較SGD分類器:
計算它的AUC得分:
它的精度和召回率以下:
從二元分類到多分類
咱們的數字分類問題,能夠分爲10個二分類問題,在檢測圖片分類時,能夠獲取每一個分類器的決策分數,哪一個分高就決定時哪一個數字。這是OvA策略,一對多。還有一種狀況能夠爲每一對數字訓練分類器,這稱爲OvO,一對一策略。
有些算法在數據規模擴大時,表現糟糕,對於這類算法,一對一是優先選擇,若是不是的話就一對多。
sklearn會檢測到你嘗試使用二元分類算法進行多類別的分類任務,它會自動進行OvA:
每一個類別得出的分數以及分類器分出的類別咱們均可以知道:
咱們也能夠進行一對一的分類策略,預測器的個數也能顯示出來:
至於評估分類器,提高準確率,這裏再也不贅述。只放代碼,本身看就好:
錯誤分析
這裏,假設咱們已經找到了一個有潛力的模型,咱們但願對該模型進行改進,咱們能夠分析其錯誤類型,幫助咱們。
首先能夠看看混淆矩陣:
我擦,數字有點多,看的不清楚,怎麼才能更形象的表示呢?咱們能夠將該矩陣可視化:
越白是表示數字越多,越黑表示數字越少。因爲大多數白色都在對角線上,因此咱們能夠認爲大部分的數字與圖片能夠正確分類。從局部上,咱們能夠發現,白色的部分也存在差別,數字2的白色就很好,而數字五、數字8看起來比其餘的要差一些,可能的緣由有:數據集中圖片較少,或者數字5在執行效果上不如其餘數字。
咱們把問題集中在錯誤上,將混淆矩陣中的每一個值除以相應類別中的圖片數,得出錯誤率:
而後用0填充對角線,只保留錯誤率,從新繪製:
目前,就能夠看出分類器產生的錯誤種類了。記住行爲實際類別,列爲預測類別。咱們能夠發現,第6列第4行和第4列第5行兩個格子很亮,說明分類器針對數字三、數字5容易混淆,其餘的白色格子,咱們能夠進行一樣的分析。
下面,咱們來具體看一下數字3和5分錯的圖片混淆矩陣。首先定義一個數字圖像顯示的函數:
上述定義的函數你們直接拷貝便可,若是想深刻研究其中的運行原理,請參考具體的文檔進行查看,不過也不會難,只是須要必定的邏輯就好。
而後對圖片進行一個顯示:
上圖中,左側兩個5*5的矩陣顯示的是被分類爲數字3的圖,右側兩個5*5的矩陣顯示了被分類爲5的圖。分類器弄錯的數字就是左下方和右上方的圖片。經過對比,咱們能夠發現有些數字用人腦來分辨也真的很容易分錯,所以知道分類器在具體的分類問題上的差別後,咱們就應該根據具體的問題,經過具體的手段進行修正。好比多采集數字3和數字5的數據,或者針對數字3和5的形狀結構,開發新的特徵來改進分類器,對圖像預處理,或者採用更高級的算法等等來解決問題。
多標籤分類、多輸出分類
咱們以前所做的分類器都會將一個實例分在一個類別中,而在某些狀況下,須要分類器爲每個實例產出多個類別。好比分類照片中的人像,一張照片可能存在不少我的,分類器就能夠輸出照片中的人對應的姓名。這種輸出多個二元標籤的分類系統成爲多標籤分類系統。
接下來,咱們針對這個問題,建立一個多標籤數組,這個數組包含數字圖片的目標標籤:第一個數字表示是否大數(7,8,9),另外一個表示該數是不是奇數。咱們這裏須要注意的是,不是全部算法都能支持多標籤的分類,咱們這裏選擇KNeighborsClassifier分類器進行分類:
咱們會發現,最後輸出的結果就是一個多標籤的結果。
最後,咱們討論一下多輸出分類。
這種類型的任務全稱應該叫多輸出—多類別分類,簡單來講是多標籤分類的泛化,其標籤也能夠是多種類別的。用一個例子來解釋:咱們如今須要針對含有噪聲的圖片進行去噪處理,就用咱們的手寫圖像爲例。咱們構造的這個系統輸出就是多個標籤(一個像素點一個標籤),每一個標籤能夠存在多個值(0-255)。這個任務比起分類任務來講,更像迴歸任務,可是咱們須要注意的是分類和迴歸的界限有時很模糊,咱們其實沒必要要必定要分的很細,適合咱們的目的就好,靈活掌握。
來,最後上一波代碼,說明上述的這個問題。固然,咱們首先要添加噪聲,而後編寫圖像顯示的函數,最後對比一下輸出就好:
第四章的心得總結就到這裏,我經過第四章的學習,將以前對召回率、精度、ROC曲線、F1分數等一系列概念有了更深的瞭解,但願經過本身的心得總結,可以幫助更多的人,謝謝。
源碼獲取:
一、個人Github:https://github.com/niufuquan1/MyStudy_For_sklearn_tensorflow/tree/master/ML_MNIST
二、百度雲:https://pan.baidu.com/s/1sX2ulOE7xgjJTzfVmA0N0Q (rxfv)