原博地址https://laboo.top/2018/11/21/tfjs-dr/git
digit-recognizergithub
https://github-laziji.github.io/digit-recognizer/
演示開始時須要加載大概100M
的訓練數據, 稍等片刻網絡
調整訓練集的大小, 觀察測試結果的準確性機器學習
數據來源與 https://www.kaggle.com 中的一道題目 digit-recognizer
題目給出42000
條訓練數據(包含圖片和標籤)以及28000
條測試數據(只包含圖片)
要求給這些測試數據打上標籤[0,1,2,3....,9] 要儘量的準確async
網站中還有許多其餘的機器學習的題目以及數據, 是個很好的練手的地方ide
這裏咱們使用TensorFlow.js
來實現這個項目函數
卷積神經網絡的第一層有兩種做用, 它既是輸入層也是執行層, 接收IMAGE_H * IMAGE_W
大小的黑白像素
最後一層是輸出層, 有10個輸出單元, 表明着0-9
這十個值的機率分佈, 例如 Label=2 , 輸出爲[0.02,0.01,0.9,...,0.01]
學習
function createConvModel() { const model = tf.sequential(); model.add(tf.layers.conv2d({ inputShape: [IMAGE_H, IMAGE_W, 1], kernelSize: 3, filters: 16, activation: 'relu' })); model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })); model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' })); model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })); model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' })); model.add(tf.layers.flatten({})); model.add(tf.layers.dense({ units: 64, activation: 'relu' })); model.add(tf.layers.dense({ units: 10, activation: 'softmax' })); return model; }
咱們選擇適當的優化器和損失函數, 來編譯模型測試
async function train() { ui.trainLog('Create model...'); model = createConvModel(); ui.trainLog('Compile model...'); const optimizer = 'rmsprop'; model.compile({ optimizer, loss: 'categoricalCrossentropy', metrics: ['accuracy'], }); const trainData = Data.getTrainData(ui.getTrainNum()); ui.trainLog('Training model...'); await model.fit(trainData.xs, trainData.labels, {}); ui.trainLog('Completed!'); ui.trainCompleted(); }
這裏測試一組測試數據, 返回對應的標籤, 即十個輸出單元中機率最高的下標優化
function testOne(xs){ if(!model){ ui.viewLog('Need to train the model first'); return; } ui.viewLog('Testing...'); let output = model.predict(xs); ui.viewLog('Completed!'); output.print(); const axis = 1; const predictions = output.argMax(axis).dataSync(); return predictions[0]; }
歡迎關注個人博客公衆號
![]()