【scikit-learn】01:使用案例對sklearn庫進行簡單介紹

sklearn學習筆記:Quick Start 
源地址:http://scikit-learn.org/stable/tutorial/basic/tutorial.htmlhtml

# -*-coding:utf-8-*-
'''
    Author:kevinelstri
    Datetime:2017.2.16
'''
# -----------------------
# An introduction to machine learning with scikit-learn
# http://scikit-learn.org/stable/tutorial/basic/tutorial.html
# -----------------------

'''
    經過使用sklearn,簡要介紹機器學習,並給出一個簡單的例子
'''

'''
    Machine learning: the problem setting
'''

'''
    機器學習:
        就是對數據的一系列樣本進行分析,來預測數據的未知結果。
    監督學習:
        數據的預測來自於對已有的數據進行分析,進而對新增的數據進行預測。
        監督學習能夠劃分爲兩類:分類和迴歸
    非監督學習:
        訓練數據由一系列沒有標籤的數據構成,目的就是發現這組數據中的類似性,也稱做聚類。
        或者來發現數據的分佈狀況,稱爲密度估計。

    訓練數據集、測試數據集:
        機器學習就是經過學習一組數據,來將結果應用於一組新的數據中。
        將一組數據劃分爲兩個集合,一個稱爲訓練集,一個稱爲測試集。
'''

'''
    Loading an example dataset
'''

'''
    sklearn 有一些標準的數據集,iris,digits 數據集用於分類,boston house prices 數據集用於迴歸
'''
from sklearn import datasets

iris = datasets.load_iris()  # 加載iris數據集
digits = datasets.load_digits()  # 加載digits數據集
# print iris
# print digits
'''
    數據集是一個相似字典的對象,它保存全部的數據和一些有關數據的元數據。
    數據存儲在.data中,這是一個(n_sample, n_features)數組。
    在監督學習問題中,多個變量存儲在.target中。
    data:數據
    target:標籤

    n_sample:樣本數量
    n_features:預測結果的數量
'''

print 'digits.data:', digits.data  # 用來分類樣本的特徵
print 'digits.target:', digits.target  # 給出了digits數據集的真實值,就是每一個數字圖案對應的想預測的真實數字

print 'iris.data:', iris.data
print 'iris.target:', iris.target

print digits.images[0]
print digits.images

'''
    Recognizing hand-written digits
'''
import matplotlib.pyplot as plt
from sklearn import datasets, svm, metrics

digits = datasets.load_digits()  # 加載數據集
'''
    digits數據集中每個數據都是一個8*8的矩陣
'''
images_and_labels = list(zip(digits.images, digits.target))  # 每一個數據集都與標籤對應,使用zip()函數構成字典
for index, (image, label) in enumerate(images_and_labels[:4]):
    plt.subplot(2, 4, index + 1)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Training:%i' % label)

n_samples = len(digits.images)  # 樣本的數量
print n_samples
data = digits.images.reshape((n_samples, -1))

classifier = svm.SVC(gamma=0.001)  # svm預測器

classifier.fit(data[:n_samples / 2], digits.target[:n_samples / 2])  # 使用數據集的一半進行訓練數據

expected = digits.target[n_samples / 2:]
predicted = classifier.predict(data[n_samples / 2:])  # 預測剩餘的數據

print("Classification report for classifier %s:\n%s\n"
      % (classifier, metrics.classification_report(expected, predicted)))
print("Confusion matrix:\n%s" % metrics.confusion_matrix(expected, predicted))

images_and_predictions = list(zip(digits.images[n_samples / 2:], predicted))  # 圖片與預測結果按照字典方式對應
for index, (image, prediction) in enumerate(images_and_predictions[:4]):
    plt.subplot(2, 4, index + 5)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')  # 展現圖片
    plt.title('Prediction: %i' % prediction)  # 標題

# plt.show()

