這是當微信小程序趕上TensorFlow系列文章的第四篇文章,閱讀本文,你將瞭解到:node
若是你想要了解更多關於本項目,能夠參考這個系列的前三篇文章:python
關於Tensorflow SavedModel格式模型的處理,能夠參考前面的文章:git
截至到目前爲止,咱們實現了一個簡單的微信小程序,使用開源的Simple TensorFlow Serving部署了服務端。但這種實現方案還存在一個重大問題:小程序和服務端通訊傳遞的圖像數據是(299, 299, 3)二進制數組的JSON化表示,這種二進制數據JSON化的最大缺點是數據量太大,一個簡單的299 x 299的圖像,這樣表示大約有3 ~ 4 M。其實HTTP傳輸二進制數據經常使用的方案是對二進制數據進行base64編碼,通過base64編碼,雖然數據量比二進制也會大一些,但相比JSON化的表示,仍是小不少。github
因此如今的問題是,如何讓服務器端接收base64編碼的圖像數據?web
爲了解決這一問題,咱們仍是先看看模型的輸入輸出,看看其簽名是怎樣的?這裏的簽名,並不是是爲了保證模型不被修改的那種電子簽名。個人理解是相似於編程語言中模塊的輸入輸出信息,好比函數名,輸入參數類型,輸出參數類型等等。藉助於Tensorflow提供的saved_model_cli.py工具,咱們能夠清楚的查看模型的簽名:編程
python ./tensorflow/python/tools/saved_model_cli.py show --dir /data/ai/workspace/aiexamples/AIDog/serving/models/inception_v3/ --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['image'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 299, 299, 3)
name: Placeholder:0
The given SavedModel SignatureDef contains the following output(s):
outputs['prediction'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 120)
name: final_result:0
Method name is: tensorflow/serving/predict
複製代碼
從中咱們能夠看出模型的輸入參數名爲image,其shape爲(-1, 299, 299, 3),這裏-1表明能夠批量輸入,一般咱們只輸入一張圖像,因此這個維度一般是1。輸出參數名爲prediction,其shape爲(-1, 120),-1和輸入是對應的,120表明120組狗類別的機率。json
如今的問題是,咱們可否在模型的輸入前面增長一層,進行base64及解碼處理呢?小程序
也許你認爲能夠在服務器端編寫一段代碼,進行base64字符串解碼,而後再轉交給Simple Tensorflow Serving進行處理,或者修改Simple TensorFlow Serving的處理邏輯,但這種修改方案增長了服務器端的工做量,使得服務器部署方案再也不通用,放棄!微信小程序
其實在上一篇文章《如何合併兩個TensorFlow模型》中咱們已經講到了如何鏈接兩個模型,這裏再稍微重複一下,首先是編寫一個base64解碼、png解碼、圖像縮放的模型:api
base64_str = tf.placeholder(tf.string, name='input_string')
input_str = tf.decode_base64(base64_str)
decoded_image = tf.image.decode_png(input_str, channels=input_depth)
# Convert from full range of uint8 to range [0,1] of float32.
decoded_image_as_float = tf.image.convert_image_dtype(decoded_image,
tf.float32)
decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
resize_shape = tf.stack([input_height, input_width])
resize_shape_as_int = tf.cast(resize_shape, dtype=tf.int32)
resized_image = tf.image.resize_bilinear(decoded_image_4d,
resize_shape_as_int)
tf.identity(resized_image, name="DecodePNGOutput")
複製代碼
接下來加載retrain模型:
with tf.Graph().as_default() as g2:
with tf.Session(graph=g2) as sess:
input_graph_def = saved_model_utils.get_meta_graph_def(
FLAGS.origin_model_dir, tag_constants.SERVING).graph_def
tf.saved_model.loader.load(sess, [tag_constants.SERVING], FLAGS.origin_model_dir)
g2def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
["final_result"],
variable_names_whitelist=None,
variable_names_blacklist=None)
複製代碼
這裏調用了graph_util.convert_variables_to_constants將模型中的變量轉化爲常量,也就是所謂的凍結圖(freeze graph)操做。
利用tf.import_graph_def方法,咱們能夠導入圖到現有圖中,注意第二個import_graph_def,其input是第一個graph_def的輸出,經過這樣的操做,就將兩個計算圖鏈接起來,最後保存起來。代碼以下:
with tf.Graph().as_default() as g_combined:
with tf.Session(graph=g_combined) as sess:
x = tf.placeholder(tf.string, name="base64_string")
y, = tf.import_graph_def(g1def, input_map={"input_string:0": x}, return_elements=["DecodePNGOutput:0"])
z, = tf.import_graph_def(g2def, input_map={"Placeholder:0": y}, return_elements=["final_result:0"])
tf.identity(z, "myOutput")
tf.saved_model.simple_save(sess,
FLAGS.model_dir,
inputs={"image": x},
outputs={"prediction": z})
複製代碼
若是你不知道retrain出來的模型的input節點是啥(注意不能使用模型部署的signature信息)?可使用以下代碼遍歷graph的節點名稱:
for n in g2def.node:
print(n.name)
複製代碼
注意,咱們能夠將鏈接以後的模型保存在./models/inception_v3/2/目錄下,原來的./models/inception_v3/1/也不用刪除,這樣兩個版本的模型能夠同時提供服務,方便從V1模型平滑過渡到V2版本模型。
咱們修改一下原來的test_client.py代碼,增長一個model_version參數,這樣就能夠決定與哪一個版本的模型進行通訊:
with open(file_name, "rb") as image_file:
encoded_string = str(base64.urlsafe_b64encode(image_file.read()), "utf-8")
if enable_ssl :
endpoint = "https://127.0.0.1:8500"
else:
endpoint = "http://127.0.0.1:8500"
json_data = {"model_name": model_name,
"model_version": model_version,
"data": {"image": encoded_string}
}
result = requests.post(endpoint, json=json_data)
複製代碼
通過一個多星期的研究和反覆嘗試,終於解決了圖像數據的base64編碼通訊問題。難點在於雖然模型是編寫retrain腳本從新訓練的,但這段代碼不是那麼好懂,想要在retrain時增長輸入層也是嘗試失敗。最後從Tensorflow模型轉Tensorflow Lite模型時的freezing graph獲得靈感,將圖中的變量固化爲常量,才解決了合併模型變量加載的問題。雖然網上提供了一些恢復變量的方法,但實際用起來並無論用,多是Tensorflow發展太快,之前的一些方法已通過時了。
本文的完整代碼請參閱:github.com/mogoweb/aie…
點擊閱讀原文能夠直達在github上的項目。
到目前爲止,關鍵的問題已經都解決,接下來就須要繼續完善微信小程序的展示,以及如何提供識別率,敬請關注個人微信公衆號:雲水木石,獲取最新動態。