TensorFlow同步訓練原理

一、同步訓練例子:

#coding=utf-8

#python sync_dist_train.py --job_name=ps --task_index=0 --issync=1
#python sync_dist_train.py --job_name=worker --task_index=0 --issync=1
#python sync_dist_train.py --job_name=worker --task_index=1 --issync=1

import time
import numpy as np
import tensorflow as tf

# Define parameters
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.')
tf.app.flags.DEFINE_integer('steps_to_validate', 1000,
                     'Steps to validate and print loss')

# For distributed
tf.app.flags.DEFINE_string("ps_hosts", "127.0.0.1:2222",
                           "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "127.0.0.1:2224,127.0.0.1:2225,127.0.0.1:2226",
                           "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_integer("issync", 0, "是否採用分佈式的同步模式,1表示同步模式,0表示異步模式")


# Hyperparameters
learning_rate = FLAGS.learning_rate
steps_to_validate = FLAGS.steps_to_validate

def main(_):
  ps_hosts = FLAGS.ps_hosts.split(",")
  worker_hosts = FLAGS.worker_hosts.split(",")
  cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
  server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)

  worker_count = len(worker_hosts)

  issync = FLAGS.issync
  if FLAGS.job_name == "ps":
    server.join()
  elif FLAGS.job_name == "worker":
    with tf.device(tf.train.replica_device_setter(
                    worker_device="/job:worker/task:%d" % FLAGS.task_index,
                    cluster=cluster)):
      global_step = tf.Variable(0, name='global_step', trainable=False)

      X = tf.placeholder(tf.float32)
      Y = tf.placeholder(tf.float32)
      w = tf.Variable(0.0, name="weight")
      b = tf.Variable(0.0, name="reminder")
      y = w * X + b

      loss = tf.reduce_mean(tf.square(y - Y))
      optimizer = tf.train.GradientDescentOptimizer(learning_rate)

      if issync == 1:
        #同步模式
        optimizer = tf.train.SyncReplicasOptimizer(optimizer,
                                                replicas_to_aggregate=len(worker_hosts),
                                                total_num_replicas=len(worker_hosts),
                                                use_locking=True)
        sync_replicas_hook = optimizer.make_session_run_hook(FLAGS.task_index == 0)

      #更新梯度
      train_op = optimizer.minimize(loss, global_step=global_step)

      hooks = [tf.train.StopAtStepHook(last_step=1000000)]
      if issync == 1:
          hooks.append(sync_replicas_hook)

      with tf.train.MonitoredTrainingSession(
        master=server.target, is_chief=(FLAGS.task_index == 0),
        checkpoint_dir="./train_logs", hooks=hooks) as mon_sess:
          while not mon_sess.should_stop():
            train_x = np.random.randn(1)
            train_y = 2 * train_x + np.random.randn(1) * 0.33  + 10
            _, loss_v, step = mon_sess.run([train_op, loss,global_step], feed_dict={X:train_x, Y:train_y})
            if step % steps_to_validate == 0:
              w_,b_ = mon_sess.run([w,b])
              print("step: %d, weight: %f, biase: %f, loss: %f" %(step, w_, b_, loss_v))



if __name__ == "__main__":
  tf.app.run()

二、同步訓練原理:

TensorFlow的同步訓練實現是在tensorflow/python/training/sync_replicas_optimizer.py文件。

1、同步訓練關鍵組件及其關係:

2、一次同步訓練的交互流程: