純Python實現鳶尾屬植物數據集神經網絡模型

摘要: 本文以Python代碼完成整個鸞尾花圖像分類任務,沒有調用任何的數據包,適合新手閱讀理解,並動手實踐體驗下機器學習方法的大體流程。

嘗試使用過各大公司推出的植物識別APP嗎?好比微軟識花、花伴侶等這些APP。當你看到一朵不知道學名的花時,只須要打開植物識別APP,拍攝一張你所想辨認的植物照片並上傳,APP會自動識別出該花的品種及詳細介紹,感受手機中裝了一個知識淵博的生物學家,是否是很神奇?其實,背後的原理很簡單,是一個圖像分類的過程,將上傳的圖像與手機中預存的數據集或聯網數據進行匹配,將其分類到對應的類別便可。隨着深度學習方法的應用,圖像分類的精度愈來愈高,在部分數據集上已經超越了人眼的能力。python

相對於傳統神經網絡的方法而言,深度學習方法通常對數據集規模、硬件平臺有着比較高的要求,若是隻是單純的想嘗試瞭解圖像分類任務的基本流程,建議採用小數據集樣本及傳統的神經網絡方法實現。本文將帶領讀者採用鳶尾屬植物數據集(Iris Data Set)來實現一個分類任務,整個鳶尾屬植物數據集是機器學習中歷史悠久的數據集,比如今經常使用的數字手寫體數據集(Mnist Data Set)數據集還要早得多,該數據集來源於英國著名的統計學家、生物學家Ronald Fiser。本文在不使用相關軟件庫的狀況下,從頭開始構建針對鳶尾屬植物數據的神經網絡模型,對其進行訓練並得到好的結果。算法

clipboard.png

鳶尾屬植物數據集是用於測試機器學習算法的最經常使用數據集。該數據包含四種特徵,萼片長度、萼片寬度、花瓣長度和花瓣寬度,用於鳶尾屬植物的不一樣物種(versicolor, virginica和setosa)。此外,每一個物種有50個實例(數據行),下面讓咱們看看樣本數據分佈狀況。數組

clipboard.png

咱們將在這個數據集上使用神經網絡構建分類模型。爲了簡單起見,使用花瓣長度和花瓣寬度做爲特徵,且只有兩類物種:versicolor和virginica。下面就讓咱們在Python中逐步訓練針對該樣本數據集的神經網絡:網絡

步驟1:準備鳶尾屬植物數據集

將Iris數據集導入python並對數據進行子集劃分以保留行之間的相關性:機器學習

clipboard.png

clipboard.png

藍色點表明Versicolor物種,紅色點表明Virginica物種。本文構建的神經網絡將在這些數據上進行訓練,以期最後能正確地分類物種。函數

步驟2:初始化參數(權重和偏置)

下面構建一個具備單個隱藏層的神經網絡。此外,將隱藏圖層的大小設置爲6:學習

clipboard.png

步驟3:前向傳播(forward propagation)

在前向傳播過程當中,使用tanh激活函數做爲第一層的激活函數,使用sigmoid激活函數做爲第二層的激活函數:測試

clipboard.png

步驟4:計算代價函數(cost function)

目標是使得計算的代價函數小化,本文采用交叉熵(cross-entropy)做爲代價函數:spa

clipboard.png

步驟5:反向傳播(back propagation)

計算反向傳播過程,主要是計算代價函數的導數:設計

clipboard.png

步驟6:更新參數

使用反向傳播過程當中計算的梯度來更新權重和偏置:

clipboard.png

步驟7:創建神經網絡

將以上全部函數組合起來以建立設計的神經網絡模型。總而言之,下面是模型函數的總體順序:

一、初始化參數

二、前向傳播

三、計算代價函數

四、反向傳播

五、更新參數

clipboard.png

步驟8:跑動模型

將隱藏層節點設置爲6,最大迭代次數設置爲10,000次,並每隔1000次打印出訓練的結果:

clipboard.png

clipboard.png

步驟9:畫出分類邊界

clipboard.png

clipboard.png

從圖中能夠觀察到,只有四個點被錯誤分類。雖然咱們能夠調整模型來進一步地提升模型訓練精度,但該些操做顯然會致使過擬合現象的出現。

資源

本文做者:【方向】

閱讀原文

本文爲雲棲社區原創內容,未經容許不得轉載。

相關文章
相關標籤/搜索