分佈式TensorFlow由高性能gRPC庫底層技術支持。Martin Abadi、Ashish Agarwal、Paul Barham論文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》。python
分佈式原理。分佈式集羣 由多個服務器進程、客戶端進程組成。部署方式,單機多卡、分佈式(多機多卡)。多機多卡TensorFlow分佈式。git
單機多卡,單臺服務器多塊GPU。訓練過程:在單機單GPU訓練,數據一個批次(batch)一個批次訓練。單機多GPU,一次處理多個批次數據,每一個GPU處理一個批次數據計算。變量參數保存在CPU,數據由CPU分發給多個GPU,GPU計算每一個批次更新梯度。CPU收集完多個GPU更新梯度,計算平均梯度,更新參數。繼續計算更新梯度。處理速度取決最慢GPU速度。github
分佈式,訓練在多個工做節點(worker)。工做節點,實現計算單元。計算服務器單卡,指服務器。計算服務器多卡,多個GPU劃分多個工做節點。數據量大,超過一臺機器處理能力,須用分佈式。算法
分佈式TensorFlow底層通訊,gRPC(google remote procedure call)。gRPC,谷歌開源高性能、跨語言RPC框架。RPC協議,遠程過程調用協議,網絡從遠程計算機程度請求服務。數據庫
分佈式部署方式。分佈式運行,多個計算單元(工做節點),後端服務器部署單工做節點、多工做節點。後端
單工做節點部署。每臺服務器運行一個工做節點,服務器多個GPU,一個工做節點能夠訪問多塊GPU卡。代碼tf.device()指定運行操做設備。優點,單機多GPU間通訊,效率高。劣勢,手動代碼指定設備。服務器
多工做節點部署。一臺服務器運行多個工做節點。微信
設置CUDA_VISIBLE_DEVICES環境變量,限制各個工做節點只可見一個GPU,啓動進程添加環境變量。用tf.device()指定特定GPU。多工做節點部署優點,代碼簡單,提升GPU使用率。劣勢,工做節點通訊,需部署多個工做節點。https://github.com/tobegit3hub/tensorflow_examples/tree/master/distributed_tensorflow 。網絡
CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=0
CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=1
CUDA_VISIBLE_DEVICES='0' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=0
CUDA_VISIBLE_DEVICES='1' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=1session
分佈式架構。https://www.tensorflow.org/extend/architecture 。客戶端(client)、服務端(server),服務端包括主節點(master)、工做節點(worker)組成。
客戶端、主節點、工做節點關係。TensorFlow,客戶端會話聯繫主節點,實際工做由工做節點實現,每一個工做節點佔一臺設備(TensorFlow具體計算硬件抽象,CPU或GPU)。單機模式,客戶端、主節點、工做節點在同一臺服務器。分佈模式,可不一樣服務器。客戶端->主節點->工做節點/job:worker/task:0->/job:ps/task:0。
客戶端。創建TensorFlow計算圖,創建與集羣交互會話層。代碼包含Session()。一個客戶端可同時與多個服務端相連,一具服務端也可與多個客戶端相連。
服務端。運行tf.train.Server實例進程,TensroFlow執行任務集羣(cluster)一部分。有主節點服務(Master service)和工做節點服務(Worker service)。運行中,一個主節點進程和數個工做節點進程,主節點進程和工做接點進程經過接口通訊。單機多卡和分佈式結構相同,只須要更改通訊接口實現切換。
主節點服務。實現tensorflow::Session接口。經過RPC服務程序鏈接工做節點,與工做節點服務進程工做任務通訊。TensorFlow服務端,task_index爲0做業(job)。
工做節點服務。實現worker_service.proto接口,本地設備計算部分圖。TensorFlow服務端,全部工做節點包含工做節點服務邏輯。每一個工做節點負責管理一個或多個設備。工做節點能夠是本地不一樣端口不一樣進程,或多臺服務多個進程。運行TensorFlow分佈式執行任務集,一個或多個做業(job)。每一個做業,一個或多個相同目的任務(task)。每一個任務,一個工做進程執行。做業是任務集合,集羣是做業集合。
分佈式機器學習框架,做業分參數做業(parameter job)和工做節點做業(worker job)。參數做業運行服務器爲參數服務器(parameter server,PS),管理參數存儲、更新。工做節點做業,管理無狀態主要從事計算任務。模型越大,參數越多,模型參數更新超過一臺機器性能,須要把參數分開到不一樣機器存儲更新。參數服務,多臺機器組成集羣,相似分佈式存儲架構,涉及數據同步、一致性,參數存儲爲鍵值對(key-value)。分佈式鍵值內存數據庫,加參數更新操做。李沐《Parameter Server for Distributed Machine Learning》http://www.cs.cmu.edu/~muli/file/ps.pdf 。
參數存儲更新在參數做業進行,模型計算在工做節點做業進行。TensorFlow分佈式實現做業間數據傳輸,參數做業到工做節點做業前向傳播,工做節點做業到參數做業反向傳播。
任務。特定TensorFlow服務器獨立進程,在做業中擁有對應序號。一個任務對應一個工做節點。集羣->做業->任務->工做節點。
客戶端、主節點、工做節點交互過程。單機多卡交互,客戶端->會話運行->主節點->執行子圖->工做節點->GPU0、GPU1。分佈式交互,客戶端->會話運行->主節點進程->執行子圖1->工做節點進程1->GPU0、GPU1。《TensorFlow:Large-Scale Machine Learning on Heterogeneous distributed Systems》https://arxiv.org/abs/1603.04467v1 。
分佈式模式。
數據並行。https://www.tensorflow.org/tutorials/deep_cnn 。CPU負責梯度平均、參數更新,不一樣GPU訓練模型副本(model replica)。基於訓練樣例子集訓練,模型有獨立性。
步驟:不一樣GPU分別定義模型網絡結構。單個GPU從數據管道讀取不一樣數據塊,前向傳播,計算損失,計算當前變量梯度。全部GPU輸出梯度數據轉移到CPU,梯度求平均操做,模型變量更新。重複,直到模型變量收斂。
數據並行,提升SGD效率。SGD mini-batch樣本,切成多份,模型複製多份,在多個模型上同時計算。多個模型計算速度不一致,CPU更新變量有同步、異步兩個方案。
同步更新、異步更新。分佈式隨機梯度降低法,模型參數分佈式存儲在不一樣參數服務上,工做節點並行訓練數據,和參數服務器通訊獲取模型參數。
同步隨機梯度降低法(Sync-SGD,同步更新、同步訓練),訓練時,每一個節點上工做任務讀入共享參數,執行並行梯度計算,同步須要等待全部工做節點把局部梯度處好,將全部共享參數合併、累加,再一次性更新到模型參數,下一批次,全部工做節點用模型更新後參數訓練。優點,每一個訓練批次考慮全部工做節點訓練情部,損失降低穩定。劣勢,性能瓶頸在最慢工做節點。異楹設備,工做節點性能不一樣,劣勢明顯。
異步隨機梯度降低法(Async-SGD,異步更新、異步訓練),每一個工做節點任務獨立計算局部梯度,異步更新到模型參數,不需執行協調、等待操做。優點,性能不存在瓶頸。劣勢,每一個工做節點計算梯度值發磅回參數服務器有參數更新衝突,影響算法收劍速度,損失降低過程抖動較大。
同步更新、異步更新實現區別於更新參數服務器參數策略。數據量小,各節點計算能力較均衡,用同步模型。數據量大,各機器計算性能良莠不齊,用異步模式。
帶備份的Sync-SGD(Sync-SDG with backup)。Jianmin Chen、Xinghao Pan、Rajat Monga、Aamy Bengio、Rafal Jozefowicz論文《Revisiting Distributed Synchronous SGD》https://arxiv.org/abs/1604.00981 。增長工做節點,解決部分工做節點計算慢問題。工做節點總數n+n*5%,n爲集羣工做節點數。異步更新設定接受到n個工做節點參數直接更新參數服務器模型參數,進入下一批次模型訓練。計算較慢節點訓練參數直接丟棄。
同步更新、異步更新有圖內模式(in-graph pattern)和圖間模式(between-graph pattern),獨立於圖內(in-graph)、圖間(between-graph)概念。
圖內複製(in-grasph replication),全部操做(operation)在同一個圖中,用一個客戶端來生成圖,把全部操做分配到集羣全部參數服務器和工做節點上。國內複製和單機多卡相似,擴展到多機多卡,數據分發仍是在客戶端一個節點上。優點,計算節點只須要調用join()函數等待任務,客戶端隨時提交數據就能夠訓練。劣勢,訓練數據分發在一個節點上,要分發給不一樣工做節點,嚴重影響併發訓練速度。
圖間複製(between-graph replication),每個工做節點建立一個圖,訓練參數保存在參數服務器,數據不分發,各個工做節點獨立計算,計算完成把要更新參數告訴參數服務器,參數服務器更新參數。優點,不須要數據分發,各個工做節點都建立圖和讀取數據訓練。劣勢,工做節點既是圖建立者又是計算任務執行者,某個工做節點宕機影響集羣工做。大數據相關深度學習推薦使用圖間模式。
模型並行。切分模型,模型不一樣部分執行在不一樣設備上,一個批次樣本能夠在不一樣設備同時執行。TensorFlow儘可能讓相鄰計算在同一臺設備上完成節省網絡開銷。Martin Abadi、Ashish Agarwal、Paul Barham論文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》https://arxiv.org/abs/1603.04467v1 。
模型並行、數據並行,TensorFlow中,計算能夠分離,參數能夠分離。能夠在每一個設備上分配計算節點,讓對應參數也在該設備上,計算參數放一塊兒。
分佈式API。https://www.tensorflow.org/deploy/distributed 。
建立集羣,每一個任務(task)啓動一個服務(工做節點服務或主節點服務)。任務能夠分佈不一樣機器,能夠同一臺機器啓動多個任務,用不一樣GPU運行。每一個任務完成工做:建立一個tf.train.ClusterSpec,對集羣全部任務進行描述,描述內容對全部任務相同。建立一個tf.train.Server,建立一個服務,運行相應做業計算任務。
TensorFlow分佈式開發API。tf.train.ClusterSpec({"ps":ps_hosts,"worker":worke_hosts})。建立TensorFlow集羣描述信息,ps、worker爲做業名稱,ps_phsts、worker_hosts爲做業任務所在節點地址信息。tf.train.ClusterSpec傳入參數,做業和任務間關係映射,映射關係任務經過IP地址、端口號表示。
結構 tf.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
可用任務 /job:local/task:0、/job:local/task:1。
結構 tf.train.ClusterSpec({"worker":["worker0.example.com:2222","worker1.example.com:2222","worker2.example.com:2222"],"ps":["ps0.example.com:2222","ps1.example.com:2222"]})
可用任務 /job:worker/task:0、 /job:worker/task:1、 /job:worker/task:2、 /job:ps/task:0、 /job:ps/task:1
tf.train.Server(cluster,job_name,task_index)。建立服務(主節點服務或工做節點服務),運行做業計算任務,運行任務在task_index指定機器啓動。
#任務0
cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
server = tr.train.Server(cluster,job_name="local",task_index=0)
#任務1
cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
server = tr.train.Server(cluster,job_name="local",task_index=1)。
自動化管理節點、監控節點工具。集羣管理工具Kubernetes。
tf.device(device_name_or_function)。設定指定設備執行張量運算,批定代碼運行CPU、GPU。
#指定在task0所在機器執行Tensor操做運算
with tf.device("/job:ps/task:0"):
weights_1 = tf.Variable(…)
biases_1 = tf.Variable(…)
分佈式訓練代碼框架。建立TensorFlow服務器集羣,在該集羣分佈式計算數據流圖。https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/deploy/distributed.md 。
import argparse import sys import tensorflow as tf FLAGS = None def main(_): # 第1步:命令行參數解析,獲取集羣信息ps_hosts、worker_hosts # 當前節點角色信息job_name、task_index ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",") # 第2步:建立當前任務節點服務器 # Create a cluster from the parameter server and worker hosts. cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) # Create and start a server for the local task. server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) # 第3步:若是當前節點是參數服務器,調用server.join()無休止等待;若是是工做節點,執行第4步 if FLAGS.job_name == "ps": server.join() # 第4步:構建要訓練模型,構建計算圖 elif FLAGS.job_name == "worker": # Assigns ops to the local worker by default. with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)): # Build model... loss = ... global_step = tf.contrib.framework.get_or_create_global_step() train_op = tf.train.AdagradOptimizer(0.01).minimize( loss, global_step=global_step) # The StopAtStepHook handles stopping after running given steps. # 第5步管理模型訓練過程 hooks=[tf.train.StopAtStepHook(last_step=1000000)] # The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="/tmp/train_logs", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # Run a training step asynchronously. # See `tf.train.SyncReplicasOptimizer` for additional details on how to # perform *synchronous* training. # mon_sess.run handles AbortedError in case of preempted PS. # 訓練模型 mon_sess.run(train_op) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") # Flags for defining the tf.train.ClusterSpec parser.add_argument( "--ps_hosts", type=str, default="", help="Comma-separated list of hostname:port pairs" ) parser.add_argument( "--worker_hosts", type=str, default="", help="Comma-separated list of hostname:port pairs" ) parser.add_argument( "--job_name", type=str, default="", help="One of 'ps', 'worker'" ) # Flags for defining the tf.train.Server parser.add_argument( "--task_index", type=int, default=0, help="Index of task within the job" ) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
分佈式最佳實踐。https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py 。
MNIST數據集分佈式訓練。開設3個端口做分佈式工做節點部署,2222端口參數服務器,2223端口工做節點0,2224端口工做節點1。參數服務器執行參數更新任務,工做節點0、工做節點1執行圖模型訓練計算任務。參數服務器/job:ps/task:0 cocalhost:2222,工做節點/job:worker/task:0 cocalhost:2223,工做節點/job:worker/task:1 cocalhost:2224。
運行代碼。
python mnist_replica.py --job_name="ps" --task_index=0 python mnist_replica.py --job_name="worker" --task_index=0 python mnist_replica.py --job_name="worker" --task_index=1 from __future__ import absolute_import from __future__ import division from __future__ import print_function import math import sys import tempfile import time import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 定義常量,用於建立數據流圖 flags = tf.app.flags flags.DEFINE_string("data_dir", "/tmp/mnist-data", "Directory for storing mnist data") # 只下載數據,不作其餘操做 flags.DEFINE_boolean("download_only", False, "Only perform downloading of data; Do not proceed to " "session preparation, model definition or training") # task_index從0開始。0表明用來初始化變量的第一個任務 flags.DEFINE_integer("task_index", None, "Worker task index, should be >= 0. task_index=0 is " "the master worker task the performs the variable " "initialization ") # 每臺機器GPU個數,機器沒有GPU爲0 flags.DEFINE_integer("num_gpus", 1, "Total number of gpus for each machine." "If you don't use GPU, please set it to '0'") # 同步訓練模型下,設置收集工做節點數量。默認工做節點總數 flags.DEFINE_integer("replicas_to_aggregate", None, "Number of replicas to aggregate before parameter update" "is applied (For sync_replicas mode only; default: " "num_workers)") flags.DEFINE_integer("hidden_units", 100, "Number of units in the hidden layer of the NN") # 訓練次數 flags.DEFINE_integer("train_steps", 200, "Number of (global) training steps to perform") flags.DEFINE_integer("batch_size", 100, "Training batch size") flags.DEFINE_float("learning_rate", 0.01, "Learning rate") # 使用同步訓練、異步訓練 flags.DEFINE_boolean("sync_replicas", False, "Use the sync_replicas (synchronized replicas) mode, " "wherein the parameter updates from workers are aggregated " "before applied to avoid stale gradients") # 若是服務器已經存在,採用gRPC協議通訊;若是不存在,採用進程間通訊 flags.DEFINE_boolean( "existing_servers", False, "Whether servers already exists. If True, " "will use the worker hosts via their GRPC URLs (one client process " "per worker host). Otherwise, will create an in-process TensorFlow " "server.") # 參數服務器主機 flags.DEFINE_string("ps_hosts","localhost:2222", "Comma-separated list of hostname:port pairs") # 工做節點主機 flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "Comma-separated list of hostname:port pairs") # 本做業是工做節點仍是參數服務器 flags.DEFINE_string("job_name", None,"job name: worker or ps") FLAGS = flags.FLAGS IMAGE_PIXELS = 28 def main(unused_argv): mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) if FLAGS.download_only: sys.exit(0) if FLAGS.job_name is None or FLAGS.job_name == "": raise ValueError("Must specify an explicit `job_name`") if FLAGS.task_index is None or FLAGS.task_index =="": raise ValueError("Must specify an explicit `task_index`") print("job name = %s" % FLAGS.job_name) print("task index = %d" % FLAGS.task_index) #Construct the cluster and start the server # 讀取集羣描述信息 ps_spec = FLAGS.ps_hosts.split(",") worker_spec = FLAGS.worker_hosts.split(",") # Get the number of workers. num_workers = len(worker_spec) # 建立TensorFlow集羣描述對象 cluster = tf.train.ClusterSpec({ "ps": ps_spec, "worker": worker_spec}) # 爲本地執行任務建立TensorFlow Server對象。 if not FLAGS.existing_servers: # Not using existing servers. Create an in-process server. # 建立本地Sever對象,從tf.train.Server這個定義開始,每一個節點開始不一樣 # 根據執行的命令的參數(做業名字)不一樣,決定這個任務是哪一個任務 # 若是做業名字是ps,進程就加入這裏,做爲參數更新的服務,等待其餘工做節點給它提交參數更新的數據 # 若是做業名字是worker,就執行後面的計算任務 server = tf.train.Server( cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) # 若是是參數服務器,直接啓動便可。這裏,進程就會阻塞在這裏 # 下面的tf.train.replica_device_setter代碼會將參數批定給ps_server保管 if FLAGS.job_name == "ps": server.join() # 處理工做節點 # 找出worker的主節點,即task_index爲0的點 is_chief = (FLAGS.task_index == 0) # 若是使用gpu if FLAGS.num_gpus > 0: # Avoid gpu allocation conflict: now allocate task_num -> #gpu # for each worker in the corresponding machine gpu = (FLAGS.task_index % FLAGS.num_gpus) # 分配worker到指定gpu上運行 worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu) # 若是使用cpu elif FLAGS.num_gpus == 0: # Just allocate the CPU to worker server # 把cpu分配給worker cpu = 0 worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu) # The device setter will automatically place Variables ops on separate # parameter servers (ps). The non-Variable ops will be placed on the workers. # The ps use CPU and workers use corresponding GPU # 用tf.train.replica_device_setter將涉及變量操做分配到參數服務器上,使用CPU。將涉及非變量操做分配到工做節點上,使用上一步worker_device值。 # 在這個with語句之下定義的參數,會自動分配到參數服務器上去定義。若是有多個參數服務器,就輪流循環分配 with tf.device( tf.train.replica_device_setter( worker_device=worker_device, ps_device="/job:ps/cpu:0", cluster=cluster)): # 定義全局步長,默認值爲0 global_step = tf.Variable(0, name="global_step", trainable=False) # Variables of the hidden layer # 定義隱藏層參數變量,這裏是全鏈接神經網絡隱藏層 hid_w = tf.Variable( tf.truncated_normal( [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w") hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b") # Variables of the softmax layer # 定義Softmax 迴歸層參數變量 sm_w = tf.Variable( tf.truncated_normal( [FLAGS.hidden_units, 10], stddev=1.0 / math.sqrt(FLAGS.hidden_units)), name="sm_w") sm_b = tf.Variable(tf.zeros([10]), name="sm_b") # Ops: located on the worker specified with FLAGS.task_index # 定義模型輸入數據變量 x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS]) y_ = tf.placeholder(tf.float32, [None, 10]) # 構建隱藏層 hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) hid = tf.nn.relu(hid_lin) # 構建損失函數和優化器 y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) # 異步訓練模式:本身計算完成梯度就去更新參數,不一樣副本之間不會去協調進度 opt = tf.train.AdamOptimizer(FLAGS.learning_rate) # 同步訓練模式 if FLAGS.sync_replicas: if FLAGS.replicas_to_aggregate is None: replicas_to_aggregate = num_workers else: replicas_to_aggregate = FLAGS.replicas_to_aggregate # 使用SyncReplicasOptimizer做優化器,而且是在圖間複製狀況下 # 在圖內複製狀況下將全部梯度平均 opt = tf.train.SyncReplicasOptimizer( opt, replicas_to_aggregate=replicas_to_aggregate, total_num_replicas=num_workers, name="mnist_sync_replicas") train_step = opt.minimize(cross_entropy, global_step=global_step) if FLAGS.sync_replicas: local_init_op = opt.local_step_init_op if is_chief: # 全部進行計算工做節點裏一個主工做節點(chief) # 主節點負責初始化參數、模型保存、概要保存 local_init_op = opt.chief_init_op ready_for_local_init_op = opt.ready_for_local_init_op # Initial token and chief queue runners required by the sync_replicas mode # 同步訓練模式所需初始令牌、主隊列 chief_queue_runner = opt.get_chief_queue_runner() sync_init_op = opt.get_init_tokens_op() init_op = tf.global_variables_initializer() train_dir = tempfile.mkdtemp() if FLAGS.sync_replicas: # 建立一個監管程序,用於統計訓練模型過程當中的信息 # lodger 是保存和加載模型路徑 # 啓動就會去這個logdir目錄看是否有檢查點文件,有的話就自動加載 # 沒有就用init_op指定初始化參數 # 主工做節點(chief)負責模型參數初始化工做 # 過程當中,其餘工做節點等待主節眯完成初始化工做,初始化完成後,一塊兒開始訓練數據 # global_step值是全部計算節點共享的 # 在執行損失函數最小值時自動加1,經過global_step知道全部計算節點一共計算多少步 sv = tf.train.Supervisor( is_chief=is_chief, logdir=train_dir, init_op=init_op, local_init_op=local_init_op, ready_for_local_init_op=ready_for_local_init_op, recovery_wait_secs=1, global_step=global_step) else: sv = tf.train.Supervisor( is_chief=is_chief, logdir=train_dir, init_op=init_op, recovery_wait_secs=1, global_step=global_step) # 建立會話,設置屬性allow_soft_placement爲True # 全部操做默認使用被指定設置,如GPU # 若是該操做函數沒有GPU實現,自動使用CPU設備 sess_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False, device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index]) # The chief worker (task_index==0) session will prepare the session, # while the remaining workers will wait for the preparation to complete. # 主工做節點(chief),task_index爲0節點初始化會話 # 其他工做節點等待會話被初始化後進行計算 if is_chief: print("Worker %d: Initializing session..." % FLAGS.task_index) else: print("Worker %d: Waiting for session to be initialized..." % FLAGS.task_index) if FLAGS.existing_servers: server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index] print("Using existing server at: %s" % server_grpc_url) # 建立TensorFlow會話對象,用於執行TensorFlow圖計算 # prepare_or_wait_for_session須要參數初始化完成且主節點準備好後,纔開始訓練 sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config) else: sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) print("Worker %d: Session initialization complete." % FLAGS.task_index) if FLAGS.sync_replicas and is_chief: # Chief worker will start the chief queue runner and call the init op. sess.run(sync_init_op) sv.start_queue_runners(sess, [chief_queue_runner]) # Perform training # 執行分佈式模型訓練 time_begin = time.time() print("Training begins @ %f" % time_begin) local_step = 0 while True: # Training feed # 讀入MNIST訓練數據,默認每批次100張圖片 batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size) train_feed = {x: batch_xs, y_: batch_ys} _, step = sess.run([train_step, global_step], feed_dict=train_feed) local_step += 1 now = time.time() print("%f: Worker %d: training step %d done (global step: %d)" % (now, FLAGS.task_index, local_step, step)) if step >= FLAGS.train_steps: break time_end = time.time() print("Training ends @ %f" % time_end) training_time = time_end - time_begin print("Training elapsed time: %f s" % training_time) # Validation feed # 讀入MNIST驗證數據,計算驗證的交叉熵 val_feed = {x: mnist.validation.images, y_: mnist.validation.labels} val_xent = sess.run(cross_entropy, feed_dict=val_feed) print("After %d training step(s), validation cross entropy = %g" % (FLAGS.train_steps, val_xent)) if __name__ == "__main__": tf.app.run()
參考資料:
《TensorFlow技術解析與實戰》
歡迎推薦上海機器學習工做機會,個人微信:qingxingfengzi