[源碼解析] 機器學習參數服務器ps-lite (1) ----- PostOffice

[源碼解析] 機器學習參數服務器ps-lite 之(1) ----- PostOffice

0x00 摘要

參數服務器是機器學習訓練一種範式,是爲了解決分佈式機器學習問題的一個編程框架,其主要包括服務器端,客戶端和調度器,與其餘範式相比,參數服務器把模型參數存儲和更新提高爲主要組件,而且使用多種方法提升了處理能力。node

本文是參數服務器系列第一篇,介紹ps-lite的整體設計和基礎模塊 Postoffice。
lpython

0x01 概要

1.1 參數服務器是什麼

若是作一個類比,參數服務器是機器學習領域的分佈式內存數據庫,其做用是存儲模型和更新模型c++

咱們來看看機器學習的幾個步驟,這些步驟不斷循環往復。git

  1. 準備數據:訓練進程拿到權重 weight 和數據(data + label);
  2. 前向計算:訓練進程使用數據進行前向計算,獲得 loss = f(weight, data & label);
  3. 反向求導:經過對 loss 反向求導,獲得導數 grad = b(loss, weight, data & label);
  4. 更新權重:weight -= grad * lr;
  5. 來到1,進行下一次迭代;

若是使用參數服務器訓練,咱們能夠把如上步驟對應以下:github

  1. 參數下發:參數服務器服務端 將 weight 發給 每一個worker(或者worker自行拉取),worker就是參數服務器Client端;
  2. 並行計算:每一個worker 分別完成本身的計算(包括前向計算和反向求導);
  3. grad 收集:參數服務器服務端 從每一個 Worker 處獲得 grad,完成歸併(或者worker自行推送);
  4. 更新權重:參數服務器服務端 自行將 grad 應用到 weight 上;
  5. 來到1,進行下一次迭代;

具體以下圖:算法

FP/BP    +--------+  Gather/Sum                                       FP/BP            +-------+    Gather/Sum
      +----------> | grad 1 +------+                                    +----------------------> |grad 2 +-----------+
      |            +--------+      |                                    |                        +-------+           |
+-----+----+                       v                     +--------------+-------------------+                        v
|          |                   +---+----------+  Update  |                                  |                 +------+-----+ Update   +------------------+
| weight 1 |                   | total grad 1 +--------->+weight 2 = weight 1 - total grad 1|                 |total grad 2+--------> |weight 2 = ...... |
|          |                   +---+----------+          |                                  |                 +------+-----+          +------------------+
+-----+----+                       ^                     +--------------+-------------------+                        ^
      |   FP/BP    +--------+      |                                    |       FP/BP            +-------+           |
      +----------> | grad 2 +------+                                    +----------------------> |grad 2 +-----------+
                   +--------+  Gather/Sum                                                        +-------+    Gather/Sum

手機以下:shell

所以咱們能夠推導出參數服務器之中各個模塊的做用:數據庫

  • 服務器端(Server ):存放機器學習模型參數,接收客戶端發送的梯度,完成歸併,對本地模型參數進行更新。
  • 客戶端(Client 或者 Worker):
    • 從服務器端獲取當前最新的參數;
    • 使用訓練數據和從最新參數計算獲得預測值,根據損失函數來計算關於訓練參數的梯度;
    • 將梯度發送給服務器端;
  • 調度器(Scheduler):管理服務器/客戶端節點,完成節點之間數據同步,節點添加/刪除等功能。

1.2 歷史溯源

參數服務器屬於機器學習訓練的一個範式,具體能夠分爲三代(目前各大公司應該有本身內部最新實現,能夠算爲第四代)。編程

在參數服務器以前,大部分分佈式機器學習算法是經過按期同步來實現的,好比集合通訊的all-reduce,或者 map-reduce類系統的reduce步驟。可是按期同步有兩個問題:bash

  • 同步時期只能作同步,不能訓練。
  • straggler問題:因爲一些軟硬件的緣由,節點的計算能力每每不盡相同。對於迭代問題來講,每一輪結束時算得快的節點都需等待算得慢的節點算完,再進行下一輪迭代。這種等待在節點數增多時將變得尤其明顯,從而拖慢總體的性能。

所以,當async sgd出現以後,就有人提出了參數服務器。

參數服務器的概念最先來自於Alex Smola於2010年提出的並行LDA的框架。它經過採用一個分佈式的Memcached做爲存放共享參數的存儲,這樣就提供了有效的機制用於分佈式系統中不一樣的Worker之間同步模型參數,而每一個Worker只須要保存他計算時因此來的一小部分參數便可,也避免了全部進程在一個時間點上都停下來同步。可是獨立的kv對帶來了很大的通訊開銷,並且服務端端難以編程。

第二代由Google的Jeff Dean進一步提出了第一代Google大腦的解決方案:DistBelief。DistBelief將巨大的深度學習模型分佈存儲在全局的參數服務器中,計算節點經過參數服務器進行信息傳遞,很好地解決了SGD和L-BFGS算法的分佈式訓練問題。

再後來就是李沐所在的DMLC組所設計的參數服務器。根據論文中所寫,該parameter server屬於第三代參數服務器,就是提供了更加通用的設計。架構上包括一個Server Group和若干個Worker Group。

1.3 論文架構

咱們首先用沐神論文中的圖來看看系統架構。

