TensorFlow保存和恢復模型的方法總結

使用TensorFlow訓練模型的過程當中,須要適時對模型進行保存,以及對保存的模型進行restore,以方便後續對模型進行處理。好比進行測試,或者部署;好比拿別的模型進行fine-tune,等等。固然,直接的保存和restore比較簡單,無需多言,可是保存和restore中還牽涉到其餘問題,以及針對各類需求的各類參數等,可能不便一下都記好。所以,有必要對此進行一個總結。本文就是對使用TensorFlow保存和restore模型的相關內容進行一下總結,以便備忘。git

保存模型

保存模型是整個內容的第一步,固然也十分簡單。無非是建立一個saver,並在一個Session裏完成保存。好比:函數

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.save(sess, model_name)

以上代碼在0.11如下版本的TensorFlow裏會保存與下面相似的3個文件:測試

checkpointspa

model.ckpt-1000.metarest

model.ckpt-1000.ckptcode

在0.11及以上版本的TensorFlow裏則會保存與下相似的4個文件:部署

checkpointget

model.ckpt-1000.indexinput

model.ckpt-1000.data-00000-of-00001it

model.ckpt-1000.meta

其中checkpoint列出保存的全部模型以及最近的模型;meta文件是模型定義的內容;ckpt(或data和index)文件是保存的模型數據;內裏細節無需過多關注,若是想了解,stackOverflow上有一個解釋的回答。

固然,除了上面最簡單的保存方式,也能夠指定保存的步數,多長時間保存一次,磁盤上最多保有幾個模型(將前面的刪除以保持固定個數),以下:

建立saver時指定參數:

saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)

其中savable_variables指定待保存的變量,好比指定爲tf.global_variables()保存全部global變量;指定爲[v1, v2]保存v1和v2兩個變量;若是省略,則保存全部;

max_to_keep指定磁盤上最多保有幾個模型;keep_checkpoint_every_n_hours指定多少小時保存一次。

保存模型時指定參數:

saver.save(sess, 'model_name', global_step=step,write_meta_graph=False)

如上,其中能夠指定模型文件名,步數,write_meta_graph則用來指定是否保存meta文件記錄graph等等。

Restore模型

具體來講,Restore模型的過程能夠分爲兩個部分,首先是建立模型,能夠手動建立,也能夠從meta文件里加載graph進行建立。

建立模型與訓練模型時建立模型的代碼相同,能夠直接複製過來使用。

從meta文件裏進行加載,能夠直接在Session裏進行以下操做:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model.ckpt-1000.meta')

後面的參數直接使用meta文件的路徑便可。如此,即將模型定義的graph加載進來了。

固然,還有一點須要注意,並不是全部的TensorFlow模型都能將graph輸出到meta文件中或者從meta文件中加載進來,若是模型有部分不能序列化的部分,則此種方法可能會無效。

而後就是爲模型加載數據,可使用下面兩種方法:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))

此方法加載指定文件夾下最近保存的一個模型的數據;或者

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
    saver.restore(sess, os.path.join(path, 'model.ckpt-1000'))

此方法能夠指定具體某個數據,須要注意的是,指定的文件不要包含後綴。

使用Restore的模型

將模型數據加載進來以後,下一步就是利用加載的模型進行下一步的操做了。這能夠根據不一樣須要以以下幾種方式進行操做。

1.查看模型參數

能夠直接查看Restore進來的模型的參數,以下:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    tvs = [v for v in tf.trainable_variables()]
    for v in tvs:
        print(v.name)
        print(sess.run(v))

如名所言,以上是查看模型中的trainable variables;或者咱們也能夠查看模型中的全部tensor或者operations,以下:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    gv = [v for v in tf.global_variables()]
    for v in gv:
        print(v.name)

上面經過global_variables()得到的與前trainable_variables相似,只是多了一些非trainable的變量,好比定義時指定爲trainable=False的變量,或Optimizer相關的變量。

下面則能夠得到幾乎全部的operations相關的tensor:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = [o for o in sess.graph.get_operations()]
    for o in ops:
        print(o.name)

首先,上面的sess.graph.get_operations()能夠換爲tf.get_default_graph().get_operations(),兩者區別無非是graph明確的時候能夠直接使用前者,不然須要使用後者。

此種方法得到的tensor比較齊全,能夠從中一窺模型全貌。不過,最方便的方法仍是推薦使用tensorboard來查看,固然這須要你提早將sess.graph輸出。

2.直接使用原始模型進行訓練或測試(前傳)

這種操做比較簡單,無非是找到原始模型的輸入、輸出便可。

只要搞清楚輸入輸出的tensor名字,便可直接使用TensorFlow中graph的get_tensor_by_name函數,創建輸入輸出的tensor:

with tf.get_default_graph() as graph:
    data = graph.get_tensor_by_name('data:0')
    output = graph.get_tensor_by_name('output:0')

如上,須要特別注意,get_tensor_by_name後面傳入的參數,若是沒有重複,須要在後面加上「:0」。

從模型中找到了輸入輸出以後,便可直接使用其繼續train整個模型,或者將輸入數據feed到模型裏,並前傳獲得test輸出了。

須要說明的是,有時候從一個graph裏找到輸入和輸出tensor的名字並不容易,因此,在定義graph時,最好能給相應的tensor取上一個明顯的名字,好比:

data = tf.placeholder(tf.float32, shape=shape, name='input_data')

preds = tf.nn.softmax(logits, name='output')

諸如此類。這樣,就能夠直接使用tf.get_tensor_by_name(‘input_data:0’)之類的來找到輸入輸出了。

3.擴展原始模型

除了直接使用原始模型,還能夠在原始模型上進行擴展,好比對1中的output繼續進行處理,添加新的操做,能夠完成對原始模型的擴展,如:

with tf.get_default_graph() as graph:
    data = graph.get_tensor_by_name('data:0')
    output = graph.get_tensor_by_name('output:0')
    logits = tf.nn.softmax(output)

4.使用原始模型的某部分

有時候,咱們有對某模型的一部分進行fine-tune的需求,好比使用一個VGG的前面提取特徵的部分,而微調其全連層,或者將其全連層更換爲使用convolution來完成,等等。TensorFlow也提供了這種支持,可使用TensorFlow的stop_gradient函數,將模型的一部分進行凍結。

with tf.get_default_graph() as graph:
    graph.get_tensor_by_name('fc1:0')
    fc1 = tf.stop_gradient(fc1)
    # add new procedure on fc1
相關文章
相關標籤/搜索