[譯] 如何在安卓應用中使用 TensorFlow Mobile

TensorFlow 是當今最流行的機器學習框架之一,您利用它能夠輕鬆建立和訓練深度模型 —— 一般也稱爲深度前饋神經網絡,這些模型能夠解決各類複雜問題,如圖像分類、目標檢測和天然語言理解。TensorFlow Mobile 是一個旨在幫助您在移動應用中利用這些模型的庫。html

在本教程中,我將向您展現如何在 Android Studio 項目中使用 TensorFlow Mobile。前端

前期準備

爲了可以跟上教程,您須要作的是:python

  • Android Studio 3.0 或更高版本
  • TensorFlow 1.5.0 或更高版本
  • 一臺可以運行 API level 21 或更高的安卓設備
  • 以及對 TensorFlow 框架的基本瞭解

一、建立模型

在咱們開始使用 TensorFlow Mobile 以前,咱們須要一個已經訓練好的 TensorFlow 模型。咱們如今建立一個。android

咱們的模型將很是基礎,相似於異或門,接受兩個輸入,它們能夠是零或一,而後有一個輸出。若是兩個輸入相同,則輸出爲零。此外,由於它將是一個深度模型,它將有兩個隱藏層,一個有四個神經元,另外一個有三個神經元。您能夠自由改變隱藏層的數量以及它們包含的神經元的數量。ios

爲了保持本教程的簡潔,咱們將使用 TFLearn,這是一個很受歡迎的 TensorFlow 封裝框架,它提供更加直接而簡潔的 API,而不是直接使用低級別的 TensorFlow API。若是您還沒安裝它,請使用如下命令將其安裝在 TensorFlow 虛擬環境中:git

pip install tflearn
複製代碼

要開始建立模型,最好在空目錄中先新建一個名爲 create_model.py 的 Python 腳本,而後使用您最喜歡的文本編輯器打開它。github

在文件裏,咱們須要作的第一件事是導入 TFLearn API。後端

import tflearn
複製代碼

接下來,咱們必須建立訓練數據。對於咱們的簡單模型,只有四種可能的輸入和輸出,相似於異或門真值表的內容。數組

X = [
    [0, 0],
    [0, 1],
    [1, 0],
    [1, 1]
]
 
Y = [
    [0],  # Desired output for inputs 0, 0
    [1],  # Desired output for inputs 0, 1
    [1],  # Desired output for inputs 1, 0
    [0]   # Desired output for inputs 1, 1
]
複製代碼

爲隱藏層中的全部神經元分配初始權重時,最好的作法一般是使用從均勻分佈中產生的隨機數。可使用 uniform() 方法生成這些值。bash

weights = tflearn.initializations.uniform(minval = -1, maxval = 1)
複製代碼

此時,咱們能夠開始構建神經網絡層。要建立輸入層,咱們必須使用 input_data() 方法,它容許咱們指定網絡能夠接受的輸入數量。一旦輸入層準備就緒,咱們能夠屢次調用 fully_connected() 方法來向網絡添加更多層。

# 輸入層
net = tflearn.input_data(
        shape = [None, 2],
        name = 'my_input'
)
 
# 隱藏層
net = tflearn.fully_connected(net, 4,
        activation = 'sigmoid',
        weights_init = weights
)
net = tflearn.fully_connected(net, 3,
        activation = 'sigmoid',
        weights_init = weights
)
 
# 輸出層
net = tflearn.fully_connected(net, 1,
        activation = 'sigmoid', 
        weights_init = weights,
        name = 'my_output'
)
複製代碼

注意,在上面的代碼中,咱們賦予了輸入層和輸出層有意義的名稱。這麼作很重要,由於咱們在使用安卓應用中的網絡時須要它們。還要注意隱藏層和輸出層使用了 sigmoid 激活函數。您能夠試試其餘激活函數,例如 softmaxtanhrelu

做爲咱們網絡的最後一層,咱們必須使用 regression() 函數建立一個迴歸層,該函數須要一些超參數做爲其參數,例如網絡的學習率以及它應該使用的優化器和損失函數。如下代碼向您展現瞭如何使用隨機梯度降低(簡稱 SGD)做爲優化器函數,均方偏差做爲損失函數:

net = tflearn.regression(net,
        learning_rate = 2,
        optimizer = 'sgd',
        loss = 'mean_square'
)
複製代碼

接下來,爲了讓 TFLearn 框架知道咱們的網絡模型其實是一個深度神經網絡模型,咱們需要調用 DNN() 函數。

model = tflearn.DNN(net)
複製代碼

模型如今已經準備好了。咱們如今要作的就是使用咱們以前建立的訓練數據進行訓練。所以,調用模型的 fit() 方法,並指定訓練數據與訓練週期。因爲訓練數據很是小,咱們的模型將須要數千次迭代才能達到合理的精度。

model.fit(X, Y, 5000)
複製代碼

一旦訓練完成,咱們能夠調用模型的 predict() 方法來檢查它是否生成指望的輸出。如下代碼展現瞭如何檢查全部有效輸入的輸出:

print("1 XOR 0 = %f" % model.predict([[1,0]]).item(0))
print("1 XOR 1 = %f" % model.predict([[1,1]]).item(0))
print("0 XOR 1 = %f" % model.predict([[0,1]]).item(0))
print("0 XOR 0 = %f" % model.predict([[0,0]]).item(0))
複製代碼

若是如今運行 Python 腳本,您應該看到以下所示的輸出:

訓練後的預測結果

請注意,輸出不會徹底是 0 或 1。而是接近 0 或 1 的浮點數。所以,在使用輸出時,可能須要使用 Python 的 round() 函數。

除非咱們在訓練後明確保存模型,不然只要程序結束,咱們就會失去模型。幸運的是,對於 TFLearn,只需調用 save() 方法便可保存模型。可是,爲了可以在 TensorFlow Mobile 中使用保存的模型,在保存以前,咱們必須確保移除全部訓練相關的操做。這些操做都在 tf.GraphKeys.TRAIN_OPS 集合中。如下代碼展現了怎麼去移除相關操做:

# 移除訓練相關的操做
with net.graph.as_default():
    del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
 
# 保存模型
model.save('xor.tflearn')
複製代碼

若是您再次運行該腳本,您會發現它會生成檢查點文件、元數據文件、索引文件和數據文件,全部這些文件一塊兒使用時能夠快速重建咱們訓練好的模型。

二、固化模型

除了保存模型外,咱們還必須先固化模型,而後才能將其與 TensorFlow Mobile 配合使用。正如您可能已經猜到的那樣,固化模型的過程涉及將其全部變量轉換爲常量。此外,固化模型必須是符合 Google Protocol Buffers 序列化格式的單個二進制文件。

新建一個名爲 freeze_model.py 的 Python 腳本,並使用文本編輯器打開它。咱們將在這個文件中編寫固化的模型代碼來。

因爲 TFLearn 沒有任何固化模型的功能,咱們如今必須直接使用 TensorFlow API。經過將如下行添加到文件來導入它們:

import tensorflow as tf
複製代碼

整個腳本里面,咱們將使用單個 TensorFlow 會話。咱們使用 Session 類的構造函數建立會話。

with tf.Session() as session:
    # 代碼的其餘部分在這
複製代碼

此時,咱們必須經過調用 import_meta_graph() 函數並將模型的元數據文件的名稱傳遞給它來建立 Saver 對象,除了返回 Saver 對象外,import_meta_graph() 函數還會自動將模型的圖定義添加到會話的圖定義中。

一旦建立了保存器(saver),咱們能夠經過調用 restore() 方法來初始化圖定義中存在的全部變量,該方法須要包含模型最新檢查點文件的目錄路徑。

my_saver = tf.train.import_meta_graph('xor.tflearn.meta')
my_saver.restore(session, tf.train.latest_checkpoint('.'))
複製代碼

此時,咱們能夠調用 convert_variables_to_constants() 函數來建立一個固化的圖定義,其中模型的全部變量都替換成常量。做爲其輸入,函數須要當前會話、當前會話的圖定義以及包含模型輸出層名稱的列表。

frozen_graph = tf.graph_util.convert_variables_to_constants(
    session,
    session.graph_def,
    ['my_output/Sigmoid']
)
複製代碼