解釋一下圖中總體架構中每一個模塊:

  • resource manager:資源分配及管理器。參數服務器使用業界現有的資源管理系統,好比yarn,k8s。
  • training data:幾十上百億的訓練數據通常存儲在分佈式文件系統上(好比HDFS),resource manager會均勻的分配到每一個worker上。
  • 參數服務器的節點被劃分到一個 server group 和多個 worker group。
  • server group:一次訓練任務中申請的servers,用於模型參數的更新和pull應答。
    • server group 中的每一個 server 只負責本身分到的部分全局共享參數(server 共同維持一個全局共享參數),通常優化器在此實現。
    • server 之間相互通訊以便進行參數的備份/遷移。
    • server group 有一個 server manager node,負責維護 server 元數據的一致性,例如節點狀態,參數的分配狀況。通常不會有什麼邏輯,只有當有server node加入或退出的時候,爲了維持一致性哈希而作一些調整。
  • worker group:一次訓練任務中申請的workers,用於前向過程和梯度計算。
    • 每一個 worker group 運行一個計算任務,worker group 中的 每一個worker 使用部分數據進行訓練。
    • 分紅多個group,這樣就能夠支持多任務的並行計算。
    • 每一個 worker group 有一個 task scheduler,負責向 worker 分配任務,並監控他們的運行狀況,當有 worker 進入或者退出時,task scheduler 從新分配未完成的任務。
    • worker 之間沒有通訊,只和對應的 server 通訊進行參數更新。

在分佈式計算梯度時,系統的數據流以下:

圖中每一個步驟的做用爲:

  1. worker 節點 基於該 batch 內的樣本計算模型權重的梯度;
  2. worker將梯度以key-value的形式推送給server;
  3. server按指定的優化器對模型權重進行梯度更新;
  4. worker從server中拉取最新的模型權重;

上面兩個圖的依據是其原始代碼。ps-lite 是後來的精簡版代碼,因此有些功能在 ps-lite 之中沒有提供。

1.4 ps-lite發展歷程

從網上找到了一些 ps-lite發展歷程,能夠看到其演進的思路。

第一代是parameter,針對特定算法(如邏輯迴歸和LDA)進行了設計和優化,以知足規模龐大的工業機器學習任務(數百億個示例和10-100TB數據大小的功能)。

後來嘗試爲機器學習算法構建一個開源通用框架。 該項目位於dmlc / parameter_server。

鑑於其餘項目的需求不斷增加,建立了ps-lite,它提供了一個乾淨的數據通訊API和一個輕量級的實現。 該實現基於dmlc / parameter_server,但爲不一樣的項目重構了做業啓動器,文件IO和機器學習算法代碼,如dmlc-core和wormhole

根據在開發dmlc / mxnet期間學到的經驗,從v1進一步重構了API和實現。 主要變化包括:

  • 庫依賴性較少;
  • 更靈活的用戶定義回調,便於其餘語言綁定;
  • 讓用戶(如mxnet的依賴引擎)管理數據一致性;

1.5 ps-lite 系統整體

ps-lite 實際上是Paramter Server的實現的一個框架,其中參數處理具體相關策略需用戶本身實現

Parameter Server包含三種角色:Worker,Server,Scheduler。具體關係以下圖:

具體角色功能爲:

  • worker(工做節點):若干個,執行data pipeline、前向和梯度計算,以key-value的形式將模型權重梯度push到server節點以及從server節點拉取模型最新權重;
  • server(服務節點):若干個,負責對worker的push和pull請求作response,存儲,維護和更新模型權重以供各個worker使用(每一個server僅維護模型的一部分);
  • scheduler(控制節點):系統內只有一個。負責全部節點的心跳監測、節點id分配和worker&server間的通訊創建,它還可用於將控制信號發送到其餘節點並收集其進度。

其中引入scheduler的好處以下:

  • 引入一個 scheduler 模塊,則會造成一個比較經典的三角色分佈式系統架構;worker 和 server 的角色和職責不變,而 scheduler 模塊則有比較多的選擇:
    • 只承擔和下層資源調度系統般若(相似 yarn、mesos)的交互;
    • 額外增長對 worker、server 心跳監控、流程控制的功能;
  • 引入 scheduler 模塊的另外一個好處是給實現模型並行留出了空間;
  • scheduler 模塊不只有利於實現模型並行訓練範式,還有其餘好處:好比經過針對特定模型參數相關性的理解,對參數訓練過程進行細粒度的調度,能夠進一步加快模型收斂速度,甚至有機會提高模型指標。

熟悉分佈式系統的同窗可能會擔憂 scheduler 模塊的單點問題,這個經過 raft、zab 等 paxos 協議能夠獲得比較好的解決。

1.6 基礎模塊

ps-lite系統中的一些基礎模塊以下:

  • Environment:一個單例模式的環境變量類,它經過一個 std::unordered_map<std::string, std::string> kvs 維護了一組 kvs 藉以保存全部環境變量名以及值;

  • PostOffice:一個單例模式的全局管理類,一個 node 在生命期內具備一個PostOffice,依賴它的類成員對Node進行管理;

  • Van:通訊模塊,負責與其餘節點的網絡通訊和Message的實際收發工做。PostOffice持有一個Van成員;

  • SimpleApp:KVServer和KVWorker的父類,它提供了簡單的Request, Wait, Response,Process功能;KVServer和KVWorker分別根據本身的使命重寫了這些功能;

  • Customer:每一個SimpleApp對象持有一個Customer類的成員,且Customer須要在PostOffice進行註冊,該類主要負責:

    • 跟蹤由SimpleApp發送出去的消息的回覆狀況;
    • 維護一個Node的消息隊列,爲Node接收消息;
  • Node :信息類,存儲了本節點的對應信息,每一個 Node 可使用 hostname + port 來惟一標識。

0x02 系統啓動

2.1 如何啓動

從源碼中的例子能夠看出,使用ps-lite 提供的腳本 local.sh 能夠啓動整個系統,這裏 test_connection 爲編譯好的可執行程序。

./local.sh 2 3 ./test_connection

2.2 啓動腳本

具體 local.sh 代碼以下。注意,在shell腳本中,有三個shift,這就讓腳本中始終使用$1。

針對咱們的例子,腳本參數對應了就是

  • DMLC_NUM_SERVER 爲 2;
  • DMLC_NUM_WORKER 爲 3;
  • bin 是 ./test_connection;

能夠從腳本中看到,本腳本作了兩件事:

  • 每次執行應用程序以前,都會依據本次執行的角色來對環境變量進行各類設定,除了DMLC_ROLE設置得不一樣外,其餘變量在每一個節點上都相同。
  • 在本地運行多個不一樣角色。這樣 ps-lite 就用多個不一樣的進程(程序)共同合做完成工做。
    • 首先啓動Scheduler節點。這是要固定好Server和Worker數量,Scheduler節點管理全部節點的地址。
    • 啓動Worker或Server節點。每一個節點要知道Scheduler節點的IP、port。啓動時鏈接Scheduler節點,綁定本地端口,並向Scheduler節點註冊本身信息(報告本身的IP,port)。
    • Scheduler等待全部Worker節點都註冊後,給其分配id,並把節點信息傳送出去(例如Worker節點要知道Server節點IP和端口,Server節點要知道Worker節點的IP和端口)。此時Scheduler節點已經準備好。
    • Worker或Server接收到Scheduler傳送的信息後,創建對應節點的鏈接。此時Worker或Server已經準備好,會正式啓動。

具體以下:

#!/bin/bash
# set -x
if [ $# -lt 3 ]; then
    echo "usage: $0 num_servers num_workers bin [args..]"
    exit -1;
fi

# 對環境變量進行各類配置,此後不一樣節點都會從這些環境變量中獲取信息
export DMLC_NUM_SERVER=$1
shift
export DMLC_NUM_WORKER=$1
shift
bin=$1
shift
arg="$@"

# start the scheduler
export DMLC_PS_ROOT_URI='127.0.0.1'
export DMLC_PS_ROOT_PORT=8000
export DMLC_ROLE='scheduler'
${bin} ${arg} &


# start servers
export DMLC_ROLE='server'
for ((i=0; i<${DMLC_NUM_SERVER}; ++i)); do
    export HEAPPROFILE=./S${i}
    ${bin} ${arg} &
done

# start workers
export DMLC_ROLE='worker'
for ((i=0; i<${DMLC_NUM_WORKER}; ++i)); do
    export HEAPPROFILE=./W${i}
    ${bin} ${arg} &
done

wait

2.3 示例程序

咱們依然使用官方例子看看。

ps-lite 使用的是 C++語言,其中 worker, server, scheduler 都使用同一套代碼。這會讓習慣於Java,python的同窗很是不適應,你們須要適應一個階段。

針對這個示例程序,起初會讓人疑惑,爲何每次程序運行,代碼中都會啓動 scheduler,worker,server?其實,從下面註釋就能看出來,具體執行是依據環境變量來決定。若是環境變量設置了本次角色是 server,則不會啓動 scheduler 和 worker。

#include <cmath>
#include "ps/ps.h"

using namespace ps;

void StartServer() {
  if (!IsServer()) {
    return;
  }
  auto server = new KVServer<float>(0);
  server->set_request_handle(KVServerDefaultHandle<float>()); //註冊functor
  RegisterExitCallback([server](){ delete server; });
}

void RunWorker() {
  if (!IsWorker()) return;
  KVWorker<float> kv(0, 0);

  // init
  int num = 10000;
  std::vector<Key> keys(num);
  std::vector<float> vals(num);

  int rank = MyRank();
  srand(rank + 7);
  for (int i = 0; i < num; ++i) {
    keys[i] = kMaxKey / num * i + rank;
    vals[i] = (rand() % 1000);
  }

  // push
  int repeat = 50;
  std::vector<int> ts;
  for (int i = 0; i < repeat; ++i) {
    ts.push_back(kv.Push(keys, vals)); //kv.Push()返回的是該請求的timestamp

    // to avoid too frequency push, which leads huge memory usage
    if (i > 10) kv.Wait(ts[ts.size()-10]);
  }
  for (int t : ts) kv.Wait(t);

  // pull
  std::vector<float> rets;
  kv.Wait(kv.Pull(keys, &rets));

  // pushpull
  std::vector<float> outs;
  for (int i = 0; i < repeat; ++i) {
    // PushPull on the same keys should be called serially
    kv.Wait(kv.PushPull(keys, vals, &outs));
  }

  float res = 0;
  float res2 = 0;
  for (int i = 0; i < num; ++i) {
    res += std::fabs(rets[i] - vals[i] * repeat);
    res2 += std::fabs(outs[i] - vals[i] * 2 * repeat);
  }
  CHECK_LT(res / repeat, 1e-5);
  CHECK_LT(res2 / (2 * repeat), 1e-5);
  LL << "error: " << res / repeat << ", " << res2 / (2 * repeat);
}

int main(int argc, char *argv[]) {
  // start system
  Start(0); // Postoffice::start(),每一個node都會調用到這裏,可是在 Start 函數之中,會依據本次設定的角色來不一樣處理,只有角色爲 scheduler 纔會啓動 Scheduler。
  // setup server nodes
  StartServer(); // Server會在其中作有效執行,其餘節點不會有效執行。
  // run worker nodes
  RunWorker(); // Worker 會在其中作有效執行,其餘節點不會有效執行。
  // stop system
  Finalize(0, true); //結束。每一個節點都須要執行這個函數。
  return 0;
}

其中KVServerDefaultHandle是functor,用與處理server收到的來自worker的請求,具體以下:

/**
 * \brief an example handle adding pushed kv into store
 */
template <typename Val>
struct KVServerDefaultHandle { //functor,用與處理server收到的來自worker的請求
    // req_meta 是存儲該請求的一些元信息,好比請求來自於哪一個節點,發送給哪一個節點等等
    // req_data 是發送過來的數據
    // server 是指向當前server對象的指針  
  void operator()(
      const KVMeta& req_meta, const KVPairs<Val>& req_data, KVServer<Val>* server) {
    size_t n = req_data.keys.size();
    KVPairs<Val> res;
    if (!req_meta.pull) { //收到的是pull請求
      CHECK_EQ(n, req_data.vals.size());
    } else { //收到的是push請求
      res.keys = req_data.keys; res.vals.resize(n);
    }
    for (size_t i = 0; i < n; ++i) {
      Key key = req_data.keys[i];
      if (req_meta.push) { //push請求
        store[key] += req_data.vals[i]; //此處的操做是將相同key的value相加
      }
      if (req_meta.pull) {  //pull請求
        res.vals[i] = store[key];
      }
    }
    server->Response(req_meta, res);
  }
  std::unordered_map<Key, Val> store;
};

0x03 Postoffice

Postoffice 是一個單例模式的全局管理類,其維護了系統的一個全局信息,具備以下特色:

  • 三種Node角色都依賴 Postoffice 進行管理,每個 node 在生命期內具備一個單例 PostOffice。
  • 如咱們以前所說,ps-lite的特色是 worker, server, scheduler 都使用同一套代碼,Postoffice也是如此,因此咱們最好分開描述。
  • 在 Scheduler側,顧名思義,Postoffice 是郵局,能夠認爲是一個地址簿,一個調控中心,其記錄了系統(由scheduler,server, worker 集體構成的這個系統)中全部節點的信息。具體功能以下:
    • 維護了一個Van對象,負責整個網絡的拉起、通訊、命令管理如增長節點、移除節點、恢復節點等等;
    • 負責整個集羣基本信息的管理,好比worker、server數的獲取,管理全部節點的地址,server 端 feature分佈的獲取,worker/server Rank與node id的互轉,節點角色身份等等;
    • 負責 Barrier 功能;
  • 在 Server / Worker 端,負責:
    • 配置當前node的一些信息,例如當前node是哪一種類型(server,worker),nodeid是啥,以及worker/server 的rank 到 node id的轉換。
    • 路由功能:負責 key 與 server 的對應關係。
    • Barrier 功能;

請注意:這些代碼都是在 Postoffice 類內,沒有按照角色分開成多個模塊。

3.1 定義

類 UML 圖以下:

下面咱們只給出關鍵變量和成員函數說明,由於每一個節點都包含一個 PostOffice,因此 PostOffice 的數據結構中包括了各類節點所須要的變量,會顯得比較繁雜。

主要變量做用以下:

  • van_ :底層通信對象;
  • customers_ :本節點目前有哪些 customer;
  • node_ids_ :node id 映射表;
  • server_key_ranges_ :Server key 區間範圍對象
  • is_worker_, is_server_, is_scheduler_ :標註了本節點類型;
  • heartbeats_ :節點心跳對象;
  • barrier_done_ : Barrier 同步變量;

主要函數做用以下:

  • InitEnvironment :初始化環境變量,建立 van 對象;
  • Start :創建通訊初始化;
  • Finalize :節點阻塞退出;
  • Manage :退出 barrier 阻塞狀態;
  • Barrier :進入 barrier 阻塞狀態;
  • UpdateHeartbeat :
  • GetDeadNodes :根據 heartbeats_ 獲取已經 dead 的節點;

具體以下:

class Postoffice {
  /**
   * \brief start the system
   *
   * This function will block until every nodes are started.
   * \param argv0 the program name, used for logging.
   * \param do_barrier whether to block until every nodes are started.
   */
  void Start(int customer_id, const char* argv0, const bool do_barrier);
  /**
   * \brief terminate the system
   *
   * All nodes should call this function before existing.
   * \param do_barrier whether to do block until every node is finalized, default true.
   */
  void Finalize(const int customer_id, const bool do_barrier = true);
  /**
   * \brief barrier
   * \param node_id the barrier group id
   */
  void Barrier(int customer_id, int node_group);
  /**
   * \brief process a control message, called by van
   * \param the received message
   */
  void Manage(const Message& recv);
  /**
   * \brief update the heartbeat record map
   * \param node_id the \ref Node id
   * \param t the last received heartbeat time
   */
  void UpdateHeartbeat(int node_id, time_t t) {
    std::lock_guard<std::mutex> lk(heartbeat_mu_);
    heartbeats_[node_id] = t;
  }
  /**
   * \brief get node ids that haven't reported heartbeats for over t seconds
   * \param t timeout in sec
   */
  std::vector<int> GetDeadNodes(int t = 60);  
 private:  
 void InitEnvironment();  
  Van* van_;
  mutable std::mutex mu_;
  // app_id -> (customer_id -> customer pointer)
  std::unordered_map<int, std::unordered_map<int, Customer*>> customers_;
  std::unordered_map<int, std::vector<int>> node_ids_;
  std::mutex server_key_ranges_mu_;
  std::vector<Range> server_key_ranges_;
  bool is_worker_, is_server_, is_scheduler_;
  int num_servers_, num_workers_;
  std::unordered_map<int, std::unordered_map<int, bool> > barrier_done_;
  int verbose_;
  std::mutex barrier_mu_;
  std::condition_variable barrier_cond_;
  std::mutex heartbeat_mu_;
  std::mutex start_mu_;
  int init_stage_ = 0;
  std::unordered_map<int, time_t> heartbeats_;
  Callback exit_callback_;
  /** \brief Holding a shared_ptr to prevent it from being destructed too early */
  std::shared_ptr<Environment> env_ref_;
  time_t start_time_;
  DISALLOW_COPY_AND_ASSIGN(Postoffice);
};

3.2 ID 映射功能

首先咱們介紹下 node id 映射功能,就是如何在邏輯節點和物理節點之間作映射,如何把物理節點劃分紅各個邏輯組,如何用簡便的方法作到給組內物理節點統一發消息

  • 1,2,4分別標識Scheduler, ServerGroup, WorkerGroup。
  • SingleWorker:rank * 2 + 9;SingleServer:rank * 2 + 8。
  • 任意一組節點均可以用單個id標識,等於全部id之和。

3.2.1 概念

  • Rank 是一個邏輯概念,是每個節點(scheduler,work,server)內部的惟一邏輯標示
  • Node id 是物理節點的惟一標識,能夠和一個 host + port 的二元組惟一對應
  • Node Group 是一個邏輯概念,每個 group 能夠包含多個 node id。ps-lite 一共有三組 group : scheduler 組,server 組,worker 組。
  • Node group id 是 是節點組的惟一標示。
    • ps-lite 使用 1,2,4 這三個數字分別標識 Scheduler,ServerGroup,WorkerGroup。每個數字都表明着一組節點,等於全部該類型節點 id 之和。好比 2 就表明server 組,就是全部 server node 的組合。
    • 爲何選擇這三個數字?由於在二進制下這三個數值分別是 "001, 010, 100",這樣若是想給多個 group 發消息,直接把 幾個 node group id 作 或操做 就行。
    • 即 1-7 內任意一個數字都表明的是Scheduler / ServerGroup / WorkerGroup的某一種組合。
      • 若是想把某一個請求發送給全部的 worker node,把請求目標節點 id 設置爲 4 便可。
      • 假設某一個 worker 但願向全部的 server 節點 和 scheduler 節點同時發送請求,則只要把請求目標節點的 id 設置爲 3 便可,由於 3 = 2 + 1 = kServerGroup + kScheduler。
      • 若是想給全部節點發送消息,則設置爲 7 便可。

3.2.2 邏輯組的實現

三個邏輯組的定義以下:

/** \brief node ID for the scheduler */
static const int kScheduler = 1;
/**
 * \brief the server node group ID
 *
 * group id can be combined:
 * - kServerGroup + kScheduler means all server nodes and the scheuduler
 * - kServerGroup + kWorkerGroup means all server and worker nodes
 */
static const int kServerGroup = 2;
/** \brief the worker node group ID */
static const int kWorkerGroup = 4;

3.2.3 Rank vs node id

node id 是物理節點的惟一標示,rank 是每個邏輯概念(scheduler,work,server)內部的惟一標示。這兩個標示由一個算法來肯定。

以下面代碼所示,若是配置了 3 個worker,則 worker 的 rank 從 0 ~ 2,那麼這幾個 worker 實際對應的 物理 node ID 就會使用 WorkerRankToID 來計算出來。

for (int i = 0; i < num_workers_; ++i) {
      int id = WorkerRankToID(i);
      for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                    kWorkerGroup + kScheduler,
                    kWorkerGroup + kServerGroup + kScheduler}) {
        node_ids_[g].push_back(id);
      }
    }

具體計算規則以下:

/**
   * \brief convert from a worker rank into a node id
   * \param rank the worker rank
   */
  static inline int WorkerRankToID(int rank) {
    return rank * 2 + 9;
  }
  /**
   * \brief convert from a server rank into a node id
   * \param rank the server rank
   */
  static inline int ServerRankToID(int rank) {
    return rank * 2 + 8;
  }
  /**
   * \brief convert from a node id into a server or worker rank
   * \param id the node id
   */
  static inline int IDtoRank(int id) {
#ifdef _MSC_VER
#undef max
#endif
    return std::max((id - 8) / 2, 0);
  }

這樣咱們能夠知道,1-7 的id表示的是node group,單個節點的id 就從 8 開始。

並且這個算法保證server id爲偶數,node id爲奇數。

  • SingleWorker:rank * 2 + 9;
  • SingleServer:rank * 2 + 8;

3.2.4 Group vs node

由於有時請求要發送給多個節點,因此ps-lite用了一個 map 來存儲每一個 node group / single node 對應的實際的node節點集合,即 肯定每一個id值對應的節點id集。

std::unordered_map<int, std::vector<int>> node_ids_

如何使用這個node_ids_?咱們仍是須要看以前的代碼:

for (int i = 0; i < num_workers_; ++i) {
      int id = WorkerRankToID(i);
      for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                    kWorkerGroup + kScheduler,
                    kWorkerGroup + kServerGroup + kScheduler}) {
        node_ids_[g].push_back(id);
      }
    }

咱們回憶一下以前的節點信息:

  • 1 ~ 7 的 id 表示的是 node group;
  • 後續的 id(8,9,10,11 ...)表示單個的 node。其中雙數 8,10,12... 表示 worker 0, worker 1, worker 2,... 即(2n + 8),9,11,13,...,表示 server 0, server 1,server 2,...,即(2n + 9);

