Tensorflow 模型保存與調用

說明:訓練模型,保存相關參數,以便在之後驗證時直接輸入驗證數據集便可獲得模型模擬結果。html

主要參考了官方教程和博客 http://www.javashuo.com/article/p-mibtkwkd-ed.htmlrest

 

1、 模型存儲htm

mymodel.meta -----------保存完整Tensorflow graph的protocol buffer,好比說,全部的 variables, operations, collections等等blog

mymodel.data-00000-of-00001 ----------.data文件中包含了訓練變量,如權重(weights),偏置(biases),梯度(gradients)和全部其餘保存的變量(variables)。教程

mymodel.indexget

checkpoint -----------記錄最新保存的模型的存儲路徑。博客

 

二、保存模型it

使用tf.train.Saver() 類io

例:saver=tf.train.Saver(tf.global_variables(),max_to_keep=20)import

若是在tf.train.Saver()中沒有指定任何東西,將保存全部變量。

若是不想保存全部的變量,只想保存其中一些變量,能夠在建立tf.train.Saver實例的時候,給它傳遞一個想要保存的變量的list或者字典。

 

三、調用一個已經訓練好的模型

使用tf.train.import_meta_graph()、saver.restore() 和 tf.get_default_graph()

例:with tf.Session() as sess:

              saver=tf.train.import_meta_graph('train.model-1000.meta')     #指定參數的讀取路徑
              saver.restore(sess,('train.model-1000'))                                   #提取參數
              graph = tf.get_default_graph()                                                  #獲取模型結構(張量圖graph)

             #經過變量名加載變量的值

             X=graph.get_tensor_by_name('X:0')

            #注意:若想經過變量名稱加載變量,要求已保存的模型中爲變量指明瞭變量名

 

四、模型再訓練

 在三、中把模型的結構和參數提取出來後,直接按本身的需求編寫模型訓練的代碼便可。

相關文章
相關標籤/搜索