tensorflow分佈式運行

一、知識點服務器

""" 單機多卡:一臺服務器上多臺設備(GPU) 參數服務器:更新參數,保存參數 工做服務器:主要功能是去計算 更新參數的模式: 一、同步模型更新 二、異步模型更新 工做服務器會默認一個機器做爲老大,建立會話 tensorflow設備命名規則: /job:ps/task:0 job:ps,服務器類型 task:0,服務器第幾臺 /job:worker/task:0/cpu:0 /job:worker/task:0/gpu:0 /job:worker/task:0/gpu:1 設備使用: 一、對集羣當中的一些ps,worker進行指定 二、建立對應的服務, ps:建立ps服務 join() worker建立worker服務,運行模型,程序,初始化會話等等 指定一個默認的worker去作 三、worker使用設備: with tf.device("/job:worker/task:0/gup:0"): 計算操做 四、分佈式使用設備: tf.train.replica_device_setter(worker_device=worker_device,cluster=cluster) 做用:經過此函數協調不一樣設備上的初始化操做 worker_device:爲指定設備, 「/job:worker/task:0/cpu:0」 or "/job:worker/task:0/gpu:0" cluster:集羣描述對象 API: 一、分佈式會話函數:MonitoredTrainingSession(master="",is_chief=True,checkpoint_dir=None,    hooks=None,save_checkpoint_secs=600,save_summaries_steps=USE_DEFAULT,save_summaries_secs=USE_DEFAULT,config=None) 參數: master:指定運行會話協議IP和端口(用於分佈式) "grpc://192.168.0.1:2000" is_chief:是否爲主worker(用於分佈式)若是True,它將負責初始化和恢復基礎的TensorFlow會話。 若是False,它將等待一位負責人初始化或恢復TensorFlow會話。 checkpoint_dir:檢查點文件目錄,同時也是events目錄 config:會話運行的配置項, tf.ConfigProto(log_device_placement=True) hooks:可選SessionRunHook對象列表 should_stop():是否異常中止 run():跟session同樣能夠運行op 二、tf.train.SessionRunHook Hook to extend calls to MonitoredSession.run() 一、begin():在會話以前,作初始化工做 二、before_run(run_context)在每次調用run()以前調用,以添加run()中的參數。 ARGS: run_context:一個SessionRunContext對象,包含會話運行信息 return:一個SessionRunArgs對象,例如:tf.train.SessionRunArgs(loss) 三、after_run(run_context,run_values)在每次調用run()後調用,通常用於運行以後的結果處理 該run_values參數包含所請求的操做/張量的結果 before_run()。 該run_context參數是相同的一個發送到before_run呼叫。  ARGS: run_context:一個SessionRunContext對象 run_values一個SessionRunValues對象, run_values.results 注:再添加鉤子類的時候,繼承SessionRunHook 三、tf.train.StopAtStepHook(last_step=5000)指定執行的訓練輪數也就是max_step,超過了就會拋出異常 tf.train.NanTensorHook(loss)判斷指定Tensor是否爲NaN,爲NaN則結束 注:在使用鉤子的時候須要定義一個全局步數:global_step = tf.contrib.framework.get_or_create_global_step() """

二、代碼session

import tensorflow as tf FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string("job_name", " ", "啓動服務的類型ps or worker") tf.app.flags.DEFINE_integer("task_index", 0, "指定ps或者worker當中的那一臺服務器以task:0 ,task:1") def main(argv): # 定義全集計數的op ,給鉤子列表當中的訓練步數使用
    global_step = tf.contrib.framework.get_or_create_global_step() # 一、指定集羣描述對象, ps , worker
    cluster = tf.train.ClusterSpec({"ps": ["10.211.55.3:2223"], "worker": ["192.168.65.44:2222"]}) # 二、建立不一樣的服務, ps, worker
    server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) # 根據不一樣服務作不一樣的事情 ps:去更新保存參數 worker:指定設備去運行模型計算
    if FLAGS.job_name == "ps": # 參數服務器什麼都不用幹,是須要等待worker傳遞參數
 server.join() else: worker_device = "/job:worker/task:0/cpu:0/"

        # 三、能夠指定設備取運行
 with tf.device(tf.train.replica_device_setter( worker_device=worker_device, cluster=cluster )): # 簡單作一個矩陣乘法運算
            x = tf.Variable([[1, 2, 3, 4]]) w = tf.Variable([[2], [2], [2], [2]]) mat = tf.matmul(x, w) # 四、建立分佈式會話
 with tf.train.MonitoredTrainingSession( master= "grpc://192.168.65.44:2222", # 指定主worker
            is_chief= (FLAGS.task_index == 0),# 判斷是不是主worker
            config=tf.ConfigProto(log_device_placement=True),# 打印設備信息
            hooks=[tf.train.StopAtStepHook(last_step=200)] ) as mon_sess: while not mon_sess.should_stop(): print(mon_sess.run(mat)) if __name__ == "__main__": tf.app.run()

三、分佈式架構圖架構

相關文章
相關標籤/搜索