導出pb模型以後測試的python代碼

連接:https://blog.csdn.net/thriving_fcl/article/details/75213361git

 

saved_model模塊主要用於TensorFlow Serving。TF Serving是一個將訓練好的模型部署至生產環境的系統,主要的優勢在於能夠保持Server端與API不變的狀況下,部署新的算法或進行試驗,同時還有很高的性能。

保持Server端與API不變有什麼好處呢?有不少好處,我只從我體會的一個方面舉例子說明一下,好比咱們須要部署一個文本分類模型,那麼輸入和輸出是能夠肯定的,輸入文本,輸出各種別的機率或類別標籤。爲了獲得較好的效果,咱們可能想嘗試不少不一樣的模型,CNN,RNN,RCNN等,這些模型訓練好保存下來之後,在inference階段須要從新載入這些模型,咱們但願的是inference的代碼有一份就好,也就是使用新模型的時候不須要針對新模型來修改inference的代碼。這應該如何實現呢?

在TensorFlow 模型保存/載入的兩種方法中總結過。
1. 僅用Saver來保存/載入變量。這個方法顯然不行,僅保存變量就必須在inference的時候從新定義Graph(定義模型),這樣不一樣的模型代碼確定要修改。即便同一種模型,參數變化了,也須要在代碼中有所體現,至少須要一個配置文件來同步,這樣就很繁瑣了。
2. 使用tf.train.import_meta_graph導入graph信息並建立Saver, 再使用Saver restore變量。相比第一種,不須要從新定義模型,可是爲了從graph中找到輸入輸出的tensor,仍是得用graph.get_tensor_by_name()來獲取,也就是還須要知道在定義模型階段所賦予這些tensor的名字。若是建立各模型的代碼都是同一我的完成的,還相對好控制,強制這些輸入輸出的命名都一致便可。若是是不一樣的開發者,要在建立模型階段就強制tensor的命名一致就比較困難了。這樣就不得再也不維護一個配置文件,將須要獲取的tensor名稱寫入,而後從配置文件中讀取該參數。

通過上面的分析發現,要實現inference的代碼統一,使用原來的方法也是能夠的,只不過TensorFlow官方提供了更好的方法,而且這個方法不只僅是解決這個問題,因此仍是得學習使用saved_model這個模塊。
saved_model 保存/載入模型

先列出會用到的API

class tf.saved_model.builder.SavedModelBuilder

# 初始化方法
__init__(export_dir)

# 導入graph與變量信息
add_meta_graph_and_variables(
    sess,
    tags,
    signature_def_map=None,
    assets_collection=None,
    legacy_init_op=None,
    clear_devices=False,
    main_op=None
)

# 載入保存好的模型
tf.saved_model.loader.load(
    sess,
    tags,
    export_dir,
    **saver_kwargs
)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23

(1) 最簡單的場景,只是保存/載入模型
保存

要保存一個已經訓練好的模型,使用下面三行代碼就能夠了。

builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
builder.add_meta_graph_and_variables(sess, ['tag_string'])
builder.save()

    1
    2
    3

首先構造SavedModelBuilder對象,初始化方法只須要傳入用於保存模型的目錄名,目錄不用預先建立。

add_meta_graph_and_variables方法導入graph的信息以及變量,這個方法假設變量都已經初始化好了,對於每一個SavedModelBuilder這個方法必定要執行一次用於導入第一個meta graph。

第一個參數傳入當前的session,包含了graph的結構與全部變量。

第二個參數是給當前須要保存的meta graph一個標籤,標籤名能夠自定義,在以後載入模型的時候,須要根據這個標籤名去查找對應的MetaGraphDef,找不到就會報如RuntimeError: MetaGraphDef associated with tags 'foo' could not be found in SavedModel這樣的錯。標籤也能夠選用系統定義好的參數,如tf.saved_model.tag_constants.SERVING與tf.saved_model.tag_constants.TRAINING。

save方法就是將模型序列化到指定目錄底下。

保存好之後到saved_model_dir目錄下,會有一個saved_model.pb文件以及variables文件夾。顧名思義,variables保存全部變量,saved_model.pb用於保存模型結構等信息。
載入

使用tf.saved_model.loader.load方法就能夠載入模型。如

meta_graph_def = tf.saved_model.loader.load(sess, ['tag_string'], saved_model_dir)

    1

