Tensorflow API提供了Cluster、Server以及Supervisor來支持模型的分佈式訓練。html
關於Tensorflow的分佈式訓練介紹能夠參考Distributed Tensorflow。簡單的歸納說明以下:python
Tensorflow分佈式集羣的全部節點執行的代碼是相同的。分佈式任務代碼具備固定的模式:git
# 第1步:命令行參數解析,獲取集羣的信息ps_hosts和worker_hosts,以及當前節點的角色信息job_name和task_index # 第2步:建立當前task結點的Server cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) # 第3步:若是當前節點是ps,則調用server.join()無休止等待;若是是worker,則執行第4步。 if FLAGS.job_name == "ps": server.join() # 第4步:則構建要訓練的模型 # build tensorflow graph model # 第5步:建立tf.train.Supervisor來管理模型的訓練過程 # Create a "supervisor", which oversees the training process. sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), logdir="/tmp/train_logs") # The supervisor takes care of session initialization and restoring from a checkpoint. sess = sv.prepare_or_wait_for_session(server.target) # Loop until the supervisor shuts down while not sv.should_stop() # train model
根據上面說到的Tensorflow分佈式訓練代碼固定模式,若是要編寫一個分佈式的Tensorlfow代碼,其框架以下所示。github
import tensorflow as tf # Flags for defining the tf.train.ClusterSpec tf.app.flags.DEFINE_string("ps_hosts", "", "Comma-separated list of hostname:port pairs") tf.app.flags.DEFINE_string("worker_hosts", "", "Comma-separated list of hostname:port pairs") # Flags for defining the tf.train.Server tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'") tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job") FLAGS = tf.app.flags.FLAGS def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts(",") # Create a cluster from the parameter server and worker hosts. cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) # Create and start a server for the local task. server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == "ps": server.join() elif FLAGS.job_name == "worker": # Assigns ops to the local worker by default. with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)): # Build model... loss = ... global_step = tf.Variable(0) train_op = tf.train.AdagradOptimizer(0.01).minimize( loss, global_step=global_step) saver = tf.train.Saver() summary_op = tf.merge_all_summaries() init_op = tf.initialize_all_variables() # Create a "supervisor", which oversees the training process. sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), logdir="/tmp/train_logs", init_op=init_op, summary_op=summary_op, saver=saver, global_step=global_step, save_model_secs=600) # The supervisor takes care of session initialization and restoring from # a checkpoint. sess = sv.prepare_or_wait_for_session(server.target) # Start queue runners for the input pipelines (if any). sv.start_queue_runners(sess) # Loop until the supervisor shuts down (or 1000000 steps have completed). step = 0 while not sv.should_stop() and step < 1000000: # Run a training step asynchronously. # See `tf.train.SyncReplicasOptimizer` for additional details on how to # perform *synchronous* training. _, step = sess.run([train_op, global_step]) if __name__ == "__main__": tf.app.run()
對於全部Tensorflow分佈式代碼,可變的只有兩點:docker
咱們經過修改tensorflow/tensorflow提供的mnist_softmax.py來構造分佈式的MNIST樣例來進行驗證。修改後的代碼請參考mnist_dist.py。api
咱們一樣經過tensorlfow的Docker image來啓動一個容器來進行驗證。bash
$ docker run -d -v /path/to/your/code:/tensorflow/mnist --name tensorflow tensorflow/tensorflow
啓動tensorflow以後,啓動4個Terminal,而後經過下面命令進入tensorflow容器,切換到/tensorflow/mnist目錄下服務器
$ docker exec -ti tensorflow /bin/bash
$ cd /tensorflow/mnist
而後在四個Terminal中分別執行下面一個命令來啓動Tensorflow cluster的一個task節點,session
# Start ps 0 python mnist_dist.py --ps_hosts=localhost:2221,localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=ps --task_index=0 # Start ps 1 python mnist_dist.py --ps_hosts=localhost:2221,localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=ps --task_index=1 # Start worker 0 python mnist_dist.py --ps_hosts=localhost:2221,localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=worker --task_index=0 # Start worker 1 python mnist_dist.py --ps_hosts=localhost:2221,localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=worker --task_index=1
具體效果本身驗證哈。app