'''
    Learning and predicting
'''
'''
    digits數據集,就是給定一個圖案,預測其表示的數字是什麼。
    樣本共有10個可能的分類(0-9),經過匹配(fit)預測器(estimator)來預測(predict)未知樣本所屬的分類。
    sklearn中,分類的預測器就是爲了實現fit(X,y)和predict(T)兩個方法(匹配和預測)。

    fit(X,y):訓練數據
    predict(T):預測數據

    預測器sklearn.svm.SVC,就是爲了實現支持向量機分類
'''
from sklearn import svm

clf = svm.SVC(gamma=0.001, C=100.)
'''
    預測器的名字是clf,這是一個分類器,它必須進行模型匹配(fit),也就是說,必須從模型中學習。
    從模型學習的過程,模型匹配的過程,是經過將訓練集傳遞給fit方法來實現的。

    本次實驗中將除了最後一個樣本的數據所有做爲訓練集[:-1]
'''
print clf.fit(digits.data[:-1], digits.target[:-1])  # 對前面全部的數據進行訓練
print clf.predict(digits.data[-1:])  # 對最後一個數據進行預測

'''
    Model persistence
        使用pickle保存訓練過的模型
'''
from sklearn import svm
from sklearn import datasets

clf = svm.SVC()  # 構造預測器
iris = datasets.load_iris()  # 加載數據集
X, y = iris.data, iris.target  # 數據的樣本數和結果數
clf.fit(X, y)  # 訓練數據

import pickle

s = pickle.dumps(clf)  # 保存訓練模型
clf2 = pickle.loads(s)  # 加載訓練模型
print clf2.predict(X[0:1])  # 應用訓練模型

# 在scikit下,能夠使用joblib's(joblib.dump, joblib.load)來代替pickle
from sklearn.externals import joblib

joblib.dump(clf, 'filename.pkl')  # 保存訓練模型
clf = joblib.load('filename.pkl')  # 加載數據模型
print clf.predict(X[0:1])  # 應用訓練模型

'''
    Conventions
'''
from sklearn import datasets
from sklearn.svm import SVC

iris = datasets.load_iris()
clf = SVC()
clf.fit(iris.data, iris.target)
print list(clf.predict(iris.data[:3]))  # output:[0,0,0]
# 因爲iris.target是整型數組,因此這裏的predict()返回的也是整型數組

clf.fit(iris.data, iris.target_names[iris.target])
print list(clf.predict(iris.data[:3]))  # output:['setosa', 'setosa', 'setosa']
# 這裏iris.target_names是字符串名字,因此predict()返回的也是字符串

'''
    Refitting and updating parameters
'''
import numpy as np
from sklearn.svm import SVC

rng = np.random.RandomState(0)
X = rng.rand(100, 10)
y = rng.binomial(1, 0.5, 100)
X_test = rng.rand(5, 10)
clf = SVC()
clf.set_params(kernel='linear').fit(X, y)
print clf.predict(X_test)  # output:[1, 0, 1, 1, 0]

clf.set_params(kernel='rbf').fit(X, y)
print clf.predict(X_test)  # output:[0, 0, 0, 1, 0]


'''
    Multiclass vs. multilabel fitting
'''
from sklearn.svm import SVC
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import LabelBinarizer
X = [[1, 2], [2, 4], [4, 5], [3, 2], [3, 1]]
y = [0, 0, 1, 1, 2]

classif = OneVsRestClassifier(estimator=SVC(random_state=0))
print classif.fit(X, y).predict(X)  # output:[0 0 1 1 2]

y = LabelBinarizer().fit_transform(y)
print classif.fit(X, y).predict(X)  # output:[[1 0 0][1 0 0][0 1 0][0 0 0][0 0 0]]

from sklearn.preprocessing import MultiLabelBinarizer
y = [[0, 1], [0, 2], [1, 3], [0, 2, 3], [2, 4]]
y = MultiLabelBinarizer().fit_transform(y)
print classif.fit(X, y).predict(X)  # output:[[1 1 0 0 0][1 0 1 0 0][0 1 0 1 0][1 0 1 0 0][1 0 1 0 0]]

運行結果:python

