隨着TensorFlow 2.0 alpha的發佈,TensorFlow.js更新到首個正式版本1.0,TensorFlow的官網也增長了TensorFlow.js的文檔,這說明TensorFlow.js再也不是一個試驗品。做爲一名瀏覽器內核研發工程師,對TensorFlow.js天然充滿了興趣。html
Javascript語言這些年來四處攻城掠地,服務端有Node.js,移動前端開發更是大熱,就連桌面應用也有JS的身影,好比最近火熱的Visual Studio Code,如今又滲透到人工智能領域。不得不感概,當年匆忙設計出來,飽受批評的一門腳本語言,居然生命力這麼頑強。前端
閒話少說,下面就來看看在瀏覽器中訓練模型是怎樣的一種體驗。python
我以前寫過一系列的《一步步提升手寫數字的識別率(1)(2)(3)》,手寫數字識別是一個很是好的入門項目,因此在這裏我就以手寫數字識別爲例,說明在瀏覽器中如何訓練模型。這裏就不從最簡單的線性迴歸模型開始,而是直接選用卷積神經網絡。git
和python代碼中訓練模型的步驟同樣,使用TensorFlow.js在瀏覽器中訓練模型的步驟主要有4步:github
有過機器學習知識的朋友,應該對MNIST數據集不陌生,這是一套28x28大小手寫數字的灰度圖像,包含55000個訓練樣本,10000個測試樣本,另外還有5000個交叉驗證數據樣本。tensorflow python提供了一個封裝類,能夠直接加載MNIST數據集,在TensorFlow.js中須要本身寫代碼加載:web
const IMAGE_SIZE = 784;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;
const TRAIN_TEST_RATIO = 5 / 6;
const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
const MNIST_IMAGES_SPRITE_PATH =
'mnist_images.png';
const MNIST_LABELS_PATH =
'mnist_labels_uint8';
/**
* A class that fetches the sprited MNIST dataset and returns shuffled batches.
*
* NOTE: This will get much easier. For now, we do data fetching and
* manipulation manually.
*/
export class MnistData {
constructor() {
this.shuffledTrainIndex = 0;
this.shuffledTestIndex = 0;
}
async load() {
// Make a request for the MNIST sprited image.
const img = new Image();
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const imgRequest = new Promise((resolve, reject) => {
img.crossOrigin = '';
img.onload = () => {
img.width = img.naturalWidth;
img.height = img.naturalHeight;
const datasetBytesBuffer =
new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);
const chunkSize = 5000;
canvas.width = img.width;
canvas.height = chunkSize;
for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
const datasetBytesView = new Float32Array(
datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
IMAGE_SIZE * chunkSize);
ctx.drawImage(
img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
chunkSize);
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
for (let j = 0; j < imageData.data.length / 4; j++) {
// All channels hold an equal value since the image is grayscale, so
// just read the red channel.
datasetBytesView[j] = imageData.data[j * 4] / 255;
}
}
this.datasetImages = new Float32Array(datasetBytesBuffer);
resolve();
};
img.src = MNIST_IMAGES_SPRITE_PATH;
});
const labelsRequest = fetch(MNIST_LABELS_PATH);
const [imgResponse, labelsResponse] =
await Promise.all([imgRequest, labelsRequest]);
this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
// Create shuffled indices into the train/test set for when we select a
// random dataset element for training / validation.
this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);
// Slice the the images and labels into train and test sets.
this.trainImages =
this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.trainLabels =
this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
this.testLabels =
this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
}
nextTrainBatch(batchSize) {
return this.nextBatch(
batchSize, [this.trainImages, this.trainLabels], () => {
this.shuffledTrainIndex =
(this.shuffledTrainIndex + 1) % this.trainIndices.length;
return this.trainIndices[this.shuffledTrainIndex];
});
}
nextTestBatch(batchSize) {
return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
this.shuffledTestIndex =
(this.shuffledTestIndex + 1) % this.testIndices.length;
return this.testIndices[this.shuffledTestIndex];
});
}
nextBatch(batchSize, data, index) {
const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);
for (let i = 0; i < batchSize; i++) {
const idx = index();
const image =
data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
batchImagesArray.set(image, i * IMAGE_SIZE);
const label =
data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
batchLabelsArray.set(label, i * NUM_CLASSES);
}
const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);
return {xs, labels};
}
}
複製代碼
代碼中,加載一個 mnist_images.png 圖片,該圖片是全部MNIST數據集的圖像拼接而來(文件很大,大約10M),另外加載一個 mnist_labels_uint8 文本文件,包含全部的MNIST數據集對應的標籤。canvas
須要注意的是,這只是一種加載MNIST數據集的方法,你也可使用一個手寫數字一張圖片的MNIST數據集,分次加載多個圖片文件。瀏覽器
上述代碼實現了一個MnistData類,它有兩個公共方法:bash
爲了檢驗上述代碼是否工做正常,能夠寫一段代碼顯示加載的數據:網絡
async function showExamples(data) {
// Create a container in the visor
const surface =
tfvis.visor().surface({ name: 'Input Data Examples', tab: 'Input Data'});
// Get the examples
const examples = data.nextTestBatch(20);
const numExamples = examples.xs.shape[0];
// Create a canvas element to render each example
for (let i = 0; i < numExamples; i++) {
const imageTensor = tf.tidy(() => {
// Reshape the image to 28x28 px
return examples.xs
.slice([i, 0], [1, examples.xs.shape[1]])
.reshape([28, 28, 1]);
});
const canvas = document.createElement('canvas');
canvas.width = 28;
canvas.height = 28;
canvas.style = 'margin: 4px;';
await tf.browser.toPixels(imageTensor, canvas);
surface.drawArea.appendChild(canvas);
imageTensor.dispose();
}
}
async function run() {
const data = new MnistData();
await data.load();
await showExamples(data);
}
document.addEventListener('DOMContentLoaded', run);
複製代碼
關於卷積神經網絡,能夠參閱《一步步提升手寫數字的識別率(3)》這篇文章,這裏定義的卷積網絡結構爲:
CONV -> MAXPOOlING -> CONV -> MAXPOOLING -> FC -> SOFTMAX
每一個卷積層使用RELU激活函數,代碼以下:
function getModel() {
const model = tf.sequential();
const IMAGE_WIDTH = 28;
const IMAGE_HEIGHT = 28;
const IMAGE_CHANNELS = 1;
// In the first layer of out convolutional neural network we have
// to specify the input shape. Then we specify some paramaters for
// the convolution operation that takes place in this layer.
model.add(tf.layers.conv2d({
inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
// The MaxPooling layer acts as a sort of downsampling using max values
// in a region instead of averaging.
model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
// Repeat another conv2d + maxPooling stack.
// Note that we have more filters in the convolution.
model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
// Now we flatten the output from the 2D filters into a 1D vector to prepare
// it for input into our last layer. This is common practice when feeding
// higher dimensional data to a final classification output layer.
model.add(tf.layers.flatten());
// Our last layer is a dense layer which has 10 output units, one for each
// output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9).
const NUM_OUTPUT_CLASSES = 10;
model.add(tf.layers.dense({
units: NUM_OUTPUT_CLASSES,
kernelInitializer: 'varianceScaling',
activation: 'softmax'
}));
// Choose an optimizer, loss function and accuracy metric,
// then compile and return the model
const optimizer = tf.train.adam();
model.compile({
optimizer: optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
return model;
}
複製代碼
若是有過tensorflow python代碼編寫經驗,上面的代碼應該很容易理解。
在瀏覽器中訓練,也能夠批量輸入圖像數據,能夠指定batch size,epoch輪次。
async function train(model, data) {
const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
const container = {
name: 'Model Training', styles: { height: '1000px' }
};
const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
const BATCH_SIZE = 512;
const TRAIN_DATA_SIZE = 5500;
const TEST_DATA_SIZE = 1000;
const [trainXs, trainYs] = tf.tidy(() => {
const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
return [
d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
d.labels
];
});
const [testXs, testYs] = tf.tidy(() => {
const d = data.nextTestBatch(TEST_DATA_SIZE);
return [
d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
d.labels
];
});
return model.fit(trainXs, trainYs, {
batchSize: BATCH_SIZE,
validationData: [testXs, testYs],
epochs: 10,
shuffle: true,
callbacks: fitCallbacks
});
}
複製代碼
和python代碼相比,fit多了一個callbacks參數。須要注意的是,訓練過程比較長,咱們不能阻塞瀏覽器主線程,代碼中大多時候須要異步方法。而callbacks能夠通知主線程更新,這裏借用了tfvis庫,能夠可視化訓練過程(相似於tensorboard),但這裏是在網頁上顯示。
評估時喂入測試集,代碼也和python版本的相似:
function doPrediction(model, data, testDataSize = 500) {
const IMAGE_WIDTH = 28;
const IMAGE_HEIGHT = 28;
const testData = data.nextTestBatch(testDataSize);
const testxs = testData.xs.reshape([testDataSize, IMAGE_WIDTH, IMAGE_HEIGHT, 1]);
const labels = testData.labels.argMax([-1]);
const preds = model.predict(testxs).argMax([-1]);
testxs.dispose();
return [preds, labels];
}
複製代碼
若是咱們但願更直觀的顯示每一個類別的精確度以及錯誤的分類,能夠藉助tfvis庫:
async function showAccuracy(model, data) {
const [preds, labels] = doPrediction(model, data);
const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
const container = {name: 'Accuracy', tab: 'Evaluation'};
tfvis.show.perClassAccuracy(container, classAccuracy, classNames);
labels.dispose();
}
async function showConfusion(model, data) {
const [preds, labels] = doPrediction(model, data);
const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds);
const container = {name: 'Confusion Matrix', tab: 'Evaluation'};
tfvis.render.confusionMatrix(
container, {values: confusionMatrix}, classNames);
labels.dispose();
}
複製代碼
評估結果以下圖所示:
TensowFlow.js藉助於WebGL,能夠加速訓練過程。若是瀏覽器不支持WebGL,也不會出錯,只不過會走CPU的路徑,固然速度也會慢不少。
雖然經過WebGL,也利用上了GPU,但對於大規模深度學習模型,在瀏覽器中訓練也不現實,這個時候咱們也能夠在server上訓練好模型,轉換爲TensorFlow.js可用的模型格式,在瀏覽器中加載模型,並進行推斷,關於這個話題,請關注後續的文章。
以上示例有完整的代碼,點擊閱讀原文,跳轉到我在github上建的示例代碼。 另外,你也能夠在瀏覽器中直接訪問:ilego.club/ai/index.ht… ,直接體驗瀏覽器中的機器學習。
參考文獻: