【學習筆記】分佈式Tensorflow

Tensorflow的一個特點就是分佈式計算。分佈式Tensorflow是由高性能的gRPC框架做爲底層技術來支持的。這是一個通訊框架gRPC(google remote procedure call),是一個高性能、跨平臺的RPC框架。RPC協議,即遠程過程調用協議,是指經過網絡從遠程計算機程序上請求服務。git

分佈式原理

Tensorflow分佈式是由多個服務器進程和客戶端進程組成。有幾種部署方式,例如單機多卡和多機多卡(分佈式)。算法

單機多卡

單機多卡是指單臺服務器有多塊GPU設備。假設一臺機器上有4塊GPU,單機多GPU的訓練過程以下:shell

  • 在單機單GPU的訓練中,數據是一個batch一個batch的訓練。 在單機多GPU中,數據一次處理4個batch(假設是4個GPU訓練), 每一個GPU處理一個batch的數據計算。
  • 變量,或者說參數,保存在CPU上。數據由CPU分發給4個GPU,在GPU上完成計算,獲得每一個批次要更新的梯度
  • 在CPU上收集完4個GPU上要更新的梯度,計算一下平均梯度,而後更新。
  • 循環進行上面步驟

多機多卡(分佈式)

而分佈式是指有多臺計算機,充分使用多臺計算機的性能,處理數據的能力。能夠根據不一樣計算機劃分不一樣的工做節點。當數據量或者計算量達到超過一臺計算機處理能力的上限的話,必須使用分佈式。api

分佈式的架構

當咱們知道的基本的分佈式原理以後,咱們來看看分佈式的架構的組成。分佈式架構的組成能夠說是一個集羣的組成方式。那麼通常咱們在進行Tensorflow分佈式時,須要創建一個集羣。一般是咱們分佈式的做業集合。一個做業中又包含了不少的任務(工做結點),每一個任務由一個工做進程來執行。服務器

節點之間的關係

通常來講,在分佈式機器學習框架中,咱們會把做業分紅參數做業(parameter job)和工做結點做業(worker job)。運行參數做業的服務器咱們稱之爲參數服務器(parameter server,PS),負責管理參數的存儲和更新,工做結點做業負責主要從事計算的任務,如運行操做。網絡

參數服務器,當模型愈來愈大時,模型的參數愈來愈多,多到一臺機器的性能不夠完成對模型參數的更新的時候,就須要把參數分開放到不一樣的機器去存儲和更新。參數服務器能夠是由多臺機器組成的集羣。工做節點是進行模型的計算的。Tensorflow的分佈式實現了做業間的數據傳輸,也就是參數做業到工做結點做業的前向傳播,以及工做節點到參數做業的反向傳播。session

分佈式的模式

在訓練一個模型的過程當中,有哪些部分能夠分開,放在不一樣的機器上運行呢?在這裏就要接觸到數據並行的概念。架構

數據並行

數據並總的原理很簡單。其中CPU主要負責梯度平均和參數更新,而GPU主要負責訓練模型副本。app

  • 模型副本定義在GPU上
  • 對於每個GPU,都是從CPU得到數據,前向傳播進行計算,獲得損失,並計算出梯度
  • CPU接到GPU的梯度,取平均值,而後進行梯度更新

每個設備的計算速度不同,有的快有的滿,那麼CPU在更新變量的時候,是應該等待每個設備的一個batch進行完成,而後求和取平均來更新呢?仍是讓一部分先計算完的就先更新,後計算完的將前面的覆蓋呢?這就由同步更新和異步更新的問題。

同步更新和異步更新

更新參數分爲同步和異步兩種方式,即異步隨機梯度降低法(Async-SGD)和同步隨機梯度降低法(Sync-SGD)

  • 同步隨即梯度降低法的含義是在進行訓練時,每一個節點的工做任務須要讀入共享參數,執行並行的梯度計算,同步須要等待全部工做節點把局部的梯度算好,而後將全部共享參數進行合併、累加,再一次性更新到模型的參數;下一個批次中,全部工做節點拿到模型更新後的參數再進行訓練。這種方案的優點是,每一個訓練批次都考慮了全部工做節點的訓練狀況,損失降低比較穩定;劣勢是,性能瓶頸在於最慢的工做結點上。
  • 異步隨機梯度降低法的含義是每一個工做結點上的任務獨立計算局部梯度,並異步更新到模型的參數中,不須要執行協調和等待操做。這種方案的優點是,性能不存在瓶頸;劣勢是,每一個工做節點計算的梯度值發送回參數服務器會有參數更新的衝突,必定程度上會影響算法的收斂速度,在損失降低的過程當中抖動較大。

分佈式API

建立集羣的方法是爲每個任務啓動一個服務,這些任務能夠分佈在不一樣的機器上,也能夠同一臺機器上啓動多個任務,使用不一樣的GPU等來運行。每一個任務都會建立完成如下工做

  • 一、建立一個tf.train.ClusterSpec,用於對集羣中的全部任務進行描述,該描述內容對全部任務應該是相同的
  • 二、建立一個tf.train.Server,用於建立一個任務,並運行相應做業上的計算任務。

Tensorflow的分佈式API使用以下:

  • tf.train.ClusterSpec()

建立ClusterSpec,表示參與分佈式TensorFlow計算的一組進程

cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222", /job:worker/task:0
                                           "worker1.example.com:2222", /job:worker/task:1
                                           "worker2.example.com:2222"],/job:worker/task:2
                                "ps": ["ps0.example.com:2222",       /job:ps/task:0
                                       "ps1.example.com:2222"]})        /job:ps/task:1

建立Tensorflow的集羣描述信息,其中ps和worker爲做業名稱,經過指定ip地址加端口建立

  • tf.train.Server(server_or_cluster_def, job_name=None, task_index=None, protocol=None, config=None, start=True)
    • server_or_cluster_def: 集羣描述
    • job_name: 任務類型名稱
    • task_index: 任務數

建立一個服務(主節點或者工做節點服務),用於運行相應做業上的計算任務,運行的任務在task_index指定的機器上啓動,例如在不一樣的ip+端口上啓動兩個工做任務

  • 屬性:target
    • 返回tf.Session鏈接到此服務器的目標
  • 方法:join()
    • 參數服務器端等待接受參數任務,直到服務器關閉

tf.device(device_name_or_function):選擇指定設備或者設備函數

  • if device_name:
    • 指定設備
    • 例如:"/job:worker/task:0/cpu:0」
  • if function:
    • tf.train.replica_device_setter(worker_device=worker_device, cluster=cluster)
    • 做用:經過此函數協調不一樣設備上的初始化操做
    • worker_device:爲指定設備, 「/job:worker/task:0/cpu:0」 or "/job:worker/task:0/gpu:0"
    • cluster:集羣描述對象

分佈式案例

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string("job_name", "worker", "啓動服務類型,ps或者worker")
tf.app.flags.DEFINE_integer("task_index", 0, "指定是哪一臺服務器索引")


def main(argv):

    # 集羣描述
    cluster = tf.train.ClusterSpec({
        "ps": ["127.0.0.1:4466"],
        "worker": ["127.0.0.1:4455"]
    })

    # 建立不一樣的服務
    server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)

    if FLAGS.job_name == "ps":
        server.join()
    else:
        work_device = "/job:worker/task:0/cpu:0"
        with tf.device(tf.train.replica_device_setter(
            worker_device=work_device,
            cluster=cluster
        )):

            # 全局計數器
            global_step = tf.train.get_or_create_global_step()

            # 準備數據
            mnist = input_data.read_data_sets("./data/mnist/", one_hot=True)

            # 創建數據的佔位符
            with tf.variable_scope("data"):
                x = tf.placeholder(tf.float32, [None, 28 * 28])
                y_true = tf.placeholder(tf.float32, [None, 10])

            # 創建全鏈接層的神經網絡
            with tf.variable_scope("fc_model"):
                # 隨機初始化權重和偏重
                weight = tf.Variable(tf.random_normal([28 * 28, 10], mean=0.0, stddev=1.0), name="w")
                bias = tf.Variable(tf.constant(0.0, shape=[10]))
                # 預測結果
                y_predict = tf.matmul(x, weight) + bias

            # 全部樣本損失值的平均值
            with tf.variable_scope("soft_loss"):
                loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))

            # 梯度降低
            with tf.variable_scope("optimizer"):
                train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss, global_step=global_step)

            # 計算準確率
            with tf.variable_scope("acc"):
                equal_list = tf.equal(tf.argmax(y_true, 1), tf.argmax(y_predict, 1))
                accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))

        # 建立分佈式會話
        with tf.train.MonitoredTrainingSession(
            checkpoint_dir="./temp/ckpt/test",
            master="grpc://127.0.0.1:4455",
            is_chief=(FLAGS.task_index == 0),
            config=tf.ConfigProto(log_device_placement=True),
            hooks=[tf.train.StopAtStepHook(last_step=100)]
        ) as mon_sess:
            while not mon_sess.should_stop():
                mnist_x, mnist_y = mnist.train.next_batch(4000)

                mon_sess.run(train_op, feed_dict={x: mnist_x, y_true: mnist_y})

                print("訓練第%d步, 準確率爲%f" % (global_step.eval(session=mon_sess), mon_sess.run(accuracy, feed_dict={x: mnist_x, y_true: mnist_y})))


if __name__ == '__main__':
    tf.app.run()

運行參數服務器:

$ python zfx.py --job_name=ps

運行worker服務器:

$ python zfx.py --job_name=worker
相關文章
相關標籤/搜索