【轉】tf.train.MonitoredTrainingSession()解析

原文地址:https://blog.csdn.net/mrr1ght/article/details/81006343。 本文有刪減。python

MonitoredTrainingSession定義

首先,tf.train.MonitorSession()從單詞的字面意思理解是用於監控訓練的回話,返回值是tf.train.MonitorSession()類的一個實例Object, tf.train.MonitorSession()會在下面講。session

MonitoredTrainingSession(
    master='',
    is_chief=True,
    checkpoint_dir=None,
    scaffold=None,
    hooks=None,
    chief_only_hooks=None,
    save_checkpoint_secs=600,
    save_summaries_steps=USE_DEFAULT,
    save_summaries_secs=USE_DEFAULT,
    config=None,
    stop_grace_period_secs=120,
    log_step_count_steps=100

Args:分佈式

  • is_chief:用於分佈式系統中,用於判斷該系統是不是chief,若是爲True,它將負責初始化並恢復底層TensorFlow會話。若是爲False,它將等待chief初始化或恢復TensorFlow會話。
  • checkpoint_dir:一個字符串。指定一個用於恢復變量的checkpoint文件路徑。
    scaffold:用於收集或創建支持性操做的腳手架。若是未指定,則會建立默認一個默認的scaffold。它用於完成圖表
    hooks:SessionRunHook對象的可選列表。可本身定義SessionRunHook對象,也可用已經預約義好的SessionRunHook對象,如:tf.train.StopAtStepHook()設置中止訓練的條件;tf.train.NanTensorHook(loss):若是loss的值爲Nan則中止訓練;
    chief_only_hooks:SessionRunHook對象列表。若是is_chief== True,則激活這些掛鉤,不然忽略。
    save_checkpoint_secs:用默認的checkpoint saver保存checkpoint的頻率(以秒爲單位)。若是save_checkpoint_secs設置爲None,不保存checkpoint。
  • save_summaries_steps:使用默認summaries saver將摘要寫入磁盤的頻率(以全局步數表示)。若是save_summaries_steps和save_summaries_secs都設置爲None,則不使用默認的summaries saver保存summaries。默認爲100
  • save_summaries_secs:使用默認summaries saver將摘要寫入磁盤的頻率(以秒爲單位)。若是save_summaries_steps和save_summaries_secs都設置爲None,則不使用默認的摘要保存。默認未啓用。
  • config:用於配置會話的tf.ConfigProtoproto的實例。它是tf.Session的構造函數的config參數。
    stop_grace_period_secs:調用close()後線程中止的秒數。
    log_step_count_steps:記錄全局步/秒的全局步數的頻率

Returns:  一個·MonitoredSession(·) 實例。函數

tf.train.MonitoredSession()使用示例

saver_hook = CheckpointSaverHook(...)
summary_hook = SummarySaverHook(...)
with MonitoredSession(session_creator=ChiefSessionCreator(...),
                      hooks=[saver_hook, summary_hook]) as sess:
    while not sess.should_stop():
        sess.run(train_op)

Args:fetch

  • session_creator:制定用於建立回話的ChiefSessionCreator
  • hooks:tf.train.SessionRunHook()實例的列表

Returns: 一個MonitoredSession 實例。spa

 

  • 初始化:在建立一個MonitoredSession時,會按順序執行如下操做:.net

    • 調用[Hooks]列表中每個Hook的begin()函數
    • 經過scaffold.finalize()完成圖graph的定義
    • 建立會話
    • 用Scaffold提供的初始化操做(op)來初始化模型
    • 若是給定checkpoint_dir中存在checkpoint文件,則用checkpoint恢復變量
    • 啓動隊列線程
    • 調用hook.after_create_session()
  • Run:當調用run()函數時,按順序執行如下操做線程

    • 調用hook.before_run()
    • 用合併後的fetches 和feed_dict調用TensorFlow的session.run() (這裏是真正調用tf.Session().run(fetches ,feed_dict))
    • 調用hook.after_run()
    • 返回用戶須要的session.run()的結果
    • 若是發生了AbortedError或者UnavailableError,則在再次執行run()以前恢復或者從新初始化會話
  • Exit:當調用close()退出時,按順序執行下列操做
    • 調用hook.end()
    • 關閉隊列線程queuerunners和會話session
    • 在monitored_session的上下文中,抑制因爲處理完全部輸入拋出的OutOf Range錯誤。


MARSGGBO原創




2019-10-21 11:23:38

相關文章
相關標籤/搜索