TF的模型文件

TF的模型文件

標籤(空格分隔): TensorFlowgit


Saver

tensorflow模型保存函數爲:dom

tf.train.Saver()

固然,除了上面最簡單的保存方式,也能夠指定保存的步數,多長時間保存一次,磁盤上最多保有幾個模型(將前面的刪除以保持固定個數),以下:函數

建立saver時指定參數:測試

saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)

其中:rest

  • savable_variables指定待保存的變量,好比指定爲tf.global_variables()保存全部global變量;指定爲[v1, v2]保存v1和v2兩個變量;若是省略,則保存全部;
  • max_to_keep指定磁盤上最多保有幾個模型;
  • keep_checkpoint_every_n_hours指定多少小時保存一次。

保存模型時指定參數:code

saver.save(sess, 'model_name', global_step=step,write_meta_graph=False)

如上,其中能夠指定模型文件名,步數,write_meta_graph則用來指定是否保存meta文件記錄graph等等。orm

示例:ci

import tensorflow as tf
​
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
v3= tf.Variable(tf.zeros([100]), name="v3")
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    saver.save(sess,"checkpoint/model.ckpt",global_step=1)

運行後,保存模型保存,獲得四個文件:get

  • checkpoint
  • model.ckpt-1.data-00000-of-00001
  • model.ckpt-1.index
  • model.ckpt-1.meta

checkpoint中記錄了已存儲(部分)和最近存儲的模型:input

model_checkpoint_path: "model.ckpt-1"
all_model_checkpoint_paths: "model.ckpt-1"
...

meta file保存了graph結構,包括 GraphDef,SaverDef等,當存在meta file,咱們能夠不在文件中定義模型,也能夠運行,而若是沒有meta file,咱們須要定義好模型,再加載data file,獲得變量值。

index file爲一個string-string table,table的key值爲tensor名,value爲serialized BundleEntryProto。每一個BundleEntryProto表述了tensor的metadata,好比那個data文件包含tensor、文件中的偏移量、一些輔助數據等。

data file保存了模型的全部變量的值,TensorBundle集合。

Restore

Restore模型的過程能夠分爲兩個部分,首先是建立模型,能夠手動建立,也能夠從meta文件里加載graph進行建立。

模型加載爲:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/xx/model.ckpt.meta')
    saver.restore(sess, "/xx/model.ckpt")

.meta文件中保存了圖的結構信息,所以須要在導入checkpoint以前導入它。不然,程序不知道checkpoint中的變量對應的變量。另外也能夠:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/xx/model.ckpt")
    #saver.restore(sess, tf.train.latest_checkpoint('./'))

PS:不存在model.ckpt文件,saver.py中:Users only need to interact with the user-specified prefix... instead of any physical pathname.

固然,還有一點須要注意,並不是全部的TensorFlow模型都能將graph輸出到meta文件中或者從meta文件中加載進來,若是模型有部分不能序列化的部分,則此種方法可能會無效。

使用Restore的模型

查看模型的參數

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
  saver.restore(sess, tf.train.latest_checkpoint('./'))
  tvs = [v for v in tf.trainable_variables()]
  for v in tvs:
    print(v.name)
    print(sess.run(v))

如名所言,以上是查看模型中的trainable variables;或者咱們也能夠查看模型中的全部tensor或者operations,以下:

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
  saver.restore(sess, tf.train.latest_checkpoint('./'))
  gv = [v for v in tf.global_variables()]
  for v in gv:
    print(v.name)

上面經過global_variables()得到的與前trainable_variables相似,只是多了一些非trainable的變量,好比定義時指定爲trainable=False的變量,或Optimizer相關的變量。

下面則能夠得到幾乎全部的operations相關的tensor:

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
  saver.restore(sess, tf.train.latest_checkpoint('./'))
  ops = [o for o in sess.graph.get_operations()]
  for o in ops:
    print(o.name)

首先,上面的sess.graph.get_operations()能夠換爲tf.get_default_graph().get_operations(),兩者區別無非是graph明確的時候能夠直接使用前者,不然須要使用後者。

此種方法得到的tensor比較齊全,能夠從中一窺模型全貌。不過,最方便的方法仍是推薦使用tensorboard來查看,固然這須要你提早將sess.graph輸出。

直接使用原始模型進行訓練或測試

這種操做比較簡單,無非是找到原始模型的輸入、輸出便可。
只要搞清楚輸入輸出的tensor名字,便可直接使用TensorFlow中graph的get_tensor_by_name函數,創建輸入輸出的tensor:

with tf.get_default_graph() as graph:
  data = graph.get_tensor_by_name('data:0')
  output = graph.get_tensor_by_name('output:0')

從模型中找到了輸入輸出以後,便可直接使用其繼續train整個模型,或者將輸入數據feed到模型裏,並前傳獲得test輸出了。

須要說明的是,有時候從一個graph裏找到輸入和輸出tensor的名字並不容易,因此,在定義graph時,最好能給相應的tensor取上一個明顯的名字,好比:

data = tf.placeholder(tf.float32, shape=shape, name='input_data')
preds = tf.nn.softmax(logits, name='output')

諸如此類。這樣,就能夠直接使用tf.get_tensor_by_name(‘input_data:0’)之類的來找到輸入輸出了。

擴展原始模型

除了直接使用原始模型,還能夠在原始模型上進行擴展,好比對1中的output繼續進行處理,添加新的操做,能夠完成對原始模型的擴展,如:

with tf.get_default_graph() as graph:
  data = graph.get_tensor_by_name('data:0')
  output = graph.get_tensor_by_name('output:0')
  logits = tf.nn.softmax(output)

使用原始模型的某部分

有時候,咱們有對某模型的一部分進行fine-tune的需求,好比使用一個VGG的前面提取特徵的部分,而微調其全連層,或者將其全連層更換爲使用convolution來完成,等等。TensorFlow也提供了這種支持,可使用TensorFlow的stop_gradient函數,將模型的一部分進行凍結。

with tf.get_default_graph() as graph:
  graph.get_tensor_by_name('fc1:0')
  fc1 = tf.stop_gradient(fc1)
  # add new procedure on fc1
相關文章
相關標籤/搜索