如何在Tensorflow.js中處理MNIST圖像數據

選自freeCodeCampgit

做者:Kevin Scottgithub

機器之心編譯canvas

參與:李詩萌、路api

數據清理是數據科學和機器學習中的重要組成部分,本文介紹瞭如何在 Tensorflow.js(0.11.1)中處理 MNIST 圖像數據,並逐行解釋代碼。


有人開玩笑說有 80% 的數據科學家在清理數據,剩下的 20% 在抱怨清理數據……在數據科學工做中,清理數據所佔比例比外人想象的要多得多。通常而言,訓練模型一般只佔機器學習或數據科學家工做的一小部分(少於 10%)。
——Kaggle CEO Antony Goldbloom

對任何一個機器學習問題而言,數據處理都是很重要的一步。本文將採用 Tensorflow.js(0.11.1)的 MNIST 樣例(github.com/tensorflow/…),逐行運行數據處理的代碼。跨域

MNIST 樣例數組

18 import * as tf from '@tensorflow/tfjs';
19
20 const IMAGE_SIZE = 784;
21 const NUM_CLASSES = 10;
22 const NUM_DATASET_ELEMENTS = 65000;
23
24 const NUM_TRAIN_ELEMENTS = 55000;
25 const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
26
27 const MNIST_IMAGES_SPRITE_PATH =
28 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
29 const MNIST_LABELS_PATH =
30 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';`
複製代碼

首先,導入 TensorFlow(確保你在轉譯代碼)並創建一些常量,包括:promise

  • IMAGE_SIZE:圖像尺寸(28*28=784)
  • NUM_CLASSES:標籤類別的數量(這個數字能夠是 0~9,因此這裏有 10 類)
  • NUM_DATASET_ELEMENTS:圖像總數量(65000)
  • NUM_TRAIN_ELEMENTS:訓練集中圖像的數量(55000)
  • NUM_TEST_ELEMENTS:測試集中圖像的數量(10000,亦稱餘數)
  • MNIST_IMAGES_SPRITE_PATH&MNIST_LABELS_PATH:圖像和標籤的路徑

將這些圖像級聯爲一個巨大的圖像,以下圖所示:瀏覽器

MNISTDatabash

接下來,從第 38 行開始是 MnistData,該類別使用如下函數:機器學習

  • load:負責異步加載圖像和標註數據;
  • nextTrainBatch:加載下一個訓練批;
  • nextTestBatch:加載下一個測試批;
  • nextBatch:返回下一個批的通用函數,該函數的使用取決因而在訓練集仍是測試集。

本文屬於入門文章,所以只採用 load 函數。

load