調用固化圖定義的 SerializeToString() 方法爲咱們提供了模型的二進制 protobuf 表示。經過使用 Python 基本的文件 I/O,我建議您把它保存爲一個名爲 frozen_model.pb 的文件。

with open('frozen_model.pb', 'wb') as f:
    f.write(frozen_graph.SerializeToString())
複製代碼

如今能夠運行腳原本生成固化模型。

咱們如今擁有開始使用 TensorFlow Mobile 所需的一切。

三、Android Studio 項目設置

TensorFlow Mobile 庫可在 JCenter 上使用,因此咱們能夠直接將它添加爲 app 模塊 build.gradle 文件中的 implementation 依賴項。

implementation 'org.tensorflow:tensorflow-android:1.7.0'
複製代碼

要把固化的模型添加到項目中,請將 frozen_model.pb 文件放置到項目的 assets 文件夾中。

四、初始化 TensorFlow 接口

TensorFlow Mobile 提供了一個簡單的接口,咱們可使用它與咱們的固化模型進行交互。要建立接口,請使用 TensorFlowInferenceInterface 類的構造函數,該類須要一個 AssetManager 實例和固化模型的文件名。

thread {
    val tfInterface = TensorFlowInferenceInterface(assets,
                                        "frozen_model.pb")
     
    // More code here
}
複製代碼

在上面的代碼中,您能夠看到咱們正在產生一個新的線程。這是爲了確保應用的 UI 保持響應,雖然沒必要要,但建議這樣作。

爲了保證 TensorFlow Mobile 可以正確讀取咱們模型的文件,如今讓咱們嘗試打印模型圖中全部操做的名稱。爲了獲得對圖的引用,咱們可使用接口的 graph() 方法,並獲取全部操做,即圖的 operations() 方法。如下代碼告訴您該怎麼作:

val graph = tfInterface.graph()
graph.operations().forEach {
    println(it.name())
}
複製代碼

若是如今運行該應用,則應該可以看到在 Android Studio 的 Logcat 窗口中打印的十幾個操做名稱。若是固化模型時沒有出錯,咱們能夠在這些名稱中找到輸入和輸出層的名稱:my_input/Xmy_output/Sigmoid

Logcat 窗口展現了操做列表

五、使用模型

爲了用模型進行預測,咱們將數據輸入到輸入層,在輸出層獲得數據。將數據輸入到輸入層須要使用接口的 feed() 方法,該方法須要輸入層的名稱、含有輸入數據的數組以及數組的維數。如下代碼展現如何將數字 01 輸入到輸入層:

tfInterface.feed("my_input/X",
            floatArrayOf(0f, 1f), 1, 2)
複製代碼

數據加載到輸入層後,咱們必須使用 run() 方法進行推斷操做,該方法須要輸出層的名稱。一旦操做完成,輸出層將包含模型的預測。爲了將預測結果加載到 Kotlin 數組中,咱們可使用 fetch() 方法。如下代碼顯示瞭如何執行此操做:

tfInterface.run(arrayOf("my_output/Sigmoid"))
 
val output = floatArrayOf(-1f)
tfInterface.fetch("my_output/Sigmoid", output)
複製代碼

您如今能夠運行該應用來查看模型的預測是否正確。

Logcat window displaying the prediction

能夠更改輸入到輸入層的數字,以確認模型的預測始終正確。

總結

您如今知道如何建立一個簡單的 TensorFlow 模型以及在安卓應用上經過 TensorFlow Mobile 去使用該模型。不過沒必要拘泥於本身的模型,用您今天學到的東西,使用更大的模型對您來講應該沒有任何問題。例如 MobileNet 以及 Inception,這些均可以在 TensorFlow 的 模型園 裏找到。可是請注意,這些模型會使 APK 更大,從而給使用低端設備的用戶形成問題。

要了解有關 TensorFlow Mobile 的更多信息,請參閱 官方文檔.


掘金翻譯計劃 是一個翻譯優質互聯網技術文章的社區,文章來源爲 掘金 上的英文分享文章。內容覆蓋 AndroidiOS前端後端區塊鏈產品設計人工智能等領域,想要查看更多優質譯文請持續關注 掘金翻譯計劃官方微博知乎專欄

相關文章
相關標籤/搜索