import tensorflow as tf def store_model_ckpt(ckpt_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') #模型的保存必須有變量 c = tf.Variable(1, name='c') a = tf.add(x, y, name='op') result = tf.add(a, c) with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) saver = tf.train.Saver() #若是隻保存其中一部分變量,則使用下面代碼,用列表或者字典均可以 #saver = tf.train.Saver([x, y]) #這裏面有參數global_step=50,當訓練50步便保存模型 saver.save(sess, ckpt_file_path) # test feed_dict = {x: 2, y: 3} print(sess.run(result, feed_dict)) def main(): ckpt_file_path = "./ckpt/model.ckpt" store_model_ckpt(ckpt_file_path) if __name__ == '__main__': main()
結果:6node
程序生成並保存四個文件python
針對上面的模型保存例子,還原模型的過程以下:git
import tensorflow as tf def restore_model_ckpt(): with tf.Session() as sess: #step1:加載模型結構 saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') #step2:只須要指定目錄就能夠恢復全部變量信息 saver.restore(sess,tf.train.latest_checkpoint('./ckpt')) #直接獲取保存的變量 print(sess.run('c:0')) #獲取placeholder變量,經過get_tensor_by_name x = sess.graph.get_tensor_by_name('x:0') y = sess.graph.get_tensor_by_name('y:0') #獲取須要進行計算的op算子,此op爲加法 op = sess.graph.get_tensor_by_name('op:0') #加入新的op操做,新的op爲乘法 new_op = tf.multiply(op, 2) #test feed_dict = {x:2, y:3} result = sess.run(new_op,feed_dict) print(result) def main(): restore_model_ckpt() if __name__ == '__main__': main()
結果:10瀏覽器
1. 首先還原模型結構網絡
2. 而後還原變量(參數)信息架構
3. 最後咱們就能夠得到已訓練的模型中的各類信息了(保存的變量、placeholder變量、operator等),同時能夠對獲取的變量添加各類新的操做(見以上代碼註釋)。
而且,咱們也能夠加載部分模型,在此基礎上加入其它操做,具體能夠參考官方文檔和demo。dom
針對ckpt模型文件的保存與還原,stackoverflow上有一個回答解釋比較清晰,能夠參考。函數
同時cv-tricks.com上面的TensorFlow模型保存與恢復的教程也很是好,能夠參考。源碼分析
import tensorflow as tf from tensorflow.python.framework import graph_util def store_model_pb(pb_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') a = tf.add(x, y) #該op算子應該加上name op = tf.add(a, b, name='op') with tf.Session() as sess: init = tf.initialize_all_variables() sess.run(init) #導出當前計算圖的GraphDef部分,只須要這一部分就能夠完成從輸入層到輸出層的計算 graph_def = tf.get_default_graph().as_graph_def() #將圖中的變量及其取值轉化爲常量,同時將圖中的沒必要要的節點去掉 output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['op']) with tf.gfile.FastGFile(pb_file_path, mode='wb') as f: f.write(output_graph_def.SerializeToString()) #test feed_dict = {x: 2, y: 3} print(sess.run(op, feed_dict)) def main(): pb_file_path = "model.pb" store_model_pb(pb_file_path) if __name__ == '__main__': main()
結果:6 測試
在當前文件下面生成model.pb文件
import tensorflow as tf from tensorflow.python.platform import gfile def restore_model_pb(pb_file_path): with tf.Session() as sess: with gfile.FastGFile(pb_file_path, 'rb') as f: graph_def = tf.GraphDef() #轉換成字符串形式 graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') #獲取placeholder的變量 x = sess.graph.get_tensor_by_name('x:0') y = sess.graph.get_tensor_by_name('y:0') #獲取op算子 op = sess.graph.get_tensor_by_name('op:0') feed_dict = {x: 2, y:3} result = sess.run(op,feed_dict) print(result) def main(): pb_file_path = "model.pb" restore_model_pb(pb_file_path) if __name__ == '__main__': main()
結果:5
但不少時候,咱們須要將TensorFlow的模型導出爲單個文件(同時包含模型結構的定義與權重),方便在其餘地方使用(如在Android中部署網絡)。利用tf.train.write_graph()默認狀況下只導出了網絡的定義(沒有權重),而利用tf.train.Saver().save()導出的文件graph_def與權重是分離的,所以須要採用別的方法。 咱們知道,graph_def文件中沒有包含網絡中的Variable值(一般狀況存儲了權重),可是卻包含了constant值,因此若是咱們能把Variable轉換爲constant,便可達到使用一個文件同時存儲網絡架構與權重的目標。
TensoFlow爲咱們提供了convert_variables_to_constants()方法,該方法能夠固化模型結構,將計算圖中的變量取值以常量的形式保存,並且保存的模型能夠移植到Android平臺。
將CKPT 轉換成 PB格式的文件的過程可簡述以下:
1. 經過傳入 CKPT 模型的路徑獲得模型的圖和變量數據
2. 經過 import_meta_graph 導入模型中的圖
3. 經過 saver.restore 從模型中恢復圖中各個變量的數據
4. 經過 graph_util.convert_variables_to_constants 將模型持久化
Code:freeze_graph.py
import tensorflow as tf from tensorflow.python.framework import graph_util def freeze_graph(ckpt_file_path, pb_file_path): #「input:0」是張量的名稱,而"input"表示的是節點的名稱。 #此處輸入的應該是節點的名稱 output_node_names = "op" #首先恢復圖結構 saver = tf.train.import_meta_graph(ckpt_file_path+'.meta',clear_devices=True) graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() with tf.Session() as sess: #恢復圖並獲得數據 saver.restore(sess,ckpt_file_path) output_graph_def = graph_util.convert_variables_to_constants( sess=sess, input_graph_def=input_graph_def, #若是有多個輸出節點 output_node_names=output_node_names.split(",")) with tf.gfile.GFile(pb_file_path,"wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node)) def main(): # 輸入ckpt模型路徑 model_folder = "D:\AI\Ckpt\TestCkpt\ckpt" #檢查目錄下ckpt文件狀態是否可用 checkpoint = tf.train.get_checkpoint_state(model_folder) #得ckpt文件路徑 ckpt_file_path = checkpoint.model_checkpoint_path # 輸出pb模型的路徑 pb_file_path="frozen_model.pb" # 調用freeze_graph將ckpt轉爲pb freeze_graph(ckpt_file_path,pb_file_path) if __name__ == '__main__': main()
結果:生成 frozen_model.pb文件,能夠採用上面pb模型加載的方法測試該pb文件
說明:
一、函數freeze_graph中,最重要的就是要肯定「指定輸出的節點名稱」,這個節點名稱必須是原模型中存在的節點,對於freeze操做,咱們須要定義輸出結點的名字。由於網絡實際上是比較複雜的,定義了輸出結點的名字,那麼freeze的時候就只把輸出該結點所須要的子圖都固化下來,其餘無關的就捨棄掉。由於咱們freeze模型的目的是接下來作預測。因此,output_node_names通常是網絡模型最後一層輸出的節點名稱,或者說就是咱們預測的目標。
二、在保存的時候,經過convert_variables_to_constants函數來指定須要固化的節點名稱,對於鄙人的代碼,須要固化的節點只有一個:output_node_names。注意節點名稱與張量的名稱的區別,例如:「input:0」是張量的名稱,而"input"表示的是節點的名稱。
三、源碼中經過graph = tf.get_default_graph()得到默認的圖,這個圖就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢復的圖,所以必須先執行tf.train.import_meta_graph,再執行tf.get_default_graph() 。
四、上面以及說明:在保存的時候,經過convert_variables_to_constants函數來指定須要固化的節點名稱,對於鄙人的代碼,須要固化的節點只有一個:output_node_names。所以,其餘網絡模型,也能夠經過簡單的修改輸出的節點名稱output_node_names,將ckpt轉爲pb文件 。
PS:注意節點名稱,應包含name_scope 和 variable_scope命名空間,並用「/」隔開,如"InceptionV3/Logits/SpatialSqueeze"
# -*- coding: utf-8 -*- """ Created on Sat Dec 22 09:49:04 2018 @author: weilong """ import tensorflow as tf #定義簡單的計算圖,實現向量加法的操做 with tf.name_scope("imput1"): input1 = tf.constant([1.0, 2.0, 3.0], name="input1") with tf.name_scope("input2"): input2 = tf.Variable(tf.random_uniform([3]), name="input2") output = tf.add_n([input1, input2], name="add") #生成寫日誌的writer,並將當前的tensorflow計算圖寫入日誌 writer = tf.summary.FileWriter("./log", tf.get_default_graph()) writer.close()
import tensorflow as tf model = 'model.pb' #請將這裏的pb文件路徑改成本身的 graph = tf.get_default_graph() graph_def = graph.as_graph_def() graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read()) tf.import_graph_def(graph_def, name='graph') summaryWriter = tf.summary.FileWriter('log/', graph)
執行以上代碼就會生成文件在log/events.out.tfevents.1535079670.DESKTOP-5IRM000。
在tensorboard中加載:
tensorboard --logdir=\path\to\log
在瀏覽器中
拷貝網站連接在瀏覽器中便可。
參考:https://blog.csdn.net/guyuealian/article/details/82218092