1. High level API checkpointspython
只針對與 estimatorlua
設置檢查點的時間頻率和總個數spa
my_checkpointing_config = tf.estimator.RunConfig( save_checkpoints_secs = 20*60, # Save checkpoints every 20 minutes. keep_checkpoint_max = 10, # Retain the 10 most recent checkpoints. )
實例化時傳遞給 estimator 的 config 參數rest
model_dir 設置存儲路徑code
classifier = tf.estimator.DNNClassifier( feature_columns=my_feature_columns, hidden_units=[10, 10], n_classes=3, model_dir='models/iris', config=my_checkpointing_config)
一旦檢查點文件存在,TensorFlow 總會在你調用 train()
、 evaluation()
或 predict()
時重建模型教程
------------------------------------------------------------------------------------------------------------get
2.Low level API tf.train.Saverit
-------------------------------------------------------------------------------------------------------------io
Saver.save 存儲 model 中的全部變量class
import tensorflow as tf # 建立變量 var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer) # 添加初始化變量的操做 init_op = tf.global_variables_initializer() # 添加保存和恢復這些變量的操做 saver = tf.train.Saver() # 而後,加載模型,初始化變量,完成一些工做,並保存這些變量到磁盤中 with tf.Session() as sess: sess.run(init_op) # 使用模型完成一些工做 var.op.run() # 將變量保存到磁盤中 save_path = saver.save(sess, "/tmp/model.ckpt") print("Model saved in path: %s" % save_path)
var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer) # tf.get_variable: Gets an existing variable with these parameters or create a new one. # shape: Shape of the new or existing variable # initializer: Initializer for the variable if one is created. tf.zeros_initializer 賦值爲0 [0 0 0]
saver = tf.train.Saver() # Saver 來管理模型中的全部變量,注意是全部變量
tf.Session() # A class for running TensorFlow operations.
with...as... #執行 with 後面的語句,若是能夠執行則將賦值給 as 後的語句。若是出現錯誤則執行 with 後語句中的 __exit__ #來報錯。相似與 try if,可是更方便
Saver.save 選擇性的存儲變量
saver = tf.train.Saver({'var2':var2})
-------------------------------------------------------------------------------------------------------------
Saver.restore 加載路徑中的全部變量
import tensorflow as tf tf.reset_default_graph() # 建立一些變量 var = tf.get_variable("var", shape=[3]) # 添加保存和恢復這些變量的操做 saver = tf.train.Saver() # 而後,加載模型,使用 saver 從磁盤中恢復變量,並使用變量完成一些工做 with tf.Session() as sess: # 從磁盤中恢復變量 saver.restore(sess, "/tmp/model.ckpt") print("Model restored.") # 檢查變量的值 print("var : %s" % var.eval())
-------------------------------------------------------------------------------------------------------------
inspector_checkpoint 檢查存儲的變量
加載 inspect_checkpoints
from tensorflow.python.tools import inspect_checkpoint as chkp
打印存儲起來的全部變量
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True, all_tensor_names=False)
注意其中的參數 all_tensor_names 教程中並未添加這個參數,運行時持續報錯 missing
打印製定的變量
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='var1', all_tensors=False, all_tensor_names=False)