訓練和評估部分主要目的是生成用於測試用的pb文件,其保存了利用TensorFlow python API構建訓練後的網絡拓撲結構和參數信息,實現方式有不少種,除了cnn外還能夠使用rnn,fcnn等。
其中基於cnn的函數也有兩套,分別爲tf.layers.conv2d和tf.nn.conv2d, tf.layers.conv2d使用tf.nn.conv2d做爲後端處理,參數上filters是整數,filter是4維張量。原型以下:java
def conv2d(inputs, filters, kernel_size, strides=(1, 1), padding=’valid’, data_format=’channels_last’,
dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(), kernel_regularizer=None, bias_regularizer=None,
activity_regularizer=None, kernel_constraint=None, bias_constraint=None, trainable=True, name=None,
reuse=None)python
def conv2d(input, filter, strides, padding, use_cudnn_on_gpu=True, data_format="NHWC", name=None)
官方Demo實例中使用的是layers module,結構以下:android
核心代碼在cnn_model_fn(features, labels, mode)函數中,完成卷積結構的完整定義,核心代碼以下:
git
也能夠採用傳統的tf.nn.conv2d函數, 核心代碼以下:
後端
String actualFilename = labelFilename.split(「file:///android_asset/「)[1]; Log.i(TAG, 「Reading labels from: 「 + actualFilename); BufferedReader br = null; br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename))); String line; while ((line = br.readLine()) != null) { c.labels.add(line); } br.close();