第一個參數就是當前的session,第二個參數是在保存的時候定義的meta graph的標籤,標籤一致才能找到對應的meta graph。第三個參數就是模型保存的目錄。

load完之後,也是從sess對應的graph中獲取須要的tensor來inference。如

x = sess.graph.get_tensor_by_name('input_x:0')
y = sess.graph.get_tensor_by_name('predict_y:0')

# 實際的待inference的樣本
_x = ...
sess.run(y, feed_dict={x: _x})

    1
    2
    3
    4
    5
    6

這樣和以前的第二種方法同樣,也是要知道tensor的name。那麼如何能夠在不知道tensor name的狀況下使用呢? 那就須要給add_meta_graph_and_variables方法傳入第三個參數,signature_def_map。
(2) 使用SignatureDef

關於SignatureDef個人理解是,它定義了一些協議,對咱們所需的信息進行封裝,咱們根據這套協議來獲取信息,從而實現建立與使用模型的解耦。SignatureDef的結構以及相關詳細的文檔在:https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/signature_defs.md

相關API

# 構建signature
tf.saved_model.signature_def_utils.build_signature_def(
    inputs=None,
    outputs=None,
    method_name=None
)

# 構建tensor info
tf.saved_model.utils.build_tensor_info(tensor)

    1
    2
    3
    4
    5
    6
    7
    8
    9

SignatureDef,將輸入輸出tensor的信息都進行了封裝,而且給他們一個自定義的別名,因此在構建模型的階段,能夠隨便給tensor命名,只要在保存訓練好的模型的時候,在SignatureDef中給出統一的別名便可。

TensorFlow的關於這部分的例子中用到了很多signature_constants,這些constants的用處主要是提供了一個方便統一的命名。在咱們本身理解SignatureDef的做用的時候,能夠先不用管這些,遇到須要命名的時候,想怎麼寫怎麼寫。
保存

假設定義模型輸入的別名爲「input_x」,輸出的別名爲「output」 ,使用SignatureDef的代碼以下

builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
# x 爲輸入tensor, keep_prob爲dropout的prob tensor
inputs = {'input_x': tf.saved_model.utils.build_tensor_info(x),
            'keep_prob': tf.saved_model.utils.build_tensor_info(keep_prob)}

# y 爲最終須要的輸出結果tensor
outputs = {'output' : tf.saved_model.utils.build_tensor_info(y)}

signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')

builder.add_meta_graph_and_variables(sess, ['test_saved_model'], {'test_signature':signature})
builder.save()

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12

上述inputs增長一個keep_prob是爲了說明inputs能夠有多個, build_tensor_info方法將tensor相關的信息序列化爲TensorInfo protocol buffer。

inputs,outputs都是dict,key是咱們約定的輸入輸出別名,value就是對具體tensor包裝獲得的TensorInfo。

而後使用build_signature_def方法構建SignatureDef,第三個參數method_name暫時先隨便給一個。

建立好的SignatureDef是用在add_meta_graph_and_variables的第三個參數signature_def_map中,但不是直接傳入SignatureDef對象。事實上signature_def_map接收的是一個dict,key是咱們本身命名的signature名稱,value是SignatureDef對象。
載入

載入與使用的代碼以下


## 略去構建sess的代碼

signature_key = 'test_signature'
input_key = 'input_x'
output_key = 'output'

meta_graph_def = tf.saved_model.loader.load(sess, ['test_saved_model'], saved_model_dir)
# 從meta_graph_def中取出SignatureDef對象
signature = meta_graph_def.signature_def

# 從signature中找出具體輸入輸出的tensor name
x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name

# 獲取tensor 並inference
x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)

# _x 實際輸入待inference的data
sess.run(y, feed_dict={x:_x})

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21

從上面兩段代碼能夠知道,咱們只須要約定好輸入輸出的別名,在保存模型的時候使用這些別名建立signature,輸入輸出tensor的具體名稱已經徹底隱藏,這就實現建立模型與使用模型的解耦。
---------------------
做者:thriving_fcl
來源:CSDN
原文:https://blog.csdn.net/thriving_fcl/article/details/75213361?utm_source=copy
版權聲明:本文爲博主原創文章,轉載請附上博文連接!github

相關文章
相關標籤/搜索