digits.data: [[ 0.  0.  5. ...  0.  0.  0.]
 [ 0.  0.  0. ... 10.  0.  0.]
 [ 0.  0.  0. ... 16.  9.  0.]
 ...
 [ 0.  0.  1. ...  6.  0.  0.]
 [ 0.  0.  2. ... 12.  0.  0.]
 [ 0.  0. 10. ... 12.  1.  0.]]
digits.target: [0 1 2 ... 8 9 8]
iris.data: [[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]
 [5.4 3.9 1.7 0.4]
 [4.6 3.4 1.4 0.3]
 [5.  3.4 1.5 0.2]
 [4.4 2.9 1.4 0.2]
 [4.9 3.1 1.5 0.1]
 [5.4 3.7 1.5 0.2]
 [4.8 3.4 1.6 0.2]
 [4.8 3.  1.4 0.1]
 [4.3 3.  1.1 0.1]
 [5.8 4.  1.2 0.2]
 [5.7 4.4 1.5 0.4]
 [5.4 3.9 1.3 0.4]
 [5.1 3.5 1.4 0.3]
 [5.7 3.8 1.7 0.3]
 [5.1 3.8 1.5 0.3]
 [5.4 3.4 1.7 0.2]
 [5.1 3.7 1.5 0.4]
 [4.6 3.6 1.  0.2]
 [5.1 3.3 1.7 0.5]
 [4.8 3.4 1.9 0.2]
 [5.  3.  1.6 0.2]
 [5.  3.4 1.6 0.4]
 [5.2 3.5 1.5 0.2]
 [5.2 3.4 1.4 0.2]
 [4.7 3.2 1.6 0.2]
 [4.8 3.1 1.6 0.2]
 [5.4 3.4 1.5 0.4]
 [5.2 4.1 1.5 0.1]
 [5.5 4.2 1.4 0.2]
 [4.9 3.1 1.5 0.1]
 [5.  3.2 1.2 0.2]
 [5.5 3.5 1.3 0.2]
 [4.9 3.1 1.5 0.1]
 [4.4 3.  1.3 0.2]
 [5.1 3.4 1.5 0.2]
 [5.  3.5 1.3 0.3]
 [4.5 2.3 1.3 0.3]
 [4.4 3.2 1.3 0.2]
 [5.  3.5 1.6 0.6]
 [5.1 3.8 1.9 0.4]
 [4.8 3.  1.4 0.3]
 [5.1 3.8 1.6 0.2]
 [4.6 3.2 1.4 0.2]
 [5.3 3.7 1.5 0.2]
 [5.  3.3 1.4 0.2]
 [7.  3.2 4.7 1.4]
 [6.4 3.2 4.5 1.5]
 [6.9 3.1 4.9 1.5]
 [5.5 2.3 4.  1.3]
 [6.5 2.8 4.6 1.5]
 [5.7 2.8 4.5 1.3]
 [6.3 3.3 4.7 1.6]
 [4.9 2.4 3.3 1. ]
 [6.6 2.9 4.6 1.3]
 [5.2 2.7 3.9 1.4]
 [5.  2.  3.5 1. ]
 [5.9 3.  4.2 1.5]
 [6.  2.2 4.  1. ]
 [6.1 2.9 4.7 1.4]
 [5.6 2.9 3.6 1.3]
 [6.7 3.1 4.4 1.4]
 [5.6 3.  4.5 1.5]
 [5.8 2.7 4.1 1. ]
 [6.2 2.2 4.5 1.5]
 [5.6 2.5 3.9 1.1]
 [5.9 3.2 4.8 1.8]
 [6.1 2.8 4.  1.3]
 [6.3 2.5 4.9 1.5]
 [6.1 2.8 4.7 1.2]
 [6.4 2.9 4.3 1.3]
 [6.6 3.  4.4 1.4]
 [6.8 2.8 4.8 1.4]
 [6.7 3.  5.  1.7]
 [6.  2.9 4.5 1.5]
 [5.7 2.6 3.5 1. ]
 [5.5 2.4 3.8 1.1]
 [5.5 2.4 3.7 1. ]
 [5.8 2.7 3.9 1.2]
 [6.  2.7 5.1 1.6]
 [5.4 3.  4.5 1.5]
 [6.  3.4 4.5 1.6]
 [6.7 3.1 4.7 1.5]
 [6.3 2.3 4.4 1.3]
 [5.6 3.  4.1 1.3]
 [5.5 2.5 4.  1.3]
 [5.5 2.6 4.4 1.2]
 [6.1 3.  4.6 1.4]
 [5.8 2.6 4.  1.2]
 [5.  2.3 3.3 1. ]
 [5.6 2.7 4.2 1.3]
 [5.7 3.  4.2 1.2]
 [5.7 2.9 4.2 1.3]
 [6.2 2.9 4.3 1.3]
 [5.1 2.5 3.  1.1]
 [5.7 2.8 4.1 1.3]
 [6.3 3.3 6.  2.5]
 [5.8 2.7 5.1 1.9]
 [7.1 3.  5.9 2.1]
 [6.3 2.9 5.6 1.8]
 [6.5 3.  5.8 2.2]
 [7.6 3.  6.6 2.1]
 [4.9 2.5 4.5 1.7]
 [7.3 2.9 6.3 1.8]
 [6.7 2.5 5.8 1.8]
 [7.2 3.6 6.1 2.5]
 [6.5 3.2 5.1 2. ]
 [6.4 2.7 5.3 1.9]
 [6.8 3.  5.5 2.1]
 [5.7 2.5 5.  2. ]
 [5.8 2.8 5.1 2.4]
 [6.4 3.2 5.3 2.3]
 [6.5 3.  5.5 1.8]
 [7.7 3.8 6.7 2.2]
 [7.7 2.6 6.9 2.3]
 [6.  2.2 5.  1.5]
 [6.9 3.2 5.7 2.3]
 [5.6 2.8 4.9 2. ]
 [7.7 2.8 6.7 2. ]
 [6.3 2.7 4.9 1.8]
 [6.7 3.3 5.7 2.1]
 [7.2 3.2 6.  1.8]
 [6.2 2.8 4.8 1.8]
 [6.1 3.  4.9 1.8]
 [6.4 2.8 5.6 2.1]
 [7.2 3.  5.8 1.6]
 [7.4 2.8 6.1 1.9]
 [7.9 3.8 6.4 2. ]
 [6.4 2.8 5.6 2.2]
 [6.3 2.8 5.1 1.5]
 [6.1 2.6 5.6 1.4]
 [7.7 3.  6.1 2.3]
 [6.3 3.4 5.6 2.4]
 [6.4 3.1 5.5 1.8]
 [6.  3.  4.8 1.8]
 [6.9 3.1 5.4 2.1]
 [6.7 3.1 5.6 2.4]
 [6.9 3.1 5.1 2.3]
 [5.8 2.7 5.1 1.9]
 [6.8 3.2 5.9 2.3]
 [6.7 3.3 5.7 2.5]
 [6.7 3.  5.2 2.3]
 [6.3 2.5 5.  1.9]
 [6.5 3.  5.2 2. ]
 [6.2 3.4 5.4 2.3]
 [5.9 3.  5.1 1.8]]
