標籤(空格分隔): TensorFlowgit
tensorflow模型保存函數爲:dom
tf.train.Saver()
固然,除了上面最簡單的保存方式,也能夠指定保存的步數,多長時間保存一次,磁盤上最多保有幾個模型(將前面的刪除以保持固定個數),以下:函數
建立saver時指定參數:測試
saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)
其中:rest
保存模型時指定參數:code
saver.save(sess, 'model_name', global_step=step,write_meta_graph=False)
如上,其中能夠指定模型文件名,步數,write_meta_graph則用來指定是否保存meta文件記錄graph等等。orm
示例:ci
import tensorflow as tf v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1") v2= tf.Variable(tf.zeros([200]), name="v2") v3= tf.Variable(tf.zeros([100]), name="v3") saver = tf.train.Saver() with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) saver.save(sess,"checkpoint/model.ckpt",global_step=1)
運行後,保存模型保存,獲得四個文件:get
checkpoint中記錄了已存儲(部分)和最近存儲的模型:input
model_checkpoint_path: "model.ckpt-1" all_model_checkpoint_paths: "model.ckpt-1" ...
meta file保存了graph結構,包括 GraphDef,SaverDef等,當存在meta file,咱們能夠不在文件中定義模型,也能夠運行,而若是沒有meta file,咱們須要定義好模型,再加載data file,獲得變量值。
index file爲一個string-string table,table的key值爲tensor名,value爲serialized BundleEntryProto。每一個BundleEntryProto表述了tensor的metadata,好比那個data文件包含tensor、文件中的偏移量、一些輔助數據等。
data file保存了模型的全部變量的值,TensorBundle集合。
Restore模型的過程能夠分爲兩個部分,首先是建立模型,能夠手動建立,也能夠從meta文件里加載graph進行建立。
模型加載爲:
with tf.Session() as sess: saver = tf.train.import_meta_graph('/xx/model.ckpt.meta') saver.restore(sess, "/xx/model.ckpt")
.meta文件中保存了圖的結構信息,所以須要在導入checkpoint以前導入它。不然,程序不知道checkpoint中的變量對應的變量。另外也能夠:
# Recreate the EXACT SAME variables v1 = tf.Variable(..., name="v1") v2 = tf.Variable(..., name="v2") ... # Now load the checkpoint variable values with tf.Session() as sess: saver = tf.train.Saver() saver.restore(sess, "/xx/model.ckpt") #saver.restore(sess, tf.train.latest_checkpoint('./'))
PS:不存在model.ckpt文件,saver.py中:Users only need to interact with the user-specified prefix... instead of any physical pathname.
固然,還有一點須要注意,並不是全部的TensorFlow模型都能將graph輸出到meta文件中或者從meta文件中加載進來,若是模型有部分不能序列化的部分,則此種方法可能會無效。
with tf.Session() as sess: saver = tf.train.import_meta_graph('model.ckpt-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) tvs = [v for v in tf.trainable_variables()] for v in tvs: print(v.name) print(sess.run(v))
如名所言,以上是查看模型中的trainable variables;或者咱們也能夠查看模型中的全部tensor或者operations,以下:
with tf.Session() as sess: saver = tf.train.import_meta_graph('model.ckpt-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) gv = [v for v in tf.global_variables()] for v in gv: print(v.name)
上面經過global_variables()得到的與前trainable_variables相似,只是多了一些非trainable的變量,好比定義時指定爲trainable=False的變量,或Optimizer相關的變量。
下面則能夠得到幾乎全部的operations相關的tensor:
with tf.Session() as sess: saver = tf.train.import_meta_graph('model.ckpt-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) ops = [o for o in sess.graph.get_operations()] for o in ops: print(o.name)
首先,上面的sess.graph.get_operations()能夠換爲tf.get_default_graph().get_operations(),兩者區別無非是graph明確的時候能夠直接使用前者,不然須要使用後者。
此種方法得到的tensor比較齊全,能夠從中一窺模型全貌。不過,最方便的方法仍是推薦使用tensorboard來查看,固然這須要你提早將sess.graph輸出。
這種操做比較簡單,無非是找到原始模型的輸入、輸出便可。
只要搞清楚輸入輸出的tensor名字,便可直接使用TensorFlow中graph的get_tensor_by_name函數,創建輸入輸出的tensor:
with tf.get_default_graph() as graph: data = graph.get_tensor_by_name('data:0') output = graph.get_tensor_by_name('output:0')
從模型中找到了輸入輸出以後,便可直接使用其繼續train整個模型,或者將輸入數據feed到模型裏,並前傳獲得test輸出了。
須要說明的是,有時候從一個graph裏找到輸入和輸出tensor的名字並不容易,因此,在定義graph時,最好能給相應的tensor取上一個明顯的名字,好比:
data = tf.placeholder(tf.float32, shape=shape, name='input_data') preds = tf.nn.softmax(logits, name='output')
諸如此類。這樣,就能夠直接使用tf.get_tensor_by_name(‘input_data:0’)之類的來找到輸入輸出了。
除了直接使用原始模型,還能夠在原始模型上進行擴展,好比對1中的output繼續進行處理,添加新的操做,能夠完成對原始模型的擴展,如:
with tf.get_default_graph() as graph: data = graph.get_tensor_by_name('data:0') output = graph.get_tensor_by_name('output:0') logits = tf.nn.softmax(output)
有時候,咱們有對某模型的一部分進行fine-tune的需求,好比使用一個VGG的前面提取特徵的部分,而微調其全連層,或者將其全連層更換爲使用convolution來完成,等等。TensorFlow也提供了這種支持,可使用TensorFlow的stop_gradient函數,將模型的一部分進行凍結。
with tf.get_default_graph() as graph: graph.get_tensor_by_name('fc1:0') fc1 = tf.stop_gradient(fc1) # add new procedure on fc1