#coding=utf-8 #python sync_dist_train.py --job_name=ps --task_index=0 --issync=1 #python sync_dist_train.py --job_name=worker --task_index=0 --issync=1 #python sync_dist_train.py --job_name=worker --task_index=1 --issync=1 import time import numpy as np import tensorflow as tf # Define parameters FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.') tf.app.flags.DEFINE_integer('steps_to_validate', 1000, 'Steps to validate and print loss') # For distributed tf.app.flags.DEFINE_string("ps_hosts", "127.0.0.1:2222", "Comma-separated list of hostname:port pairs") tf.app.flags.DEFINE_string("worker_hosts", "127.0.0.1:2224,127.0.0.1:2225,127.0.0.1:2226", "Comma-separated list of hostname:port pairs") tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'") tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job") tf.app.flags.DEFINE_integer("issync", 0, "是否採用分佈式的同步模式,1表示同步模式,0表示異步模式") # Hyperparameters learning_rate = FLAGS.learning_rate steps_to_validate = FLAGS.steps_to_validate def main(_): ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",") cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index) worker_count = len(worker_hosts) issync = FLAGS.issync if FLAGS.job_name == "ps": server.join() elif FLAGS.job_name == "worker": with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)): global_step = tf.Variable(0, name='global_step', trainable=False) X = tf.placeholder(tf.float32) Y = tf.placeholder(tf.float32) w = tf.Variable(0.0, name="weight") b = tf.Variable(0.0, name="reminder") y = w * X + b loss = tf.reduce_mean(tf.square(y - Y)) optimizer = tf.train.GradientDescentOptimizer(learning_rate) if issync == 1: #同步模式 optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=len(worker_hosts), total_num_replicas=len(worker_hosts), use_locking=True) sync_replicas_hook = optimizer.make_session_run_hook(FLAGS.task_index == 0) #更新梯度 train_op = optimizer.minimize(loss, global_step=global_step) hooks = [tf.train.StopAtStepHook(last_step=1000000)] if issync == 1: hooks.append(sync_replicas_hook) with tf.train.MonitoredTrainingSession( master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="./train_logs", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): train_x = np.random.randn(1) train_y = 2 * train_x + np.random.randn(1) * 0.33 + 10 _, loss_v, step = mon_sess.run([train_op, loss,global_step], feed_dict={X:train_x, Y:train_y}) if step % steps_to_validate == 0: w_,b_ = mon_sess.run([w,b]) print("step: %d, weight: %f, biase: %f, loss: %f" %(step, w_, b_, loss_v)) if __name__ == "__main__": tf.app.run()
TensorFlow的同步訓練實現是在tensorflow/python/training/sync_replicas_optimizer.py文件。