本文中出現的TF皆爲TensorFlow的簡稱。java
先說兩句題外話吧,TensorFlow 前兩天熱熱鬧鬧的發佈了正式版r1.0,可感受本身才剛剛上手 r0.12,這個時代發展的太快,腳步是一刻也不能停啊~node
可是不得不吐槽 TensorFlow的向下兼容作的實在不太友好,每次更新完版本,之前的代碼就跑不動,各類提示您使用的函數已經不存在。。。python
代碼積攢的愈來愈多,所有針對新版本翻改一遍,工程真是浩大。可是喜新厭舊,手賤如我,每次都忍不住點了更新。不過此次忍的還算不錯,到目前還沒更新,繼續忍住android
在以前的文章中,我介紹瞭如何實現 TensorFlow官網的Mobile教程:
【將Tensorflow移植到安卓手機,實現物體識別、行人檢測和圖像風格遷移】。
但在那個教程中,TensorFlow提供了完整的、已經構建好的Android項目,咱們須要作的總結下來只有3步:一、搭建環境;二、編譯;三、安裝到手機git
這固然還不夠,咱們的最終目的固然是要爲我所用,因此怎樣才能移植本身訓練好的TF模型到安卓手機呢?換句話說,怎樣將訓練好的模型放入Android項目中並進行成功編譯?又或者怎樣建立本身的Android Tensorflow項目?github
PS:
以前沒有安卓開發的經驗,純粹是爲了實現將TF模型移植到手機纔開始上手,目前屬於入門級小白,若有錯誤之處,歡迎批評指正!網絡
一、 保存訓練完畢的TF模型
二、 在Android項目中導入TF模型、導入Android平臺調用TF模型須要的jar包和so文件 (它們負責TF模型的解析和運算)
三、定義變量、存儲數據,經過jar包提供的接口進行模型的調用session
TensorFlow版本: r0.12
Python 版本:2.7
Python IDE: Spyder
Android IDE : Android Studioapp
咱們以mnist數據集上本身訓練的一個圖像識別模型爲例,進行講解函數
1、 在使用python代碼編寫的TF模型定義中爲模型的輸入層和輸出層Tensor Variable分別指定名字(經過形參 ‘name’)
X = tf.placeholder(tf.float32, shape = […], name=‘input’) //網絡的輸入 Y = tf.nn.softmax(tf.matmul(f, out_weights) + out_biases, name=’output’) //網絡的輸出
名字能夠隨便起,以方便好記爲主,後面還會反覆用到。我起的是input和output。
2、 將使用TensorFlow訓練好的模型保存爲.pb文件
在模型訓練結束後的代碼位置,添加下述兩句代碼,可將模型保存爲.pb文件
output_graph_def = tf.graph_until.convert_variables_to_constants(session, session.graph_def, output_node_names=[‘output’]) //形參output_node_names用於指定輸出的節點名稱
貼一個說明文檔,幫助你們進一步瞭解這個函數
with tf.gfile.FastGFile(model\mnist.pb, mode = ’wb’) as f: f.write(output_graph_def.SerializeToString())
第一個參數用於指定輸出的文件存放路徑、文件名及格式。我把它放在與代碼同級目錄的model文件下,取名爲mnist.pb
第二個參數 mode用於指定文件操做的模式,’wb’中w表明寫文件,b表明將數據以二進制方式寫入文件。
若是不指明‘b’,則默認會以文本txt方式寫入文件。如今TF還不支持對文本格式.pb文件的解析,在調用時會出現報錯。
注:
1)、不能使用 tf.train.write_graph()保存模型,由於它只是保存了模型的結構,並不保存訓練完畢的參數值
2)、不能使用 tf.train.saver()保存模型,由於它只是保存了網絡中的參數值,並不保存模型的結構。
很顯然,咱們須要的是既保存模型的結構,又保存模型中每一個參數的值。以上二者皆不符合。
3、生成在Android平臺上調用tensorflow 模型須要的jar包和so文件
1) 從github下載TensorFlow的項目源碼
2) 安裝Bazel
Bazel的安裝過程,我在另外一篇文章中有介紹,歡迎參閱
Ubuntu14.04 源代碼安裝 TensorFlow r0.12 詳細教程
3) 參考以下圖的官方教程,生成Android上調用TF模型須要的so文件和jar包
4、安裝Android Studio,建立Android 項目
Android Studio安裝完畢後,還須要搭建環境。搭建過程可參考個人另外一篇文章:
Ubuntu 使用 Android Studio 編譯 TensorFlow android demo
5、添加資源到項目
1) 將(二)步生成的.pb文件放入項目中
打開 Project view ,app/src/main/assets。
若不存在assets目錄,右鍵main->new->folder->Assets Folder
2) 添加(三)步生成的jar包
打開Project view,將jar包拷貝到app->libs下
選中jar文件,右鍵 add as library
3) 添加(三)生成的so文件
打開 Project view,將.so文件拷貝到 app/src/main/jniLibs下(jniLibs文件夾若沒有則新建)
若是我講的不太明白的話,可自行谷歌搜索「如何在 Android studio中添加引用 jar文件和so文件」
6、建立接口,實現調用
1) 導入jar包和so文件
在須要調用模型的.java文件中,導入jar包:
import org.tensorflow.contrib.android.TensorFlowInferenceInterface
在該java類定義的首行,導入so文件:
{ System.loadLibrary(「tensorflow_inference」) }
2)定義變量及對象
private static final String MODEL_FILE = 「file:///android_asset/mnist.pb」 //模型存放路徑 private static final String INPUT_NODE = 「input」; //模型中輸入變量的名稱 private static final String INPUT_NODE = 「output」; //模型中輸出變量的名稱 private static final int NUM_CLASSES = 10; //樣本集的類別數量,mnist數據集對應10 private static final int HEIGHT = 24; //輸入圖片的像素高 private static final int WIDTH = 24; //輸入圖片的像素寬 private static final int CHANNEL = 3; //輸入圖片的通道數:RGB private floats inputs = new float[HEIGHT*WIDTH*CHANNEL]; //用於存儲的模型輸入數據 private floats outputs = new float[NUM_CLASSES]; //用於存儲模型的輸出數據
2)Tensorflow 接口初始化
private TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(); //接口定義 inferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE); //接口初始化
在完成上述兩步以後,就能夠反覆調用模型。
在每次調用前,先將待輸入的數據按順序存放進 inputs 變量中,而後執行下述三個語句。
3)TF模型的調用
inferenceInterface.fillNodeFloat(INPUT_NODE, new int[]{1, HEIGHT, WIDTH, CHANNEL}, inputs); //送入輸入數據 inferenceInterface.runInference(new String[]{OUTPUT_NODE}); //進行模型的推理 inferenceInterface.readNodeFloat(OUTPUT_NODE, outputs); //獲取輸出數據
而後接下來的主要工做就是安卓項目的編譯以及將編譯完的apk文件安裝到手機,這部份內容與通常的安卓項目並沒有區別。這些內容在個人另外一篇文章中也有所說起:
Ubuntu 使用 Android Studio 編譯 TensorFlow android demo
爲了便於你們理解,我寫的代碼比較面向過程。固然放在java環境下,仍是要多多從面向對象的角度出發,合理的封裝,提升代碼的複用性。