因此,爲了實現 「設置 1-7 內任意一個數字 能夠發送給其對應的 全部node」 這個功能,對於每個新節點,須要將其對應多個id(node,node group)上,這些id組就是本節點能夠與之通信的節點。例如對於 worker 2 來講,其 node id 是 2 * 2 + 8 = 12,因此須要將它與

  • 12(自己)
  • 4(kWorkerGroup)li
  • 4+1(kWorkerGroup + kScheduler)
  • 4+2(kWorkerGroup + kServerGroup)
  • 4+1+2,(kWorkerGroup + kServerGroup + kScheduler )

這 5 個id 相對應,即須要在 node_ids_ 這個映射表中對應的 4, 4 + 1, 4 + 2, 4 +1 + 2, 12 這五個 item 之中添加。就是上面代碼中的內部 for 循環條件。即,node_ids_ [4], node_ids_ [5],node_ids_ [6],node_ids_ [7] ,node_ids_ [12] 之中,都須要把 12 添加到 vector 最後。

3.3 參數表示

workers 跟 servers 之間經過 pushpull 來通訊。worker 經過 push 將計算好的梯度發送到server,而後經過 pull 從server更新參數。

3.3.1 KV格式

parameter server 中,參數都是能夠被表示成(key, value)的集合,好比一個最小化損失函數的問題,key就是feature ID,而value就是它的權值。對於稀疏參數來講,value不存在的key,就能夠認爲value是0。

把參數表示成 k-v, 形式更天然,易於理解和編程實現。

3.3.2 key-values

分佈式算法有兩個額外成本:數據通訊成本,負載均衡不理想和機器性能差別致使的同步成本。

對於高維機器學習訓練來講,由於高頻特徵更新極爲頻繁,所會致使網絡壓力極大。若是每個參數都設一個key而且按key更新,那麼會使得通訊變得更加頻繁低效,爲了抹平這個問題,就須要有折衷和平衡,即,
利用機器學習算法的特性,給每一個key對應的value賦予一個向量或者矩陣,這樣就能夠一次性傳遞多個參數,權衡了融合與同步的成本。

作這樣的操做的前提是假設參數是有順序的。缺點是在對於稀疏模型來講,總會在向量或者矩陣裏會有參數爲0,這在單個參數狀態下是不用存的,因此,形成了數據的冗餘。

但這樣作有兩點好處:

  • 下降網絡通訊
  • 使得向量層面的操做變得可行,從而不少線性庫的優化特性能夠利用的上,好比BLAS、LAPACK、ATLAS等。

3.3.3 Range 操做

爲了提升計算性能和帶寬效率,參數服務器也會採用批次更新的辦法,來減輕高頻 key 的壓力。好比把minibatch之中高頻key合併成一個minibatch進行更新。

ps-lite 容許用戶使用 Range PushRange Pull 操做。

3.4 路由功能(keyslice)

路由功能指的就是:Worker 在作 Push/Pull 時候,如何知道把消息發送給哪些 Servers

咱們知道,ps-lite 是多 Server 架構,一個很重要的問題是如何分佈多個參數。好比給定一個參數的鍵,如何肯定其存儲在哪一臺 Server 上。因此必然有一個路由邏輯用來確立 key與server的對應關係。

PS Lite 將路由邏輯放置在 Worker 端,採用範圍劃分的策略,即每個 Server 有本身固定負責的鍵的範圍。這個範圍是在 Worker 啓動的時候肯定的。細節以下:

  • 根據編譯 PS Lite 時是否設定的宏 USE_KEY32 來決定參數的鍵的數據類型,要麼是 32 位無符號整數,要麼是 64 位的。
  • 根據鍵的數據類型,肯定其值域的上界。例如 uint32_t 的上界是 4294967295。
  • 根據鍵域的上界和啓動時獲取的 Server 數量(即環境變量 DMLC_NUM_SERVER 的值)來劃分範圍。
  • 每一個server維護的key範圍按 uint32_t / uint64_t 從小到大等距分區間。給定上界 MAX 和 Server 數量 N,第 i 個 Server 負責的範圍是 [MAX/N*i, MAX/N*(i+1))
  • 對key的hash值構造有必定的要求以免server間的key傾斜(如32位、16位、8位、4位、2位高低位對調)。
  • Worker push和pull的key按升序排列進行slice以實現zero copy。

須要注意的是,在不能恰好整除的狀況下,鍵域上界的一小段被丟棄了。

具體實現以下:

首先,ps-lite的key只支持int類型。

#if USE_KEY32
/*! \brief Use unsigned 32-bit int as the key type */
using Key = uint32_t;
#else
/*! \brief Use unsigned 64-bit int as the key type */
using Key = uint64_t;
#endif
/*! \brief The maximal allowed key value */
static const Key kMaxKey = std::numeric_limits<Key>::max();

其次,將int範圍均分便可

const std::vector<Range>& Postoffice::GetServerKeyRanges() {
  if (server_key_ranges_.empty()) {
    for (int i = 0; i < num_servers_; ++i) {
      server_key_ranges_.push_back(Range(
          kMaxKey / num_servers_ * i,
          kMaxKey / num_servers_ * (i+1)));
    }
  }
  return server_key_ranges_;
}

3.5 初始化環境

從以前分析中咱們能夠知道,ps-lite 是經過環境變量來控制具體節點。

具體某個節點屬於哪種取決於啓動節點以前設置了哪些環境變量以及其數值。

