MNIST 數據集
包含60 000 張訓練圖像和10 000 張測試圖像,由美國國家標準與技術研究院(National Institute of Standards and Technology,即MNIST 中
的NIST)在20 世紀80 年代收集獲得。
類和標籤
在機器學習中,分類問題中的某個類別叫做類(class)。數據點叫做樣本(sample)。某
個樣本對應的類叫做標籤(label)。
MNIST 數據集預先加載在Keras 庫中,其中包括4 個Numpy 數組。
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images 和train_labels 組成了訓練集(training set),模型將從這些數據中進行
學習。而後在測試集(test set,即test_images 和test_labels)上對模型進行測試。
圖像被編碼爲Numpy 數組,而標籤是數字數組,取值範圍爲0~9。圖像和標籤一一對應。
咱們來看一下訓練數據:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)
測試數據:
>>> test_images.shape
(10000, 28, 28)
>>> len(test_labels)
10000
>>> test_labels
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)
神經網絡架構
from keras import models
from keras import layers
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax'))
本例中的網絡包含2 個Dense 層,它們是密集鏈接(也叫全鏈接)的神經層。第二層(也
是最後一層)是一個10 路softmax 層,它將返回一個由10 個機率值(總和爲1)組成的數組。
每一個機率值表示當前數字圖像屬於10 個數字類別中某一個的機率。
要想訓練網絡,咱們還須要選擇編譯(compile)步驟的三個參數。
損失函數(loss function):網絡如何衡量在訓練數據上的性能,即網絡如何朝着正確的
方向前進。
優化器(optimizer):基於訓練數據和損失函數來更新網絡的機制。
在訓練和測試過程當中須要監控的指標(metric):本例只關心精度,即正確分類的圖像所
佔的比例。
編譯步驟
network.compile(optimizer='rmsprop',loss='categorical_crossentropy', metrics=['accuracy'])
在開始訓練以前,咱們將對數據進行預處理,將其變換爲網絡要求的形狀,並縮放到所
有值都在[0, 1] 區間。好比,以前訓練圖像保存在一個uint8 類型的數組中,其形狀爲
(60000, 28, 28),取值區間爲[0, 255]。咱們須要將其變換爲一個float32 數組,其形
狀爲(60000, 28 * 28),取值範圍爲0~1。
準備圖像數據
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255
準備標籤
from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
開始訓練網絡
>>> network.fit(train_images, train_labels, epochs=5, batch_size=128)
Epoch 1/5
60000/60000 [=============================] - 9s - loss: 0.2524 - acc: 0.9273
Epoch 2/5
51328/60000 [=======================>.....] - ETA: 1s - loss: 0.1035 - acc: 0.9692
檢查模型在測試集上的性能
>>> test_loss, test_acc = network.evaluate(test_images, test_labels)
>>> print('test_acc:', test_acc)
test_acc: 0.9785