iris.target: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
[[ 0.  0.  5. 13.  9.  1.  0.  0.]
 [ 0.  0. 13. 15. 10. 15.  5.  0.]
 [ 0.  3. 15.  2.  0. 11.  8.  0.]
 [ 0.  4. 12.  0.  0.  8.  8.  0.]
 [ 0.  5.  8.  0.  0.  9.  8.  0.]
 [ 0.  4. 11.  0.  1. 12.  7.  0.]
 [ 0.  2. 14.  5. 10. 12.  0.  0.]
 [ 0.  0.  6. 13. 10.  0.  0.  0.]]
[[[ 0.  0.  5. ...  1.  0.  0.]
  [ 0.  0. 13. ... 15.  5.  0.]
  [ 0.  3. 15. ... 11.  8.  0.]
  ...
  [ 0.  4. 11. ... 12.  7.  0.]
  [ 0.  2. 14. ... 12.  0.  0.]
  [ 0.  0.  6. ...  0.  0.  0.]]
 [[ 0.  0.  0. ...  5.  0.  0.]
  [ 0.  0.  0. ...  9.  0.  0.]
  [ 0.  0.  3. ...  6.  0.  0.]
  ...
  [ 0.  0.  1. ...  6.  0.  0.]
  [ 0.  0.  1. ...  6.  0.  0.]
  [ 0.  0.  0. ... 10.  0.  0.]]
 [[ 0.  0.  0. ... 12.  0.  0.]
  [ 0.  0.  3. ... 14.  0.  0.]
  [ 0.  0.  8. ... 16.  0.  0.]
  ...
  [ 0.  9. 16. ...  0.  0.  0.]
  [ 0.  3. 13. ... 11.  5.  0.]
  [ 0.  0.  0. ... 16.  9.  0.]]
 ...
 [[ 0.  0.  1. ...  1.  0.  0.]
  [ 0.  0. 13. ...  2.  1.  0.]
  [ 0.  0. 16. ... 16.  5.  0.]
  ...
  [ 0.  0. 16. ... 15.  0.  0.]
  [ 0.  0. 15. ... 16.  0.  0.]
  [ 0.  0.  2. ...  6.  0.  0.]]
 [[ 0.  0.  2. ...  0.  0.  0.]
  [ 0.  0. 14. ... 15.  1.  0.]
  [ 0.  4. 16. ... 16.  7.  0.]
  ...
  [ 0.  0.  0. ... 16.  2.  0.]
  [ 0.  0.  4. ... 16.  2.  0.]
  [ 0.  0.  5. ... 12.  0.  0.]]
 [[ 0.  0. 10. ...  1.  0.  0.]
  [ 0.  2. 16. ...  1.  0.  0.]
  [ 0.  0. 15. ... 15.  0.  0.]
  ...
  [ 0.  4. 16. ... 16.  6.  0.]
  [ 0.  8. 16. ... 16.  8.  0.]
  [ 0.  1.  8. ... 12.  1.  0.]]]
