Android+TensorFlow+CNN+MNIST實現手寫數字識別

開發環境

  • TensorFlow: 1.2.0
  • Python: 3.6
  • Python IDE: PyCharm 2017.2
  • Android IDE: Android Studio 3.0

訓練與評估

訓練和評估部分主要目的是生成用於測試用的pb文件,其保存了利用TensorFlow python API構建訓練後的網絡拓撲結構和參數信息,實現方式有不少種,除了cnn外還能夠使用rnn,fcnn等。
其中基於cnn的函數也有兩套,分別爲tf.layers.conv2d和tf.nn.conv2d, tf.layers.conv2d使用tf.nn.conv2d做爲後端處理,參數上filters是整數,filter是4維張量。原型以下:java

convolutional.py文件

def conv2d(inputs, filters, kernel_size, strides=(1, 1), padding=’valid’, data_format=’channels_last’,
dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(), kernel_regularizer=None, bias_regularizer=None,
activity_regularizer=None, kernel_constraint=None, bias_constraint=None, trainable=True, name=None,
reuse=None)python

gen_nn_ops.py 文件

def conv2d(input, filter, strides, padding, use_cudnn_on_gpu=True, data_format="NHWC", name=None)
官方Demo實例中使用的是layers module,結構以下:android

  • Convolutional Layer #1:32個5×5的filter,使用ReLU激活函數
  • Pooling Layer #1:2×2的filter作max pooling,步長爲2
  • Convolutional Layer #2:64個5×5的filter,使用ReLU激活函數
  • Pooling Layer #2:2×2的filter作max pooling,步長爲2
  • Dense Layer #1:1024個神經元,使用ReLU激活函數,dropout率0.4
    (爲了不過擬合,在訓練的時候,40%的神經元會被隨機去掉)
  • Dense Layer #2 (Logits Layer):10個神經元,每一個神經元對應一個類別(0-9)

核心代碼在cnn_model_fn(features, labels, mode)函數中,完成卷積結構的完整定義,核心代碼以下:
這裏寫圖片描述git

也能夠採用傳統的tf.nn.conv2d函數, 核心代碼以下:
這裏寫圖片描述後端

測試

  • 核心是使用API接口: TensorFlowInferenceInterface.java
  • 配置gradle 或者 自編譯TensorFlow源碼導入jar和so compile
    ‘org.tensorflow:tensorflow-android:1.2.0’
  • 導入pb文件.pb文件放assets目錄,而後讀取
String actualFilename = labelFilename.split(「file:///android_asset/「)[1];
Log.i(TAG, 「Reading labels from: 「 + actualFilename);
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
String line;
while ((line = br.readLine()) != null) {
c.labels.add(line);
}
br.close();
  • TensorFlow接口使用以下:
相關文章
相關標籤/搜索