async load() {
 // Make a request for the MNIST sprited image.
 const img = new Image();
 const canvas = document.createElement('canvas');
 const ctx = canvas.getContext('2d');
複製代碼

異步函數(async)是 Javascript 中相對較新的語言功能,所以你須要一個轉譯器。

Image 對象是表示內存中圖像的本地 DOM 函數,在圖像加載時提供可訪問圖像屬性的回調。canvas 是 DOM 的另外一個元素,該元素能夠提供訪問像素數組的簡單方式,還能夠經過上下文對其進行處理。

由於這兩個都是 DOM 元素,因此若是用 Node.js(或 Web Worker)則無需訪問這些元素。有關其餘可替代的方法,請參見下文。

imgRequest

const imgRequest = new Promise((resolve, reject) => {
 img.crossOrigin = '';
 img.onload = () => {
 img.width = img.naturalWidth;
 img.height = img.naturalHeight;
複製代碼

該代碼初始化了一個 new promise,圖像加載成功後該 promise 結束。該示例沒有明確處理偏差狀態。

crossOrigin 是一個容許跨域加載圖像並能夠在與 DOM 交互時解決 CORS(跨源資源共享,cross-origin resource sharing)問題的圖像屬性。naturalWidth 和 naturalHeight 指加載圖像的原始維度,在計算時能夠強制校訂圖像尺寸。

const datasetBytesBuffer =
 new ArrayBuffer(NUMDATASETELEMENTS * IMAGESIZE * 4);
57
58 const chunkSize = 5000;
59 canvas.width = img.width;
60 canvas.height = chunkSize;
複製代碼

該代碼初始化了一個新的 buffer,包含每一張圖的每個像素。它將圖像總數和每張圖像的尺寸和通道數量相乘。

我認爲 chunkSize 的用處在於防止 UI 一次將太多數據加載到內存中,但並不能 100% 肯定。

62 for (let i = 0; i < NUMDATASETELEMENTS / chunkSize; i++) {
63 const datasetBytesView = new Float32Array(
64 datasetBytesBuffer, i * IMAGESIZE * chunkSize * 4,
 IMAGESIZE * chunkSize);
66 ctx.drawImage(
67 img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
68 chunkSize);
69
70 const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
複製代碼

該代碼遍歷了每一張 sprite 圖像,併爲該迭代初始化了一個新的 TypedArray。接下來,上下文圖像獲取了一個繪製出來的圖像塊。最終,使用上下文的 getImageData 函數將繪製出來的圖像轉換爲圖像數據,返回的是一個表示底層像素數據的對象。

72 for (let j = 0; j < imageData.data.length / 4; j++) {
73 // All channels hold an equal value since the image is grayscale, so
74 // just read the red channel.
75 datasetBytesView[j] = imageData.data[j * 4] / 255;
76 }
77 }
複製代碼

咱們遍歷了這些像素併除以 255(像素的可能最大值),以將值限制在 0 到 1 之間。只有紅色的通道是必要的,由於它是灰度圖像。

78 this.datasetImages = new Float32Array(datasetBytesBuffer);
79
80 resolve();
81 };
82 img.src = MNISTIMAGESSPRITEPATH;
);
複製代碼

這一行建立了 buffer,將其映射到保存了咱們像素數據的新 TypedArray 中,而後結束了該 promise。事實上最後一行(設置 src 屬性)才真正啓動函數並加載圖像。

起初困擾個人一件事是 TypedArray 的行爲與其底層數據 buffer 相關。你可能注意到了,在循環中設置了 datasetBytesView,但它永遠都不會返回。

datasetBytesView 引用了緩衝區的 datasetBytesBuffer(初始化使用)。當代碼更新像素數據時,它會間接編輯緩衝區的值,而後將其轉換爲 78 行的 new Float32Array。

獲取 DOM 外的圖像數據

若是你在 DOM 中,使用 DOM 便可,瀏覽器(經過 canvas)負責肯定圖像的格式以及將緩衝區數據轉換爲像素。可是若是你在 DOM 外工做的話(也就是說用的是 Node.js 或 Web Worker),那就須要一種替代方法。

fetch 提供了一種稱爲 response.arrayBuffer 的機制,這種機制使你能夠訪問文件的底層緩衝。咱們能夠用這種方法在徹底避免 DOM 的狀況下手動讀取字節。這裏有一種編寫上述代碼的替代方法(這種方法須要 fetch,能夠用 isomorphic-fetch 等方法在 Node 中進行多邊填充):

const imgRequest = fetch(MNISTIMAGESSPRITE_PATH).then(resp => resp.arrayBuffer()).then(buffer => {
 return new Promise(resolve => {
 const reader = new PNGReader(buffer);
 return reader.parse((err, png) => {
 const pixels = Float32Array.from(png.pixels).map(pixel => {
 return pixel / 255;
 });
 this.datasetImages = pixels;
 resolve();
 });
 });
});
複製代碼

這爲特定圖像返回了一個緩衝數組。在寫這篇文章時,我第一次試着解析傳入的緩衝,但我不建議這樣作。若是須要的話,我推薦使用 pngjs 進行 png 的解析。當處理其餘格式的圖像時,則須要本身寫解析函數。

有待深刻

理解數據操做是用 JavaScript 進行機器學習的重要部分。經過理解本文所述用例與需求,咱們能夠根據需求在僅使用幾個關鍵函數的狀況下對數據進行格式化。

TensorFlow.js 團隊一直在改進 TensorFlow.js 的底層數據 API,這有助於更多地知足需求。這也意味着,隨着 TensorFlow.js 的不斷改進和發展,API 也會繼續前進,跟上發展的步伐。

相關文章
相關標籤/搜索