帶你入門機器學習 - 卷積神經網絡識別手寫數字

將來的程序員分爲兩種, 會機器學習和不會的git

原博地址laboo.top/2018/11/21/…程序員

源碼

digit-recognizergithub

demo

github-laziji.github.io/digit-recog… 演示開始時須要加載大概100M的訓練數據, 稍等片刻網絡

調整訓練集的大小, 觀察測試結果的準確性機器學習

數據來源

數據來源與 www.kaggle.com 中的一道題目 digit-recognizer 題目給出42000條訓練數據(包含圖片和標籤)以及28000條測試數據(只包含圖片) 要求給這些測試數據打上標籤[0,1,2,3....,9] 要儘量的準確async

網站中還有許多其餘的機器學習的題目以及數據, 是個很好的練手的地方ide

實現

這裏咱們使用TensorFlow.js來實現這個項目函數

建立模型

卷積神經網絡的第一層有兩種做用, 它既是輸入層也是執行層, 接收IMAGE_H * IMAGE_W大小的黑白像素 最後一層是輸出層, 有10個輸出單元, 表明着0-9這十個值的機率分佈, 例如 Label=2 , 輸出爲[0.02,0.01,0.9,...,0.01]學習

function createConvModel() {
  const model = tf.sequential();

  model.add(tf.layers.conv2d({
    inputShape: [IMAGE_H, IMAGE_W, 1],
    kernelSize: 3,
    filters: 16,
    activation: 'relu'
  }));

  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
  model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
  model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
  model.add(tf.layers.flatten({}));

  model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
  model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));

  return model;
}
複製代碼

訓練模型

咱們選擇適當的優化器和損失函數, 來編譯模型測試

async function train() {

  ui.trainLog('Create model...');
  model = createConvModel();
  
  ui.trainLog('Compile model...');
  const optimizer = 'rmsprop';
  model.compile({
    optimizer,
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy'],
  });
  const trainData = Data.getTrainData(ui.getTrainNum());
  
  ui.trainLog('Training model...');
  await model.fit(trainData.xs, trainData.labels, {});

  ui.trainLog('Completed!');
  ui.trainCompleted();
}
複製代碼

測試

這裏測試一組測試數據, 返回對應的標籤, 即十個輸出單元中機率最高的下標

function testOne(xs){
  if(!model){
    ui.viewLog('Need to train the model first');
    return;
  }
  ui.viewLog('Testing...');
  let output = model.predict(xs);
  ui.viewLog('Completed!');
  output.print();
  const axis = 1;
  const predictions = output.argMax(axis).dataSync();
  return predictions[0];
}
複製代碼

歡迎關注個人博客公衆號

2018_11_16_0048241709.png
相關文章
相關標籤/搜索