1797
Classification report for classifier SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape='ovr', degree=3, gamma=0.001, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False):
             precision    recall  f1-score   support
          0       1.00      0.99      0.99        88
          1       0.99      0.97      0.98        91
          2       0.99      0.99      0.99        86
          3       0.98      0.87      0.92        91
          4       0.99      0.96      0.97        92
          5       0.95      0.97      0.96        91
          6       0.99      0.99      0.99        91
          7       0.96      0.99      0.97        89
          8       0.94      1.00      0.97        88
          9       0.93      0.98      0.95        92
avg / total       0.97      0.97      0.97       899
Confusion matrix:
[[87  0  0  0  1  0  0  0  0  0]
 [ 0 88  1  0  0  0  0  0  1  1]
 [ 0  0 85  1  0  0  0  0  0  0]
 [ 0  0  0 79  0  3  0  4  5  0]
 [ 0  0  0  0 88  0  0  0  0  4]
 [ 0  0  0  0  0 88  1  0  0  2]
 [ 0  1  0  0  0  0 90  0  0  0]
 [ 0  0  0  0  0  1  0 88  0  0]
 [ 0  0  0  0  0  0  0  0 88  0]
 [ 0  0  0  1  0  1  0  0  0 90]]
SVC(C=100.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape='ovr', degree=3, gamma=0.001, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)
[8]
[0]
[0]
[0, 0, 0]
['setosa', 'setosa', 'setosa']
[1 0 1 1 0]
[0 0 0 1 0]
[0 0 1 1 2]
[[1 0 0]
 [1 0 0]
 [0 1 0]
 [0 0 0]
 [0 0 0]]
[[1 1 0 0 0]
 [1 0 1 0 0]
 [0 1 0 1 0]
 [1 0 1 0 0]
 [1 0 1 0 0]]
PyDev console: starting.
Python 2.7.12 (default, Dec  4 2017, 14:50:18) 
[GCC 5.4.0 20160609] on linux2
相關文章
相關標籤/搜索