Android+TensorFlow+CNN+MNIST 手寫數字識別實現

Android+TensorFlow+CNN+MNIST 手寫數字識別實現

SkySeraph 2018html

Email:skyseraph00#163.comjava

更多精彩請直接訪問SkySeraph我的站點www.skyseraph.com 

Overview

本文系「SkySeraph AI 實踐到理論系列」第一篇,咱以AI界的HelloWord 經典MNIST數據集爲基礎,在Android平臺,基於TensorFlow,實現CNN的手寫數字識別。
Code here~python


Practice

Environment

  • TensorFlow: 1.2.0
  • Python: 3.6
  • Python IDE: PyCharm 2017.2
  • Android IDE: Android Studio 3.0

Train & Evaluate(Python+TensorFlow)

訓練和評估部分主要目的是生成用於測試用的pb文件,其保存了利用TensorFlow python API構建訓練後的網絡拓撲結構和參數信息,實現方式有不少種,除了cnn外還可使用rnn,fcnn等。
其中基於cnn的函數也有兩套,分別爲tf.layers.conv2d和tf.nn.conv2d, tf.layers.conv2d使用tf.nn.conv2d做爲後端處理,參數上filters是整數,filter是4維張量。原型以下:
convolutional.py文件
def conv2d(inputs, filters, kernel_size, strides=(1, 1), padding=’valid’, data_format=’channels_last’,
dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(), kernel_regularizer=None, bias_regularizer=None,
activity_regularizer=None, kernel_constraint=None, bias_constraint=None, trainable=True, name=None,
reuse=None)android

gen_nn_ops.py 文件git

def conv2d(input, filter, strides, padding, use_cudnn_on_gpu=True, data_format="NHWC", name=None)

官方Demo實例中使用的是layers module,結構以下:github

  • Convolutional Layer #1:32個5×5的filter,使用ReLU激活函數
  • Pooling Layer #1:2×2的filter作max pooling,步長爲2
  • Convolutional Layer #2:64個5×5的filter,使用ReLU激活函數
  • Pooling Layer #2:2×2的filter作max pooling,步長爲2
  • Dense Layer #1:1024個神經元,使用ReLU激活函數,dropout率0.4 (爲了不過擬合,在訓練的時候,40%的神經元會被隨機去掉)
  • Dense Layer #2 (Logits Layer):10個神經元,每一個神經元對應一個類別(0-9)

核心代碼在cnn_model_fn(features, labels, mode)函數中,完成卷積結構的完整定義,核心代碼以下.
算法

也能夠採用傳統的tf.nn.conv2d函數, 核心代碼以下。
數據庫

