將TensorFlow訓練的模型移植到Android手機

前言

本文中出現的TF皆爲TensorFlow的簡稱。java

先說兩句題外話吧,TensorFlow 前兩天熱熱鬧鬧的發佈了正式版r1.0,可感受本身才剛剛上手 r0.12,這個時代發展的太快,腳步是一刻也不能停啊~node

可是不得不吐槽 TensorFlow的向下兼容作的實在不太友好,每次更新完版本,之前的代碼就跑不動,各類提示您使用的函數已經不存在。。。python

代碼積攢的愈來愈多,所有針對新版本翻改一遍,工程真是浩大。可是喜新厭舊,手賤如我,每次都忍不住點了更新。不過此次忍的還算不錯,到目前還沒更新,繼續忍住android

在以前的文章中,我介紹瞭如何實現 TensorFlow官網的Mobile教程: 
【將Tensorflow移植到安卓手機,實現物體識別、行人檢測和圖像風格遷移】。 
但在那個教程中,TensorFlow提供了完整的、已經構建好的Android項目,咱們須要作的總結下來只有3步:一、搭建環境;二、編譯;三、安裝到手機git

這固然還不夠,咱們的最終目的固然是要爲我所用,因此怎樣才能移植本身訓練好的TF模型到安卓手機呢?換句話說,怎樣將訓練好的模型放入Android項目中並進行成功編譯?又或者怎樣建立本身的Android Tensorflow項目?github

PS: 
以前沒有安卓開發的經驗,純粹是爲了實現將TF模型移植到手機纔開始上手,目前屬於入門級小白,若有錯誤之處,歡迎批評指正!網絡

手機調用TF模型的過程簡介:

一、 保存訓練完畢的TF模型 
二、 在Android項目中導入TF模型、導入Android平臺調用TF模型須要的jar包和so文件 (它們負責TF模型的解析和運算) 
三、定義變量、存儲數據,經過jar包提供的接口進行模型的調用session

環境

TensorFlow版本: r0.12 
Python 版本:2.7 
Python IDE: Spyder 
Android IDE : Android Studioapp

移植過程

咱們以mnist數據集上本身訓練的一個圖像識別模型爲例,進行講解函數

1、 在使用python代碼編寫的TF模型定義中爲模型的輸入層和輸出層Tensor Variable分別指定名字(經過形參 ‘name’)

X = tf.placeholder(tf.float32, shape = […], name=‘input’)  //網絡的輸入
Y = tf.nn.softmax(tf.matmul(f, out_weights) + out_biases, name=’output’)  //網絡的輸出
  •  

名字能夠隨便起,以方便好記爲主,後面還會反覆用到。我起的是input和output。

2、 將使用TensorFlow訓練好的模型保存爲.pb文件

在模型訓練結束後的代碼位置,添加下述兩句代碼,可將模型保存爲.pb文件

output_graph_def = tf.graph_until.convert_variables_to_constants(session, session.graph_def, output_node_names=[‘output’])
//形參output_node_names用於指定輸出的節點名稱
  •  

貼一個說明文檔,幫助你們進一步瞭解這個函數

這裏寫圖片描述

with tf.gfile.FastGFile(model\mnist.pb, mode = ’wb’) as f:
    f.write(output_graph_def.SerializeToString())
  •  

第一個參數用於指定輸出的文件存放路徑、文件名及格式。我把它放在與代碼同級目錄的model文件下,取名爲mnist.pb

第二個參數 mode用於指定文件操做的模式,’wb’中w表明寫文件,b表明將數據以二進制方式寫入文件。

若是不指明‘b’,則默認會以文本txt方式寫入文件。如今TF還不支持對文本格式.pb文件的解析,在調用時會出現報錯。

注: 
1)、不能使用 tf.train.write_graph()保存模型,由於它只是保存了模型的結構,並不保存訓練完畢的參數值 
2)、不能使用 tf.train.saver()保存模型,由於它只是保存了網絡中的參數值,並不保存模型的結構。 
很顯然,咱們須要的是既保存模型的結構,又保存模型中每一個參數的值。以上二者皆不符合。

3、生成在Android平臺上調用tensorflow 模型須要的jar包和so文件 
1) 從github下載TensorFlow的項目源碼

2) 安裝Bazel 
Bazel的安裝過程,我在另外一篇文章中有介紹,歡迎參閱 
Ubuntu14.04 源代碼安裝 TensorFlow r0.12 詳細教程

3) 參考以下圖的官方教程,生成Android上調用TF模型須要的so文件和jar包 
這裏寫圖片描述

4、安裝Android Studio,建立Android 項目

Android Studio安裝完畢後,還須要搭建環境。搭建過程可參考個人另外一篇文章:

Ubuntu 使用 Android Studio 編譯 TensorFlow android demo

5、添加資源到項目

1) 將(二)步生成的.pb文件放入項目中 
打開 Project view ,app/src/main/assets。 
若不存在assets目錄,右鍵main->new->folder->Assets Folder

2) 添加(三)步生成的jar包 
打開Project view,將jar包拷貝到app->libs下 
選中jar文件,右鍵 add as library

3) 添加(三)生成的so文件 
打開 Project view,將.so文件拷貝到 app/src/main/jniLibs下(jniLibs文件夾若沒有則新建)

若是我講的不太明白的話,可自行谷歌搜索「如何在 Android studio中添加引用 jar文件和so文件」

6、建立接口,實現調用

1) 導入jar包和so文件 
在須要調用模型的.java文件中,導入jar包:

import org.tensorflow.contrib.android.TensorFlowInferenceInterface
  •  

在該java類定義的首行,導入so文件:

{
    System.loadLibrary(「tensorflow_inference」)
}
  •  

2)定義變量及對象

private static final String MODEL_FILE = 「file:///android_asset/mnist.pb」   //模型存放路徑
private static final String INPUT_NODE = 「input」;       //模型中輸入變量的名稱
private static final String INPUT_NODE = 「output」;  //模型中輸出變量的名稱
private static final int NUM_CLASSES = 10;  //樣本集的類別數量,mnist數據集對應10

private static final int HEIGHT = 24;       //輸入圖片的像素高
private static final int WIDTH = 24;        //輸入圖片的像素寬
private static final int CHANNEL = 3;    //輸入圖片的通道數:RGB

private floats inputs = new float[HEIGHT*WIDTH*CHANNEL];    //用於存儲的模型輸入數據
private floats outputs = new float[NUM_CLASSES];    //用於存儲模型的輸出數據

2)Tensorflow 接口初始化

private TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface();   //接口定義
inferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE);  //接口初始化
  •  

在完成上述兩步以後,就能夠反覆調用模型。 
在每次調用前,先將待輸入的數據按順序存放進 inputs 變量中,而後執行下述三個語句。

3)TF模型的調用

inferenceInterface.fillNodeFloat(INPUT_NODE, new int[]{1, HEIGHT, WIDTH, CHANNEL}, inputs);  //送入輸入數據
inferenceInterface.runInference(new String[]{OUTPUT_NODE});     //進行模型的推理
inferenceInterface.readNodeFloat(OUTPUT_NODE, outputs); //獲取輸出數據
  •  

而後接下來的主要工做就是安卓項目的編譯以及將編譯完的apk文件安裝到手機,這部份內容與通常的安卓項目並沒有區別。這些內容在個人另外一篇文章中也有所說起:

Ubuntu 使用 Android Studio 編譯 TensorFlow android demo

爲了便於你們理解,我寫的代碼比較面向過程。固然放在java環境下,仍是要多多從面向對象的角度出發,合理的封裝,提升代碼的複用性。

相關文章
相關標籤/搜索