基於 keras-js 快速實現瀏覽器內的 CNN 手寫數字識別

在這篇文章中,我會快速地介紹如何使用 keras 訓練一個簡單的識別 MNIST(一個手寫數字數據集)的 CNN(卷積神經網絡),而且把訓練好的網絡應用到 web 瀏覽器內。html

DEMO 地址:starkwang.github.io/keras-js-de…vue


零、準備工做

首先須要給你的電腦安裝 keras,具體安裝的步驟請參考 keras 官方文檔python


1、快速入門

首先十分推薦閱讀 tensorflow 官方文檔中的 MNIST For ML Beginners,這裏是極客學院的中文翻譯webpack

MNIST 是一個很流行的入門級機器學習/計算機視覺數據集,它包含 0 - 9 的各類手寫數字圖片:git

每張圖片的尺寸均爲 28 * 28,用一個 28 * 28 的二維數組來表示,換句話說,每張圖片都是由 784 個像素點組成,每一個像素點的值在 0 - 255 之間。github

好比下面就是一個 "3" 的數據:web

000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 038 043 105 255 253 253 253 253 253 174 006 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 043 139 224 226 252 253 252 252 252 252 252 252 158 014 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 178 252 252 252 252 253 252 252 252 252 252 252 252 059 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 109 252 252 230 132 133 132 132 189 252 252 252 252 059 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 004 029 029 024 000 000 000 000 014 226 252 252 172 007 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 085 243 252 252 144 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 088 189 252 252 252 014 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 091 212 247 252 252 252 204 009 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 032 125 193 193 193 253 252 252 252 238 102 028 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 045 222 252 252 252 252 253 252 252 252 177 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 045 223 253 253 253 253 255 253 253 253 253 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 031 123 052 044 044 044 044 143 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 015 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 086 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 005 075 009 000 000 000 000 000 000 098 242 252 252 074 000 000 000 000 000 000 000 000 
000 000 000 000 000 061 183 252 029 000 000 000 000 018 092 239 252 252 243 065 000 000 000 000 000 000 000 000 
000 000 000 000 000 208 252 252 147 134 134 134 134 203 253 252 252 188 083 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 208 252 252 252 252 252 252 252 252 253 230 153 008 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 049 157 252 252 252 252 252 217 207 146 045 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 007 103 235 252 172 103 024 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 
000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000
複製代碼

使用 keras,能夠很方便地導入 MNIST 數據集:npm

from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
複製代碼

整體來講,咱們的想要獲得的網絡模型,是有一個固定的輸入輸出的:canvas

  • 輸入爲一個 28 * 28 的二維整數數組
  • 輸出是一個長度爲 10 的數組,依次表示 0-9 的可能性(例如若是有一張圖片 80% 機率爲 1, 20% 機率爲 7的話,那麼這個數組就是 [0, 0.8, 0, 0, 0, 0, 0, 0.2, 0, 0]

2、使用 keras 訓練網絡

咱們想要訓練的模型,由如下幾層網絡組成:數組

  1. 32 個 3x3 卷積核的卷積層
  2. 64 個 3x3 卷積核的卷積層
  3. 採樣因子爲 (2, 2) 的池化層
  4. Dropout 層
  5. Flatten 層
  6. ReLu 全鏈接層
  7. Dropout 層
  8. Softmax 全鏈接層

用 keras 訓練一個識別 MNIST 的 CNN 網絡很是方便,下面是一個官方給出的例子(源碼在此):

from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

batch_size = 128
num_classes = 10
epochs = 12

# input image dimensions
img_rows, img_cols = 28, 28

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

# Save model
model.save('myMnistCNN.h5')
複製代碼

若是已經安裝好了 keras,直接運行便可:

python mnist_cnn.py
複製代碼

3、轉換輸出模型

得到訓練好的 .h5 文件以後,模型還不能直接使用,由於咱們須要對它進行轉編碼,keras-js 提供了一個 python 腳本來自動執行:

python ./python/encoder.py -q myMnistCNN.h5
複製代碼

這個腳本會把 .h5 文件轉編碼爲 keras-js 可讀的格式,裏面包含了訓練好的神經網絡的全部模型和參數。

4、使用 keras-js 導入模型

首先須要引入 keras-js,能夠經過 script 標籤直接引入:

<script src="https://unpkg.com/keras-js"></script>
複製代碼

也能夠經過 npm 安裝後使用 webpack 構建引入,參考這裏

接下來就能夠直接建立一個 Model,keras-js 會自動加載對應的 bin 文件:

const model = new KerasJS.Model({
    filepath: '/path/to/mnist_cnn.bin',
    gpu: true,
    transferLayerOutputs: true
})
複製代碼

初始化完畢以後,就能夠用於 MNIST 識別了,輸入是一個長度爲 784 的數組(包含 28*28 各個像素點的灰度值),輸出是一個長度爲 10 的數組(0-9的機率):

(可使用上文中給的那個 "3" 的數據範例)

model
  .ready()
  .then(() => {
    // data 是一個長度爲 784 的數組,每一項都介於 0 - 255 之間
    // 這裏咱們須要把數組轉換爲 Float32 類型
    const inputData = new Float32Array(data)
    // 識別
    return model.predict(inputData)
  })
  .then(outputData => {
    // 輸出爲 0-9 的機率,例如:
    // { output: [0, 0, 0, 0.8, 0, 0, 0.2, 0, 0, 0] }
  })
  .catch(err => {
    // ...
  })
複製代碼

5、Canvas 實現一個手寫板

最後一步就是實現一個手寫板,具體的代碼就不放上來了,主要就是經過 mousedownmousemovemouseup 事件來繪製圖形。

繪製完畢以後,調用 ctx.getImageData,就能夠獲得 canvas 內的像素數據,每一個像素對應四個數值,依次是每一個點的 rgba 值,處理以後就能夠獲得長度爲 784 的灰度數組了。而後使用上文提到的 model.predict 便可。

相關文章
相關標籤/搜索