Test(Android+TensorFlow)

  • 核心是使用API接口: TensorFlowInferenceInterface.java
  • 配置gradle 或者 自編譯TensorFlow源碼導入jar和so
    compile ‘org.tensorflow:tensorflow-android:1.2.0’
  • 導入pb文件.pb文件放assets目錄,而後讀取後端

    String actualFilename = labelFilename.split(「file:///android_asset/「)[1];
    Log.i(TAG, 「Reading labels from: 「 + actualFilename);
    BufferedReader br = null;
    br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
    String line;
    while ((line = br.readLine()) != null) {
    c.labels.add(line);
    }
    br.close();網絡

  • TensorFlow接口使用

  • 最終效果:

Theory

MNIST

MNIST,最經典的機器學習模型之一,包含0~9的數字,28*28大小的單色灰度手寫數字圖片數據庫,其中共60,000 training examples和10,000 test examples。
文件目錄以下,主要包括4個二進制文件,分別爲訓練和測試圖片及Label。

以下爲訓練圖片的二進制結構,在真實數據前(pixel),有部分描述字段(魔數,圖片個數,圖片行數和列數),真實數據的存儲採用大端規則。
(大端規則,就是數據的高字節保存在低內存地址中,低字節保存在高內存地址中)

在具體實驗使用,須要提取真實數據,可採用專門用於處理字節的庫struct中的unpack_from方法,核心方法以下:
struct.unpack_from(self._fourBytes2, buf, index)

MNIST做爲AI的Hello World入門實例數據,TensorFlow封裝對其封裝好了函數,可直接使用
mnist = input_data.read_data_sets(‘MNIST’, one_hot=True)

CNN(Convolutional Neural Network)

CNN Keys

  • CNN,Convolutional Neural Network,中文全稱卷積神經網絡,即所謂的卷積網(ConvNets)。
  • 卷積(Convolution)可謂是現代深度學習中最最重要的概念了,它是一種數學運算,讀者能夠從下面連接[23]中卷積相關數學機理,包括分別從傅里葉變換和狄拉克δ函數中推到卷積定義,咱們能夠從字面上宏觀粗魯的理解成將因子翻轉相乘捲起來。
  • 卷積動畫。演示以下圖[26],更多動畫演示可參考[27]
  • 神經網絡。一個由大量神經元(neurons)組成的系統,以下圖所示[21]

    其中x表示輸入向量,w爲權重,b爲偏值bias,f爲激活函數。

  • Activation Function 激活函數: 經常使用的非線性激活函數有Sigmoid、tanh、ReLU等等,公式以下如所示。

    • Sigmoid缺點
      • 函數飽和使梯度消失(神經元在值爲 0 或 1 的時候接近飽和,這些區域,梯度幾乎爲 0)
      • sigmoid 函數不是關於原點中心對稱的(無0中心化)
    • tanh: 存在飽和問題,但它的輸出是零中心的,所以實際中 tanh 比 sigmoid 更受歡迎。
    • ReLU
      • 優勢1:ReLU 對於 SGD 的收斂有巨大的加速做用
      • 優勢2:只須要一個閾值就能夠獲得激活值,而不用去算一大堆複雜的(指數)運算
      • 缺點:須要合理設置學習率(learning rate),防止訓練時dead,還可使用Leaky ReLU/PReLU/Maxout等代替
  • Pooling池化。通常分爲平均池化mean pooling和最大池化max pooling,以下圖所示[21]爲max pooling,除此以外,還有重疊池化(OverlappingPooling)[24],空金字塔池化(Spatial Pyramid Pooling)[25]
    • 平均池化:計算圖像區域的平均值做爲該區域池化後的值。
    • 最大池化:選圖像區域的最大值做爲該區域池化後的值。

CNN Architecture

  • 三層神經網絡。分別爲輸入層(Input layer),輸出層(Output layer),隱藏層(Hidden layer),以下圖所示[21]
  • CNN層級結構。 斯坦福cs231n中闡述了一種[INPUT-CONV-RELU-POOL-FC],以下圖所示[21],分別爲輸入層,卷積層,激勵層,池化層,全鏈接層。
  • CNN通用架構分爲以下三層結構:
    • Convolutional layers 卷積層
    • Pooling layers 匯聚層
    • Dense (fully connected) layers 全鏈接層
  • 動畫演示。參考[22]。

Regression + Softmax

機器學習有監督學習(supervised learning)中兩大算法分別是分類算法和迴歸算法,分類算法用於離散型分佈預測,迴歸算法用於連續型分佈預測。
迴歸的目的就是創建一個迴歸方程用來預測目標值,迴歸的求解就是求這個迴歸方程的迴歸係數。
其中迴歸(Regression)算法包括Linear Regression,Logistic Regression等, Softmax Regression是其中一種用於解決多分類(multi-class classification)問題的Logistic迴歸算法的推廣,經典實例就是在MNIST手寫數字分類上的應用。

Linear Regression

Linear Regression是機器學習中最基礎的模型,其目標是用預測結果儘量地擬合目標label

  • 多元線性迴歸模型定義
  • 多元線性迴歸求解
  • Mean Square Error (MSE)
    • Gradient Descent(梯度降低法)
    • Normal Equation(普通最小二乘法)
    • 局部加權線性迴歸(LocallyWeightedLinearRegression, LWLR ):針對線性迴歸中模型欠擬合現象,在估計中引入一些誤差以便下降預測的均方偏差。
    • 嶺迴歸(ridge regression)和縮減方法
  • 選擇: Normal Equation相比Gradient Descent,計算量大(需計算X的轉置與逆矩陣),只適用於特徵個數小於100000時使用;當特徵數量大於100000時使用梯度法。當X不可逆時可替代方法爲嶺迴歸算法。LWLR方法增長了計算量,由於它對每一個點作預測時都必須使用整個數據集,而不是計算出迴歸係數獲得迴歸方程後代入計算便可,通常不選擇。
  • 調優: 平衡預測誤差和模型方差(高誤差就是欠擬合,高方差就是過擬合)
    • 獲取更多的訓練樣本 - 解決高方差
    • 嘗試使用更少的特徵的集合 - 解決高方差
    • 嘗試得到其餘特徵 - 解決高誤差
    • 嘗試添加多項組合特徵 - 解決高誤差
    • 嘗試減少 λ - 解決高誤差
    • 嘗試增長 λ -解決高方差

Softmax Regression

  • Softmax Regression估值函數(hypothesis)
  • Softmax Regression代價函數(cost function)
  • 理解:
  • Softmax Regression & Logistic Regression:
    • 多分類 & 二分類。Logistic Regression爲K=2時的Softmax Regression
    • 針對K類問題,當類別之間互斥時可採用Softmax Regression,當非斥時,可採用K個獨立的Logistic Regression
  • 總結: Softmax Regression適用於類別數量大於2的分類,本例中用於判斷每張圖屬於每一個數字的機率。

References & Recommends

MNIST

Softmax

CNN

TensorFlow+CNN / TensorFlow+Android



By SkySeraph-2018

SkySeraph cnBlogs

本文首發於skyseraph.com「Android+TensorFlow+CNN+MNIST 手寫數字識別實現」

相關文章
相關標籤/搜索