【TensorFlow系列】【三】凍結模型文件並作inference

本文基於mnist與lenet,講述以下兩個問題:node

1.如何將訓練好的網絡模型凍結,造成net.pb文件?python

2.如何將net.pb文件部署到TensorFlow中作inference?git

pb文件保存的步驟
1.須要給input與最終的預測值取個名字,便於部署時輸入數據並輸出數據
2.利用graph_util.convert_variables_to_constants將網絡中模型參數變量轉換爲常量
3.利用tf.gfile.FastGFile將模型參數序列化後的數據寫入文件。網絡

pb文件部署步驟:
1.利用tf.gfile.FastGFile讀取pb文件,並將文件中存儲的graph導入到TensorFlow中。
2.從graph中獲取input與output變量,傳入圖片數據,作inferencesession

 

【基於mnist與lenet,保存pb文件】dom

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework import graph_util

mnist = input_data.read_data_sets(train_dir=r"E:\mnist_data",one_hot=True)


#定義輸入數據mnist圖片大小28*28*1=784,None表示batch_size
x = tf.placeholder(dtype=tf.float32,shape=[None,28*28],name="input")
#定義標籤數據,mnist共10類
y_ = tf.placeholder(dtype=tf.float32,shape=[None,10],name="y_")
#將數據調整爲二維數據,w*H*c---> 28*28*1,-1表示N張
image = tf.reshape(x,shape=[-1,28,28,1])

#第一層,卷積核={5*5*1*32},池化核={2*2*1,1*2*2*1}
w1 = tf.Variable(initial_value=tf.random_normal(shape=[5,5,1,32],stddev=0.1,dtype=tf.float32,name="w1"))
b1= tf.Variable(initial_value=tf.zeros(shape=[32]))
conv1 = tf.nn.conv2d(input=image,filter=w1,strides=[1,1,1,1],padding="SAME",name="conv1")
relu1 = tf.nn.relu(tf.nn.bias_add(conv1,b1),name="relu1")
pool1 = tf.nn.max_pool(value=relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
#shape={None,14,14,32}
#第二層,卷積核={5*5*32*64},池化核={2*2*1,1*2*2*1}
w2 = tf.Variable(initial_value=tf.random_normal(shape=[5,5,32,64],stddev=0.1,dtype=tf.float32,name="w2"))
b2 = tf.Variable(initial_value=tf.zeros(shape=[64]))
conv2 = tf.nn.conv2d(input=pool1,filter=w2,strides=[1,1,1,1],padding="SAME")
relu2 = tf.nn.relu(tf.nn.bias_add(conv2,b2),name="relu2")
pool2 = tf.nn.max_pool(value=relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME",name="pool2")
#shape={None,7,7,64}
#FC1
w3 = tf.Variable(initial_value=tf.random_normal(shape=[7*7*64,1024],stddev=0.1,dtype=tf.float32,name="w3"))
b3 = tf.Variable(initial_value=tf.zeros(shape=[1024]))
#關鍵,進行reshape
input3 = tf.reshape(pool2,shape=[-1,7*7*64],name="input3")
fc1 = tf.nn.relu(tf.nn.bias_add(value=tf.matmul(input3,w3),bias=b3),name="fc1")
#shape={None,1024}
#FC2
w4 = tf.Variable(initial_value=tf.random_normal(shape=[1024,10],stddev=0.1,dtype=tf.float32,name="w4"))
b4 = tf.Variable(initial_value=tf.zeros(shape=[10]))
fc2 = tf.nn.bias_add(value=tf.matmul(fc1,w4),bias=b4)
#shape={None,10}
#定義交叉熵損失
# 使用softmax將NN計算輸出值表示爲機率
y = tf.nn.softmax(fc2,name="out")

# 定義交叉熵損失函數
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=fc2,labels=y_)
loss = tf.reduce_mean(cross_entropy)
#定義solver
train = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss=loss)

#定義正確值,判斷兩者下標index是否相等
correct_predict = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
#定義如何計算準確率
accuracy = tf.reduce_mean(tf.cast(correct_predict,dtype=tf.float32),name="accuracy")
#定義初始化op
init = tf.global_variables_initializer()

#訓練NN
with tf.Session() as session:
    session.run(fetches=init)
    for i in range(0,1000):
        xs, ys = mnist.train.next_batch(100)
        session.run(fetches=train,feed_dict={x:xs,y_:ys})
        if i%100 == 0:
            train_accuracy = session.run(fetches=accuracy,feed_dict={x:xs,y_:ys})
            print(i,"accuracy=",train_accuracy)
    #訓練完成後,將網絡中的權值轉化爲常量,造成常量graph
    constant_graph = graph_util.convert_variables_to_constants(sess=session,
                                                            input_graph_def=session.graph_def,
                                                            output_node_names=['out'])
    #將帶權值的graph序列化,寫成pb文件存儲起來
    with tf.gfile.FastGFile("lenet.pb", mode='wb') as f:
        f.write(constant_graph.SerializeToString())

【將pb文件部署到TensorFlow中並作inference】ide

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

mnist = input_data.read_data_sets(train_dir=r"E:\mnist_data",one_hot=True)
pb_path = r"lenet.pb"
#導入pb文件到graph中
with tf.gfile.FastGFile(pb_path,'rb') as f:
    # 複製定義好的計算圖到新的圖中,先建立一個空的圖.
    graph_def = tf.GraphDef()
    # 加載proto-buf中的模型
    graph_def.ParseFromString(f.read())
    # 最後複製pre-def圖的到默認圖中.
    _ = tf.import_graph_def(graph_def, name='')
with tf.Session() as session:
    #獲取輸入tensor
    input = tf.get_default_graph().get_tensor_by_name("input:0")
    #獲取預測tensor
    output = tf.get_default_graph().get_tensor_by_name("out:0")
    #取第100張圖片測試
    one_image = np.reshape(mnist.test.images[100], [-1, 784])
    #將測試圖片傳入nn中,作inference
    out = session.run(output,feed_dict={input:one_image})
    pre_label = np.argmax(out,1)
    print("pre_label=",pre_label)
    print('true label:', np.argmax(mnist.test.labels[100],0))

測試結果以下圖:函數

相關文章
相關標籤/搜索