原文地址:https://blog.csdn.net/mrr1ght/article/details/81011280 。本文有刪減。html
tf.train.SessionRunHook()是一個類;用來定義Hooks;python
Hooks是什麼,官方文檔中關於training hooks的定義是:session
Hooks are tools that run in the process of training/evaluation of the model.函數
Hooks是在模型訓練/測試過程當中的工具。Pytorch中也常常會有這個概念出現,其實也就跟keras裏的callbacks同樣,hook和callback都是在訓練過程當中執行特定的任務。工具
例如判斷是否須要中止訓練的EarlyStopping;改變學習率的LearningRateScheduler,他們都有一個共性,就是在每一個step開始/結束或者每一個epoch開始/結束時須要執行某個操做。如每一個epoch結束都保存一次checkpoint;每一個epoch結束時都判斷一次loss有沒有降低,若是loss沒有降低的輪數大於提取設定的閾值,就終止訓練。固然以上的功能咱們均可以本身徹底重頭實現。可是這些keras和tersorflow提供了更好的工具就是hook和callback,而且一些經常使用的功能都已經實現好了。說到底每一個hook和callback都是按照固定格式定義了在每一個step開始/結束要執行的操做,每一個epoch開始/結束執行的操做。學習
Hooks都是繼承自父類tf.train.SessionRunHook()
,首先看一下這個父類的定義源碼;測試
tf.train.SessionRunHook()
類定義在tensorflow/python/training/session_run_hook.py
,類中每一個函數的做用與何時調用都已加入函數註釋中;lua
class SessionRunHook(object): """Hook to extend calls to MonitoredSession.run().""" def begin(self): """再建立會話以前調用 調用begin()時,default graph會被建立, 可在此處向default graph增長新op,begin()調用後,default graph不能再被修改 """ pass def after_create_session(self, session, coord): # pylint: disable=unused-argument """tf.Session被建立後調用 調用後會指示全部的Hooks有一個新的會話被建立 Args: session: A TensorFlow Session that has been created. coord: A Coordinator object which keeps track of all threads. """ pass def before_run(self, run_context): # pylint: disable=unused-argument """調用在每一個sess.run()執行以前 能夠返回一個tf.train.SessRunArgs(op/tensor),在即將運行的會話中加入這些op/tensor; 加入的op/tensor會和sess.run()中已定義的op/tensor合併,而後一塊兒執行; Args: run_context: A `SessionRunContext` object. Returns: None or a `SessionRunArgs` object. """ return None def after_run(self, run_context, # pylint: disable=unused-argument run_values): # pylint: disable=unused-argument """調用在每一個sess.run()以後 參數run_values是befor_run()中要求的op/tensor的返回值; 能夠調用run_context.qeruest_stop()用於中止迭代 sess.run拋出任何異常after_run不會被調用 Args: run_context: A `SessionRunContext` object. run_values: A SessionRunValues object. """ pass def end(self, session): # pylint: disable=unused-argument """在會話結束時調用 end()常被用於Hook想要執行最後的操做,如保存最後一個checkpoint 若是sess.run()拋出除了表明迭代結束的OutOfRange/StopIteration異常外, end()不會被調用 Args: session: A TensorFlow Session that will be soon closed. """ pass
tf.train.SessionRunHook()
類中定義的方法的參數run_context
,run_values
,run_args
,包含sess.run()
會話運行所需的一切信息,spa
run_context
:類tf.train.SessRunContext
的實例run_values
:類tf.train.SessRunValues
的實例run_args
:類tf.train.SessRunArgs
的實例.這三個類會在下面詳細介紹.net
(1)可使用tf中已經預約義好的Hook,其都是tf.train.SessionRunHook()的子類;如
(2)也可用tf.train.SessionRunHook()定義本身的Hook,並重寫類中的方法;而後把想要使用的Hook(預約義好的或者本身定義的)放到tf.train.MonitorTrainingSession()參數[Hook]列表中;
關於tf.train.MonitorTrainingSession()
參見tf.train.MonitoredTrainingSession()解析。
給一個定義本身Hook的栗子,來自cifar10
class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time#duration持續的時間 self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch))
這三個類都服務於sess.run(),區別以下:
(1) tf.train.SessRunArgs類
提供給會話運行的參數,與sess.run()參數定義同樣:
fethes,feeds,option
(2) tf.train.SessRunValues
用於保存sess.run()的結果,其中resluts是sess.run()返回值中對應於SessRunArgs()的返回值,
(3) tf.train.SessRunContext
SessRunContext包含sess.run()所需的一切信息
屬性:
方法:
equest_stop(): 設置_stop_request值爲True
tf.train.SessionRunHook()和tf.train.MonitorTrainingSession()通常一塊兒使用,下面是cifar10中的使用實例
class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time#duration持續的時間 self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) #monitored 被監控的 with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook()], config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement)) as mon_sess: while not mon_sess.should_stop(): mon_sess.run(train_op)