TensorFlow 學習初步- 變量,模型的存儲和讀取

1. High level API checkpointspython

只針對與 estimatorlua

設置檢查點的時間頻率和總個數spa

my_checkpointing_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,  # Save checkpoints every 20 minutes.
    keep_checkpoint_max = 10,       # Retain the 10 most recent checkpoints.
)

實例化時傳遞給 estimator 的 config 參數rest

model_dir 設置存儲路徑code

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris',
    config=my_checkpointing_config)

一旦檢查點文件存在,TensorFlow 總會在你調用 train() 、 evaluation() 或 predict() 時重建模型教程

------------------------------------------------------------------------------------------------------------get

2.Low level API tf.train.Saverit

-------------------------------------------------------------------------------------------------------------io

Saver.save 存儲 model 中的全部變量class

import tensorflow as tf

# 建立變量
var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer)


# 添加初始化變量的操做
init_op = tf.global_variables_initializer()

# 添加保存和恢復這些變量的操做
saver = tf.train.Saver()

# 而後,加載模型,初始化變量,完成一些工做,並保存這些變量到磁盤中
with tf.Session() as sess:
  sess.run(init_op)
  # 使用模型完成一些工做
  var.op.run()

  # 將變量保存到磁盤中
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in path: %s" % save_path)
var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer)

# tf.get_variable: Gets an existing variable with these parameters or create a new one.
# shape: Shape of the new or existing variable
# initializer: Initializer for the variable if one is created. tf.zeros_initializer 賦值爲0 [0 0 0]
saver = tf.train.Saver() # Saver 來管理模型中的全部變量,注意是全部變量
tf.Session() # A class for running TensorFlow operations.
with...as...
#執行 with 後面的語句,若是能夠執行則將賦值給 as 後的語句。若是出現錯誤則執行 with 後語句中的 __exit__
#來報錯。相似與 try if,可是更方便

Saver.save 選擇性的存儲變量

saver = tf.train.Saver({'var2':var2})

-------------------------------------------------------------------------------------------------------------

Saver.restore 加載路徑中的全部變量

import tensorflow as tf

tf.reset_default_graph()

# 建立一些變量
var = tf.get_variable("var", shape=[3])

# 添加保存和恢復這些變量的操做
saver = tf.train.Saver()

# 而後,加載模型,使用 saver 從磁盤中恢復變量,並使用變量完成一些工做
with tf.Session() as sess:
  # 從磁盤中恢復變量
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # 檢查變量的值
  print("var : %s" % var.eval())

-------------------------------------------------------------------------------------------------------------

inspector_checkpoint 檢查存儲的變量

加載 inspect_checkpoints

from tensorflow.python.tools import inspect_checkpoint as chkp

打印存儲起來的全部變量

chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True, all_tensor_names=False)

注意其中的參數 all_tensor_names 教程中並未添加這個參數,運行時持續報錯 missing

打印製定的變量

chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='var1', all_tensors=False, all_tensor_names=False)
相關文章
相關標籤/搜索