這是Tensorflow SavedModel模型系列文章的第三篇,也是終章。在《Tensorflow SavedModel模型的保存與加載》中,咱們談到了Tensorflow模型如何保存爲SavedModel格式,以及如何加載之。在《如何查看tensorflow SavedModel格式模型的信息》中,咱們演示瞭如何查看模型的signature和計算圖結構。在本文中,咱們將探討如何合併兩個模型,簡單的說,就是將第一個模型的輸出,做爲第二個模型的輸入,串聯起來造成一個新模型。python
爲何須要合併兩個模型?git
咱們仍是以《Tensorflow SavedModel模型的保存與加載》中的代碼爲例,這個手寫數字識別模型接收的輸入是shape爲[?, 784],這裏?表明能夠批量接收輸入,能夠先忽略,就把它固定爲1吧。784是28 x 28進行展開的結果,也就是28 x 28灰度圖像展開的結果。github
問題是,咱們送給模型的一般是圖片,可能來自文件、可能來自攝像頭。讓問題變得複雜的是,若是咱們經過HTTP來調用部署到服務器端的模型,二進制數據其實是不方便HTTP傳輸的,這時咱們一般須要對圖像數據進行base64編碼。這樣服務器端接收到的數據是一個base64字符串,可模型接受的是二進制向量。web
很天然的,咱們能夠想到兩種解決方法:小程序
從新訓練模型一個接收base64字符串的模型。微信小程序
這種解決方法的問題在於:從新訓練模型很費時,甚至不可行。本文示例由於比較簡單,從新訓練也沒啥。若是是那種很深的卷積神經網絡,訓練一次可能須要好幾天,從新訓練代價很大。更廣泛的狀況是,咱們使用的是別人訓練好的模型,好比圖像識別中廣泛使用的Mobilenet、InceptionV3等等,都是Google、微軟這樣的公司,耗費大量的資源訓練出來的,咱們沒有那個條件從新訓練。數組
在服務器端增長base64到二進制數據的轉換bash
這種解決方法實現起來不復雜,但若是咱們使用的是Tensorflow model server之類的方案部署的呢?固然咱們也能夠再開啓一個server,來接受客戶端的base64圖像數據,處理完畢以後再轉發給Tensorflow model server,但這無疑增長了服務端的工做量,增長了服務端的複雜性。服務器
在本文,咱們將給出第三種方案:編寫一個Tensorflow模型,接收base64的圖像數據,輸出二進制向量,而後將第一個模型的輸出做爲第二個模型的輸入,串接起來,保存爲一個新的模型,最後部署新的模型。微信
Tensorflow包含了大量圖像處理和數組處理的方法,因此實現這個模型比較簡單,模型包含了base64解碼、解碼PNG圖像、縮放到28 * 2八、最後展開爲(1, 784)的數組輸出,符合手寫數字識別模型的輸入,代碼以下:
with tf.Graph().as_default() as g1:
base64_str = tf.placeholder(tf.string, name='input_string')
input_str = tf.decode_base64(base64_str)
decoded_image = tf.image.decode_png(input_str, channels=1)
# Convert from full range of uint8 to range [0,1] of float32.
decoded_image_as_float = tf.image.convert_image_dtype(decoded_image,
tf.float32)
decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
resize_shape = tf.stack([28, 28])
resize_shape_as_int = tf.cast(resize_shape, dtype=tf.int32)
resized_image = tf.image.resize_bilinear(decoded_image_4d,
resize_shape_as_int)
# 展開爲1維數組
resized_image_1d = tf.reshape(resized_image, (-1, 28 * 28))
print(resized_image_1d.shape)
tf.identity(resized_image_1d, name="DecodeJPGOutput")
g1def = g1.as_graph_def()
複製代碼
在該模型中,並不存在變量,都是一些固定的操做,因此無需進行訓練。
手寫識別模型參考《Tensorflow SavedModel模型的保存與加載》一文,模型保存在 "./model" 下,加載代碼以下:
with tf.Graph().as_default() as g2:
with tf.Session(graph=g2) as sess:
input_graph_def = saved_model_utils.get_meta_graph_def(
"./model", tag_constants.SERVING).graph_def
tf.saved_model.loader.load(sess, ["serve"], "./model")
g2def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
["myOutput"],
variable_names_whitelist=None,
variable_names_blacklist=None)
複製代碼
這裏使用了g2定義了另一個graph,和前面的模型的graph區分開來。注意這裏調用了graph_util.convert_variables_to_constants將模型中的變量轉化爲常量,也就是所謂的凍結圖(freeze graph)操做。
在研究如何鏈接兩個模型時,我在這個問題上卡了好久。先的想法是合併模型以後,再加載變量值進來,可是嘗試以後,怎麼也不成功。後來的想法是遍歷手寫識別模型的變量,獲取其變量值,將變量值複製到合併的模型的變量,但這樣操做,使用模型時,老是提示有變量未初始化。
最後從Tensorflow模型到Tensorflow lite模型轉換中得到了靈感,將模型中的變量固定下來,這樣就不存在變量的加載問題,也不會出現模型變量未初始化的問題。
執行convert_variables_to_constants後,能夠看到有兩個變量轉化爲了常量操做,也就是手寫數字識別模型中的w和b:
Converted 2 variables to const ops.
複製代碼
利用tf.import_graph_def方法,咱們能夠導入圖到現有圖中,注意第二個import_graph_def,其input是第一個graph_def的輸出,經過這樣的操做,就將兩個計算圖鏈接起來,最後保存起來。代碼以下:
with tf.Graph().as_default() as g_combined:
with tf.Session(graph=g_combined) as sess:
x = tf.placeholder(tf.string, name="base64_input")
y, = tf.import_graph_def(g1def, input_map={"input_string:0": x}, return_elements=["DecodeJPGOutput:0"])
z, = tf.import_graph_def(g2def, input_map={"myInput:0": y}, return_elements=["myOutput:0"])
tf.identity(z, "myOutput")
tf.saved_model.simple_save(sess,
"./modelbase64",
inputs={"base64_input": x},
outputs={"myOutput": z})
複製代碼
由於第一個模型不包含變量,第二個模型的變量轉化爲了常量操做,因此最後保存的模型文件並不包含變量:
modelbase64/
├── saved_model.pb
└── variables
1 directory, 1 file
複製代碼
咱們寫一段測試代碼,測試一下合併以後模型是否管用,代碼以下:
with tf.Session(graph=tf.Graph()) as sess:
sess.run(tf.global_variables_initializer())
tf.saved_model.loader.load(sess, ["serve"], "./modelbase64")
graph = tf.get_default_graph()
with open("./5.png", "rb") as image_file:
encoded_string = str(base64.urlsafe_b64encode(image_file.read()), "utf-8")
x = sess.graph.get_tensor_by_name('base64_input:0')
y = sess.graph.get_tensor_by_name('myOutput:0')
scores = sess.run(y,
feed_dict={x: encoded_string})
print("predict: %d, actual: %d" % (np.argmax(scores, 1), 5))
複製代碼
這裏模型的輸入爲base64_input,輸出仍然是myOutput,使用兩個圖片測試,均工做正常。
最近三篇文章其實都是在研究個人微信小程序時總結的,爲了更好的說明問題,我使用了一個很是簡單的模型來講明問題,但一樣適用於複雜的模型。
本文的完整代碼請參考:github.com/mogoweb/aie…
但願這篇文章對您有幫助,感謝閱讀!同時敬請關注個人微信公衆號:雲水木石。