Tensorflow的slim框架能夠寫出像keras同樣簡單的代碼來實現網絡結構(雖然如今keras也已經集成在tf.contrib中了),並且models/slim提供了相似以前說過的object detection接口相似的image classification接口,能夠很方便的進行fine-tuning利用本身的數據集訓練本身所需的模型。node
官方文檔提供了比較詳細的從數據準備,預訓練模型的model zoo,fine-tuning,freeze model等一系列流程的步驟,可是缺乏了inference的文檔,不過tf全部模型的加載方式是通用的,因此調用方法和調用其餘pb模型是同樣的。python
根據TF開發人員是說法Tensorflow對於模型讀寫的保存和調用的步驟通常以下:Build your graph
--> write your graph
--> import from written graph
--> run compute etc
。git
如下咱們使用slim提供的網絡inception-resnet-v2做爲例子:github
import tensorflow as tf import nets.inception_resnet_v2 as net slim = tf.contrib.slim # checkpoint path checkpoint_path = "/your/path/to/inception_resnet_v2.ckpt" # ckpt file obtained during model training or fine-tuning # set up and load session sess = tf.Session() arg_scope = net.inception_resnet_v2_arg_scope() # initialize tensor suitable for model input input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3]) with slim.arg_scope(arg_scope): logits, end_points = net.inception_resnet_v2(inputs=input_tensor) # set up model saver saver = tf.train.Saver() saver.restore(sess, checkpoint_path) with tf.gfile.GFile('/your/path/to/model_graph.pb', 'w') as f: # save model to given pb file f.write(sess.graph_def.SerializeToString()) f.close()
這裏用tf提供的tensorflow/python/tools下的freeze_graph工具:shell
$ bazel build tensorflow/python/tools:freeze_graph $ bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=/your/path/to/model_graph.pb \ # obtained above --input_checkpoint=/your/path/to/inception_resnet_v2.ckpt \ --input_binary=true --output_graph=/your/path/to/frozen_graph.pb \ --output_node_names=InceptionResnetV2/Logits/Predictions # output node name defined in inception resnet v2 net
LOG_DIR = ‘/tmp/graphdeflogdir’ model_filename = '/your/path/to/frozen_graph.pb' with tf.Session() as sess: with tf.gfile.FastGFile(model_filename, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name='') writer = tf.summary.FileWriter(LOG_DIR, graph_def) writer.close()
而後用tensorborad --logdir=LOG_DIR選擇graph就能夠查看到frozen後的網絡結構。api
import cv2 import numpy as np def preprocess_inception(image_np, central_fraction=0.875): image_height, image_width, image_channel = image_np.shape if central_fraction: bbox_start_h = int(image_height * (1 - central_fraction) / 2) bbox_end_h = int(image_height - bbox_start_h) bbox_start_w = int(image_width * (1 - central_fraction) / 2) bbox_end_w = int(image_width - bbox_start_w) image_np = image_np[bbox_start_h:bbox_end_h, bbox_start_w:bbox_end_w] # normalize image_np = 2 * (image_np / 255.) - 1 return image_np image_np = cv2.imread("test.jpg") # preprocess image as inception resnet v2 does image_np = preprcess_inception(image_np) # resize to model input image size image_np = cv2.resize(image_np, (299, 299)) # expand dims to shape [None, 299, 299, 3] image_np = np.expand_dims(image_np, 0) # load model with tf.gfile.GFile('/your/path/to/frozen_graph.pb') graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) graph = tf.import_graph_def(graph_def, name='') with tf.Session(graph=graph) as sess: input tensor = sess.graph.get_tensor_by_name("input:0") # get input tensor output_tensor = sess.graph.get_tensor_by_name("InceptionResnetV2/Logits/Predictions:0") # get output tensor logits = sess.run(output_tensor, feed_dict={input_tensor: image_np}) print "Prediciton label index:", np.argmax(logits[0], 1) print "Top 3 Prediciton label index:", np.argsort(logits[0], 3)