環境變量包括:節點角色,worker&server個數、ip、port等。

InitEnvironment 函數就是建立了 Van,獲得了 worker 和 server 的數量,獲得了本節點的類型。

void Postoffice::InitEnvironment() {
  const char* val = NULL;
  std::string van_type = GetEnv("DMLC_PS_VAN_TYPE", "zmq");
  van_ = Van::Create(van_type);
  val = CHECK_NOTNULL(Environment::Get()->find("DMLC_NUM_WORKER"));
  num_workers_ = atoi(val);
  val =  CHECK_NOTNULL(Environment::Get()->find("DMLC_NUM_SERVER"));
  num_servers_ = atoi(val);
  val = CHECK_NOTNULL(Environment::Get()->find("DMLC_ROLE"));
  std::string role(val);
  is_worker_ = role == "worker";
  is_server_ = role == "server";
  is_scheduler_ = role == "scheduler";
  verbose_ = GetEnv("PS_VERBOSE", 0);
}

3.6 啓動

主要就是:

  • 調用 InitEnvironment() 來初始化環境,建立 VAN 對象;
  • node_ids_初始化。根據worker和server節點個數,肯定每一個id值對應的節點id集。具體邏輯咱們前面有分析。
  • 啓動 van,這裏會進行各類交互(有一個 ADD_NODE 同步等待,與後面的 barrier 等待不一樣);
  • 若是是第一次調用PostOffice::Start,初始化start_time_成員;
  • 若是設置了須要 barrier,則調用 Barrier 來進行 等待/處理 最終系通通一啓動。即 全部Node準備並向Scheduler發送要求同步的Message,進行第一次同步;

具體代碼以下:

void Postoffice::Start(int customer_id, const char* argv0, const bool do_barrier) {
  start_mu_.lock();
  if (init_stage_ == 0) {
    InitEnvironment();

    // init node info.
    // 對於全部的worker,進行node設置
    for (int i = 0; i < num_workers_; ++i) {
      int id = WorkerRankToID(i);
      for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
                    kWorkerGroup + kScheduler,
                    kWorkerGroup + kServerGroup + kScheduler}) {
        node_ids_[g].push_back(id);
      }
    }
		// 對於全部的server,進行node設置
    for (int i = 0; i < num_servers_; ++i) {
      int id = ServerRankToID(i);
      for (int g : {id, kServerGroup, kWorkerGroup + kServerGroup,
                    kServerGroup + kScheduler,
                    kWorkerGroup + kServerGroup + kScheduler}) {
        node_ids_[g].push_back(id);
      }
    }
		// 設置scheduler的node
    for (int g : {kScheduler, kScheduler + kServerGroup + kWorkerGroup,
                  kScheduler + kWorkerGroup, kScheduler + kServerGroup}) {
      node_ids_[g].push_back(kScheduler);
    }
    init_stage_++;
  }
  start_mu_.unlock();

  // start van
  van_->Start(customer_id);

  start_mu_.lock();
  if (init_stage_ == 1) {
    // record start time
    start_time_ = time(NULL);
    init_stage_++;
  }
  start_mu_.unlock();
  // do a barrier here
  if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);
}

3.7 Barrier

3.7.1 同步

總的來說,schedular節點經過計數的方式實現各個節點的同步。具體來講就是:

  • 每一個節點在本身指定的命令運行完後會向schedular節點發送一個Control::BARRIER命令的請求並本身阻塞直到收到schedular對應的返回後才解除阻塞;
  • schedular節點收到請求後則會在本地計數,看收到的請求數是否和barrier_group的數量是否相等,相等則表示每一個機器都運行完指定的命令了,此時schedular節點會向barrier_group的每一個機器發送一個返回的信息,並解除其阻塞。

3.7.2 初始化

ps-lite 使用 Barrier 來控制系統的初始化,就是你們都準備好了再一塊兒前進。這是一個可選項。具體以下:

  • Scheduler等待全部的worker和server發送BARRIER信息;
  • 在完成ADD_NODE後,各個節點會進入指定 group 的Barrier阻塞同步機制(發送 BARRIER 給 Scheduler),以保證上述過程每一個節點都已經完成;
  • 全部節點(worker和server,包括scheduler) 等待scheduler收到全部節點 BARRIER 信息後的應答;
  • 最終全部節點收到scheduler 應答的Barrier message後退出阻塞狀態;
3.7.2.1 等待 BARRIER 消息

Node會調用 Barrier 函數 告知Scheduler,隨即本身進入等待狀態。

注意,調用時候是

if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);

這就是說,等待全部的 group,即 scheduler 節點也要給本身發送消息。

void Postoffice::Barrier(int customer_id, int node_group) {
  if (GetNodeIDs(node_group).size() <= 1) return;
  auto role = van_->my_node().role;
  if (role == Node::SCHEDULER) {
    CHECK(node_group & kScheduler);
  } else if (role == Node::WORKER) {
    CHECK(node_group & kWorkerGroup);
  } else if (role == Node::SERVER) {
    CHECK(node_group & kServerGroup);
  }

  std::unique_lock<std::mutex> ulk(barrier_mu_);
  barrier_done_[0][customer_id] = false;
  Message req;
  req.meta.recver = kScheduler;
  req.meta.request = true;
  req.meta.control.cmd = Control::BARRIER;
  req.meta.app_id = 0;
  req.meta.customer_id = customer_id;
  req.meta.control.barrier_group = node_group; // 記錄了等待哪些
  req.meta.timestamp = van_->GetTimestamp();
  van_->Send(req); // 給 scheduler 發給 BARRIER
  barrier_cond_.wait(ulk, [this, customer_id] { // 而後等待
      return barrier_done_[0][customer_id];
    });
}
3.7.2.2 處理 BARRIER 消息

