分析鳶尾花數據集

下面將結合Scikit-learn官網的邏輯迴歸模型分析鳶尾花示例,給你們進行詳細講解及拓展。因爲該數據集分類標籤劃分爲3類(0類、1類、2類),很好的適用於邏輯迴歸模型。python

1. 鳶尾花數據集

在Sklearn機器學習包中,集成了各類各樣的數據集,包括前面的糖尿病數據集,這裏引入的是鳶尾花卉(Iris)數據集,它是很經常使用的一個數據集。鳶尾花有三個亞屬,分別是山鳶尾(Iris-setosa)、變色鳶尾(Iris-versicolor)和維吉尼亞鳶尾(Iris-virginica)。算法

該數據集一共包含4個特徵變量,1個類別變量。共有150個樣本,iris是鳶尾植物,這裏存儲了其萼片和花瓣的長寬,共4個屬性,鳶尾植物分三類。如表17.2所示:

數組

 

iris裏有兩個屬性iris.data,iris.target。data是一個矩陣,每一列表明瞭萼片或花瓣的長寬,一共4列,每一列表明某個被測量的鳶尾植物,一共採樣了150條記錄。
from sklearn.datasets import load_iris   #導入數據集iris
iris = load_iris()  #載入數據集
print iris.data
輸出以下所示:
[[ 5.1  3.5  1.4  0.2]
 [ 4.9  3.   1.4  0.2]
 [ 4.7  3.2  1.3  0.2]
 [ 4.6  3.1  1.5  0.2]
 ....
 [ 6.7  3.   5.2  2.3]
 [ 6.3  2.5  5.   1.9]
 [ 6.5  3.   5.2  2. ]
 [ 6.2  3.4  5.4  2.3]
 [ 5.9  3.   5.1  1.8]]
target是一個數組,存儲了data中每條記錄屬於哪一類鳶尾植物,因此數組的長度是150,數組元素的值由於共有3類鳶尾植物,因此不一樣值只有3個。種類爲山鳶尾、雜色鳶尾、維吉尼亞鳶尾。
print iris.target          #輸出真實標籤
print len(iris.target)      #150個樣本 每一個樣本4個特徵
print iris.data.shape  

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
150
(150L, 4L)
從輸出結果能夠看到,類標共分爲三類,前面50個類標位0,中間50個類標位1,後面爲2。下面給詳細介紹使用決策樹進行對這個數據集進行測試的代碼。

2. 散點圖繪製

下列代碼主要是載入鳶尾花數據集,包括數據data和標籤target,而後獲取其中兩列數據或兩個特徵,核心代碼爲:X = [x[0] for x in DD],獲取的值賦值給X變量,最後調用scatter()函數繪製散點圖。機器學習

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_iris    #導入數據集iris
  
#載入數據集  
iris = load_iris()  
print iris.data          #輸出數據集  
print iris.target         #輸出真實標籤  
#獲取花卉兩列數據集  
DD = iris.data  
X = [x[0] for x in DD]  
print X  
Y = [x[1] for x in DD]  
print Y  
  
#plt.scatter(X, Y, c=iris.target, marker='x')
plt.scatter(X[:50], Y[:50], color='red', marker='o', label='setosa') #前50個樣本
plt.scatter(X[50:100], Y[50:100], color='blue', marker='x', label='versicolor') #中間50個
plt.scatter(X[100:], Y[100:],color='green', marker='+', label='Virginica') #後50個樣本
plt.legend(loc=2) #左上角
plt.show()
繪製散點圖如圖所示:

3. 邏輯迴歸分析

從圖中能夠看出,數據集線性可分的,能夠劃分爲3類,分別對應三種類型的鳶尾花,下面採用邏輯迴歸對其進行分類預測。前面使用X=[x[0] for x in DD]獲取第一列數據,Y=[x[1] for x in DD]獲取第二列數據,這裏採用另外一種方法,iris.data[:, :2]獲取其中兩列數據(兩個特徵),完整代碼以下:函數

 

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_iris   
from sklearn.linear_model import LogisticRegression 

#載入數據集
iris = load_iris()         
X = X = iris.data[:, :2]   #獲取花卉兩列數據集
Y = iris.target           

#邏輯迴歸模型
lr = LogisticRegression(C=1e5)  
lr.fit(X,Y)

#meshgrid函數生成兩個網格矩陣
h = .02
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

#pcolormesh函數將xx,yy兩個網格矩陣和對應的預測結果Z繪製在圖片上
Z = lr.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.figure(1, figsize=(8,6))
plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired)

#繪製散點圖
plt.scatter(X[:50,0], X[:50,1], color='red',marker='o', label='setosa')
plt.scatter(X[50:100,0], X[50:100,1], color='blue', marker='x', label='versicolor')
plt.scatter(X[100:,0], X[100:,1], color='green', marker='s', label='Virginica') 

plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.xticks(())
plt.yticks(())
plt.legend(loc=2) 
plt.show()
下面做者對導入數據集後的代碼進行詳細講解。

lr = LogisticRegression(C=1e5)  
lr.fit(X,Y)
初始化邏輯迴歸模型並進行訓練,C=1e5表示目標函數。

x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
獲取的鳶尾花兩列數據,對應爲花萼長度和花萼寬度,每一個點的座標就是(x,y)。 先取X二維數組的第一列(長度)的最小值、最大值和步長h(設置爲0.02)生成數組,再取X二維數組的第二列(寬度)的最小值、最大值和步長h生成數組, 最後用meshgrid函數生成兩個網格矩陣xx和yy,以下所示:

 

[[ 3.8   3.82  3.84 ...,  8.36  8.38  8.4 ]
 [ 3.8   3.82  3.84 ...,  8.36  8.38  8.4 ]
 ..., 
 [ 3.8   3.82  3.84 ...,  8.36  8.38  8.4 ]
 [ 3.8   3.82  3.84 ...,  8.36  8.38  8.4 ]]
[[ 1.5   1.5   1.5  ...,  1.5   1.5   1.5 ]
 [ 1.52  1.52  1.52 ...,  1.52  1.52  1.52]
 ..., 
 [ 4.88  4.88  4.88 ...,  4.88  4.88  4.88]
 [ 4.9   4.9   4.9  ...,  4.9   4.9   4.9 ]]

Z = lr.predict(np.c_[xx.ravel(), yy.ravel()])
調用ravel()函數將xx和yy的兩個矩陣轉變成一維數組,因爲兩個矩陣大小相等,所以兩個一維數組大小也相等。np.c_[xx.ravel(), yy.ravel()]是獲取矩陣,即:工具

xx.ravel() 
[ 3.8   3.82  3.84 ...,  8.36  8.38  8.4 ]
yy.ravel() 
[ 1.5  1.5  1.5 ...,  4.9  4.9  4.9]
np.c_[xx.ravel(), yy.ravel()]
[[ 3.8   1.5 ]
 [ 3.82  1.5 ]
 [ 3.84  1.5 ]
 ..., 
 [ 8.36  4.9 ]
 [ 8.38  4.9 ]
 [ 8.4   4.9 ]]

總結下:上述操做是把第一列花萼長度數據按h取等分做爲行,並複製多行獲得xx網格矩陣;再把第二列花萼寬度數據按h取等分,做爲列,並複製多列獲得yy網格矩陣;最後將xx和yy矩陣都變成兩個一維數組,調用np.c_[]函數組合成一個二維數組進行預測。
調用predict()函數進行預測,預測結果賦值給Z。即:學習

Z = logreg.predict(np.c_[xx.ravel(), yy.ravel()])
[1 1 1 ..., 2 2 2]
size: 39501

Z = Z.reshape(xx.shape)
調用reshape()函數修改形狀,將其Z轉換爲兩個特徵(長度和寬度),則39501個數據轉換爲171*231的矩陣。Z = Z.reshape(xx.shape)輸出以下:測試

[[1 1 1 ..., 2 2 2]
 [1 1 1 ..., 2 2 2]
 [0 1 1 ..., 2 2 2]
 ..., 
 [0 0 0 ..., 2 2 2]
 [0 0 0 ..., 2 2 2]
 [0 0 0 ..., 2 2 2]]

plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired)
調用pcolormesh()函數將xx、yy兩個網格矩陣和對應的預測結果Z繪製在圖片上,能夠發現輸出爲三個顏色區塊,分佈表示分類的三類區域。cmap=plt.cm.Paired表示繪圖樣式選擇Paired主題。輸出的區域以下圖所示:

code

 

plt.scatter(X[:50,0], X[:50,1], color='red',marker='o', label='setosa')
調用scatter()繪製散點圖,第一個參數爲第一列數據(長度),第二個參數爲第二列數據(寬度),第3、四個參數爲設置點的顏色爲紅色,款式爲圓圈,最後標記爲setosa。

輸出以下圖所示,通過邏輯迴歸後劃分爲三個區域,左上角部分爲紅色的圓點,對應setosa鳶尾花;右上角部分爲綠色方塊,對應virginica鳶尾花;中間下部分爲藍色星形,對應versicolor鳶尾花。散點圖爲各數據點真實的花類型,劃分的三個區域爲數據點預測的花類型,預測的分類結果與訓練數據的真實結果結果基本一致,部分鳶尾花出現交叉。

orm

 

迴歸算法做爲統計學中最重要的工具之一,它經過創建一個迴歸方程用來預測目標值,並求解這個迴歸方程的迴歸係數。本篇文章詳細講解了邏輯迴歸模型的原理知識,結合Sklearn機器學習庫的LogisticRegression算法分析了鳶尾花分類狀況。更多知識點但願讀者下來後進行拓展,也推薦大學從Sklearn開源知識官網學習最新的實例。

相關文章
相關標籤/搜索