如何查看tensorflow SavedModel格式模型的信息

在《Tensorflow SavedModel模型的保存與加載》一文中,咱們談到SavedModel格式的優勢是與語言無關、容易部署和加載。那問題來了,若是別人發佈了一個SavedModel模型,咱們該如何去了解這個模型,如何去加載和使用這個模型呢?python

理想的狀態是模型發佈者編寫出完備的文檔,給出示例代碼。但在不少狀況下,咱們只是獲得了訓練好的模型,而沒有齊全的文檔,這個時候咱們可否從模型自己上得到一些信息呢?好比模型的輸入輸出、模型的結構等等。git

答案是能夠的。github

查看模型的Signature簽名

這裏的簽名,並不是是爲了保證模型不被修改的那種電子簽名。個人理解是相似於編程語言中模塊的輸入輸出信息,好比函數名,輸入參數類型,輸出參數類型等等。咱們以《Tensorflow SavedModel模型的保存與加載》裏的代碼爲例,從語句:web

signature = predict_signature_def(inputs={'myInput': x},
                                  outputs={'myOutput': y})
複製代碼

咱們能夠看到模型的輸入名爲myInput,輸出名爲myOutput。若是咱們沒有源碼呢?編程

Tensorflow提供了一個工具,若是你下載了Tensorflow的源碼,能夠找到這樣一個文件,./tensorflow/python/tools/saved_model_cli.py,你能夠加上-h參數查看該腳本的幫助信息:瀏覽器

usage: saved_model_cli.py [-h] [-v] {show,run,scan} ...

saved_model_cli: Command-line interface for SavedModel

optional arguments:
  -h, --help       show this help message and exit
  -v, --version    show program's version number and exit commands: valid commands {show,run,scan} additional help 複製代碼

指定SavedModel模所在的位置,咱們就能夠顯示SavedModel的模型信息:bash

python $TENSORFLOW_DIR/tensorflow/python/tools/saved_model_cli.py show --dir ./model/ --all
複製代碼

結果爲:編程語言

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['predict']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['myInput'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 784)
        name: myInput:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['myOutput'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 10)
        name: Softmax:0
  Method name is: tensorflow/serving/predict
複製代碼

從這裏咱們能夠清楚的看到模型的輸入/輸出的名稱、數據類型、shape以及方法名稱。有了這些信息,咱們就能夠很容易寫出推斷方法。函數

查看模型的計算圖

瞭解tensflow的人可能知道TensorBoard是一個很是強大的工具,可以顯示不少模型信息,其中包括計算圖。問題是,TensorBoard須要模型訓練時的log,若是這個SavedModel模型是別人訓練好的呢?辦法也不是沒有,咱們能夠寫一段代碼,加載這個模型,而後輸出summary info,代碼以下:工具

import tensorflow as tf
import sys
from tensorflow.python.platform import gfile

from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

with tf.Session() as sess:
  model_filename ='./model/saved_model.pb'
  with gfile.FastGFile(model_filename, 'rb') as f:

    data = compat.as_bytes(f.read())
    sm = saved_model_pb2.SavedModel()
    sm.ParseFromString(data)

    if 1 != len(sm.meta_graphs):
      print('More than one graph found. Not sure which to write')
      sys.exit(1)

    g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
LOGDIR='./logdir'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
train_writer.flush()
train_writer.close()
複製代碼

代碼中,將彙總信息輸出到logdir,接着啓動TensorBoard,加上上面的logdir:

tensorboard --logdir ./logdir
複製代碼

在瀏覽器中輸入地址: http://127.0.0.1:6006/ ,就能夠看到以下的計算圖:

小結

按照前面兩種方法,咱們能夠對Tensorflow SavedModel格式的模型有比較全面的瞭解,即便模型訓練者並無給出文檔。有了這些模型信息,相信你寫出使用模型進行推斷更加容易。

本文的完整代碼請參考:github.com/mogoweb/aie…

但願這篇文章對您有幫助,感謝閱讀!

image
相關文章
相關標籤/搜索