處理等待的動做在 Van 類之中,咱們提早放出來。

具體ProcessBarrierCommand邏輯以下:

  • 若是 msg->meta.request 爲true,說明是 scheduler 收到消息進行處理。
    • Scheduler會對Barrier請求進行增長計數。
    • 當 Scheduler 收到最後一個請求時(計數等於此group節點總數),則將計數清零,發送結束Barrier的命令。這時候 meta.request 設置爲 false;
    • 向此group全部節點發送request==falseBARRIER消息。
  • 若是 msg->meta.request 爲 false,說明是收到消息這個 respones,能夠解除barrier了,因而進行處理,調用 Manage 函數 。
    • Manage 函數 將app_id對應的全部costomer的barrier_done_置爲true,而後通知全部等待條件變量barrier_cond_.notify_all()
void Van::ProcessBarrierCommand(Message* msg) {
  auto& ctrl = msg->meta.control;
  if (msg->meta.request) {  // scheduler收到了消息,由於 Postoffice::Barrier函數 會在發送時候作設置爲true。
    if (barrier_count_.empty()) {
      barrier_count_.resize(8, 0);
    }
    int group = ctrl.barrier_group;
    ++barrier_count_[group]; // Scheduler會對Barrier請求進行計數
    if (barrier_count_[group] ==
        static_cast<int>(Postoffice::Get()->GetNodeIDs(group).size())) { // 若是相等,說明已經收到了最後一個請求,因此發送解除 barrier 消息。
      barrier_count_[group] = 0;
      Message res;
      res.meta.request = false; // 回覆時候,這裏就是false
      res.meta.app_id = msg->meta.app_id;
      res.meta.customer_id = msg->meta.customer_id;
      res.meta.control.cmd = Control::BARRIER;
      for (int r : Postoffice::Get()->GetNodeIDs(group)) {
        int recver_id = r;
        if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) {
          res.meta.recver = recver_id;
          res.meta.timestamp = timestamp_++;
          Send(res);
        }
      }
    }
  } else { // 說明這裏收到了 barrier respones,能夠解除 barrier了。具體見上面的設置爲false處。
    Postoffice::Get()->Manage(*msg);
  }
}

Manage 函數就是解除了 barrier。

void Postoffice::Manage(const Message& recv) {
  CHECK(!recv.meta.control.empty());
  const auto& ctrl = recv.meta.control;
  if (ctrl.cmd == Control::BARRIER && !recv.meta.request) {
    barrier_mu_.lock();
    auto size = barrier_done_[recv.meta.app_id].size();
    for (size_t customer_id = 0; customer_id < size; customer_id++) {
      barrier_done_[recv.meta.app_id][customer_id] = true;
    }
    barrier_mu_.unlock();
    barrier_cond_.notify_all(); // 這裏解除了barrier
  }
}

具體示意以下:

+
    Scheduler                                       |                  Worker
        +                                           |                     +
        |                                           |                     |
        |                                           |                     |
        +--------------------------------+          |                     +-----------------+
        |                                |          |                     |                 |
        |                                |          |                     |                 |
        |                                |          |                     |                 |
        |                                v          |                     |                 v
        |                         receiver_thread_  |                     |           receiver_thread_
        |                                +          |                     |                 |
        |                                |          |                     |                 |
        v              BARRIER           |          |   BARRIER           v                 |
Postoffice::Barrier +----------------->  | <---------------------+ Postoffice::Barrier      |
        +                                |          |                     +                 |
        |                                |          |                     |                 |
        |                                |          |                     |                 |
        |                                |          |                     |                 |
        |                                v          |                     |                 |
        v                                           |                     v                 |
 barrier_cond_.wait          ProcessBarrierCommand  |               barrier_cond_.wait      |
        |                                +          |                     |                 |
        |                                |          |                     |                 |
        |                  All Nodes OK  |          |                     |                 |
        |                                |          |                     |                 |
        |                 +--------------+          |   BARRIER           |                 |
        |                 |              +---------------------------------------------->   |
        |                 |  BARRIER     |          |                     |                 |
        |                 +------------> |          |                     |                 |
        |                                |          |                     |                 |
        |                                |          |                     |                 |
        +<-------------------------------<          |                     | <---------------+
        |          barrier_cond_.notify_all         |                     |    barrier_cond_.notify_all
        v                                           |                     v
                                                    +

手機以下:

至此,Postoffice的分析咱們初步完成,其他功能咱們將會結合 Van 和 Customer 在後續文章中分析。

0xEE 我的信息

★★★★★★關於生活和技術的思考★★★★★★

微信公衆帳號:羅西的思考

若是您想及時獲得我的撰寫文章的消息推送,或者想看看我的推薦的技術資料,敬請關注。

在這裏插入圖片描述

0xFF 參考

MXNet設計和實現簡介

史上最全面的ps-lite理解

ps-lite 深度源碼解讀

ps-lite源碼剖析

基於Parameter Server的可擴展分佈式機器學習架構

ps-lite代碼解析

ps-lite代碼筆記

分佈式TensorFlow入門教程

分佈式機器學習(上)-並行計算與機器學習

分佈式機器學習(中)-並行計算與機器學習

分佈式機器學習(下)-聯邦學習

ps-lite 源代碼分析

Talk - Scaling Distributed Machine Learning with System and Algorithm Co-design 筆記

相關文章
相關標籤/搜索