本文是參數服務器第三篇,介紹ps-lite的Customer模塊。html
目前有了郵局 (PostOffice)和通訊模塊小推車(Van),接下來就要看看郵局的客戶Customer。node
Customer 就是 SimpleApp 在郵局的代理人。由於 worker,server 須要集中精力在算法上,因此把 worker,server 邏輯上與網絡相關的收發消息功能 都總結/轉移到 Customer 之中。python
本系列其餘文章是:c++
[源碼解析] 機器學習參數服務器ps-lite 之(1) ----- PostOfficegit
[源碼解析] 機器學習參數服務器ps-lite(2) ----- 通訊模塊Vangithub
咱們總結一下目前的整體狀態:算法
瞭解一個類的上下文環境可讓咱們更好的理解這個類,因此咱們首先須要看看 Customer 在哪裏使用到,咱們目前已經分析了兩個類,咱們就看看這兩個類中如何使用Customer。shell
在 PostOffice 之中,有以下成員變量:數組
// app_id -> (customer_id -> customer pointer) std::unordered_map<int, std::unordered_map<int, Customer*>> customers_;
以及以下成員函數,就是把Customer註冊到customers_:安全
void Postoffice::AddCustomer(Customer* customer) { std::lock_guard<std::mutex> lk(mu_); int app_id = CHECK_NOTNULL(customer)->app_id(); // check if the customer id has existed int customer_id = CHECK_NOTNULL(customer)->customer_id(); customers_[app_id].insert(std::make_pair(customer_id, customer)); std::unique_lock<std::mutex> ulk(barrier_mu_); barrier_done_[app_id].insert(std::make_pair(customer_id, false)); } Customer* Postoffice::GetCustomer(int app_id, int customer_id, int timeout) const { Customer* obj = nullptr; for (int i = 0; i < timeout * 1000 + 1; ++i) { { std::lock_guard<std::mutex> lk(mu_); const auto it = customers_.find(app_id); if (it != customers_.end()) { std::unordered_map<int, Customer*> customers_in_app = it->second; obj = customers_in_app[customer_id]; break; } } std::this_thread::sleep_for(std::chrono::milliseconds(1)); } return obj; }
所以,咱們能夠看出來幾點:
在 Van 中,咱們能夠看到,當處理數據消息時候,會:
void Van::ProcessDataMsg(Message* msg) { // data msg int app_id = msg->meta.app_id; int customer_id = Postoffice::Get()->is_worker() ? msg->meta.customer_id : app_id; auto* obj = Postoffice::Get()->GetCustomer(app_id, customer_id, 5); obj->Accept(*msg); }
所以咱們知道:
在 Customer 之中咱們能夠看到,Accept 的做用就是往 Customer 的 queue 之中插入消息。
ThreadsafePQueue recv_queue_; inline void Accept(const Message& recved) { recv_queue_.Push(recved); }
Customer對象自己也會啓動一個接受線程 recv_thread_
,使用 Customer::Receiving(),其中調用註冊的recv_handle_
函數對消息進行處理。
std::unique_ptr<std::thread> recv_thread_; recv_thread_ = std::unique_ptr<std::thread>(new std::thread(&Customer::Receiving, this)); void Customer::Receiving() { while (true) { Message recv; recv_queue_.WaitAndPop(&recv); if (!recv.meta.control.empty() && recv.meta.control.cmd == Control::TERMINATE) { break; } recv_handle_(recv); if (!recv.meta.request) { std::lock_guard<std::mutex> lk(tracker_mu_); tracker_[recv.meta.timestamp].second++; tracker_cond_.notify_all(); } } }
所以咱們能夠得出目前邏輯(接受消息邏輯)以下:
Postoffice::start()
。Postoffice::start()
會初始化節點信息,而且調用Van::start()
。Van::start()
啓動一個本地線程,使用Van::Receiving()
來持續監聽收到的message。Van::Receiving()
接收後消息以後,根據不一樣命令執行不一樣動做。針對數據消息,若是須要下一步處理,會調用 ProcessDataMsg
:
Customer
。Customer::Accept
函數。Customer::Accept()
函數將消息添加到一個隊列recv_queue_
;Customer
對象自己也會啓動一個接受線程 recv_thread_
,使用 Customer::Receiving()
recv_queue_
隊列取消息。recv_handle_
函數對消息進行處理。簡要版邏輯以下,數據流按照圖上數字順序進行,咱們也能夠看到, Van,Postoffice,Customer 這三個類彼此之間有些過耦合,可能作一下梳理會更好:
+--------------------------+ | Van | | | DataMessage +-----------> Receiving | | 1 + | +---------------------------+ | | | | Postoffice | | | 2 | | | | v | GetCustomer | | | ProcessDataMsg <------------------> unordered_map customers_| | + | 3 | | | | | +---------------------------+ +--------------------------+ | | | 4 | +-------------------------+ | Customer | | | | | | v | | Accept | | + | | | | | | 5 | | v | | recv_queue_ | | + | | | 6 | | | | | v | | Receiving | | + | | | 7 | | | | | v | | recv_handle_ | | | +-------------------------+
下面咱們就詳細剖析下具體邏輯。
咱們首先要介紹一些基礎類。
SArray 有以下特色:
在ps-lite中,每一個server 擁有一段連續的key,以及這些key對應的value。key和value是分開存儲的,每一個key可能對應多個value,所以須要記錄每一個key的長度,因此就有了 KVPairs。
KVPairs 特色以下:
舉例而言:
定義以下:
struct KVPairs { // /** \brief empty constructor */ // KVPairs() {} /** \brief the list of keys */ SArray<Key> keys; /** \brief the according values */ SArray<Val> vals; /** \brief the according value lengths (could be empty) */ SArray<int> lens; // key對應value的長度vector /** \brief priority */ int priority = 0; };
Node封裝了節點信息,例如角色,ip,端口,是不是恢復節點。
struct Node { /** \brief the empty value */ static const int kEmpty; /** \brief default constructor */ Node() : id(kEmpty), port(kEmpty), is_recovery(false) {} /** \brief node roles */ enum Role { SERVER, WORKER, SCHEDULER }; /** \brief the role of this node */ Role role; /** \brief node id */ int id; /** \brief customer id */ int customer_id; /** \brief hostname or ip */ std::string hostname; /** \brief the port this node is binding */ int port; /** \brief whether this node is created by failover */ bool is_recovery; };
Control :封裝了控制消息的meta信息,barrier_group(用於標識哪些節點須要同步,當command=BARRIER時使用),node(Node類,用於標識控制命令對哪些節點使用)等,方法簽名。
能夠看到,Control 就包含了上面介紹的 Node 類型。
struct Control { /** \brief empty constructor */ Control() : cmd(EMPTY) { } /** \brief return true is empty */ inline bool empty() const { return cmd == EMPTY; } /** \brief all commands */ enum Command { EMPTY, TERMINATE, ADD_NODE, BARRIER, ACK, HEARTBEAT }; /** \brief the command */ Command cmd; /** \brief node infos */ std::vector<Node> node; /** \brief the node group for a barrier, such as kWorkerGroup */ int barrier_group; /** message signature */ uint64_t msg_sig; };
Meta :是消息的元數據部分,包括時間戳,發送者id,接受者id,控制信息Control,消息類型等;
struct Meta { /** \brief the empty value */ static const int kEmpty; /** \brief default constructor */ Meta() : head(kEmpty), app_id(kEmpty), customer_id(kEmpty), timestamp(kEmpty), sender(kEmpty), recver(kEmpty), request(false), push(false), pull(false), simple_app(false) {} /** \brief an int head */ int head; /** \brief the unique id of the application of messsage is for*/ int app_id; /** \brief customer id*/ int customer_id; /** \brief the timestamp of this message */ int timestamp; /** \brief the node id of the sender of this message */ int sender; /** \brief the node id of the receiver of this message */ int recver; /** \brief whether or not this is a request message*/ bool request; /** \brief whether or not a push message */ bool push; /** \brief whether or not a pull message */ bool pull; /** \brief whether or not it's for SimpleApp */ bool simple_app; /** \brief an string body */ std::string body; /** \brief data type of message.data[i] */ std::vector<DataType> data_type; /** \brief system control message */ Control control; /** \brief the byte size */ int data_size = 0; /** \brief message priority */ int priority = 0; };
Message 是要發送的信息,具體以下:
消息頭 meta:就是元數據(使用了Protobuf 進行數據壓縮),包括:
消息體 body:就是發送的數據,使用了自定義的 SArray 共享數據,減小數據拷貝;
幾個類之間的邏輯關係以下:
Message中的某些功能須要依賴Meta來完成,以此類推。
message 包括以下類型:
具體定義以下:
struct Message { /** \brief the meta info of this message */ Meta meta; /** \brief the large chunk of data of this message */ std::vector<SArray<char> > data; /** * \brief push array into data, and add the data type */ template <typename V> void AddData(const SArray<V>& val) { CHECK_EQ(data.size(), meta.data_type.size()); meta.data_type.push_back(GetDataType<V>()); SArray<char> bytes(val); meta.data_size += bytes.size(); data.push_back(bytes); } };
每次發送消息時,消息就按這個格式封裝好,負責發送消息的類成員(Customer類)就會按照Meta之中的信息將消息送貨上門。
Customer 其實有兩個功能:
具體特色以下:
每一個SimpleApp對象持有一個Customer類的成員,且Customer須要在PostOffice進行註冊。
由於 Customer 同時又要處理Message 可是其自己並無接管網絡,所以實際的Response和Message須要外部調用者告訴它,因此功能和職責上有點分裂。
每個鏈接對應一個Customer實例,每一個Customer都與某個node id相綁定,表明當前節點發送到對應node id節點。鏈接對方的id和Customer實例的id相同。
新建一次request,會返回一個timestamp,這個timestamp會做爲此次request的id,每次請求會自增1,相應的res也會自增1,調用wait時會保證 後續好比作Wait以此爲ID識別。
咱們首先看看Customer的成員變量。
須要注意,這裏對於變量功能的理解,咱們能夠從消息流程來看,即若是有一個接受消息,則這個流程數據流以下,因此咱們把 Customer 的成員變量也按照這個順序梳理 :
Van::ProcessDataMsg ---> Customer::Accept ---> Customer::recv_queue_ ---> Customer::recv_thread_ ---> Customer::recv_handle_
主要成員變量以下:
ThreadsafePQueue recv_queue_ :線程安全的消息隊列;
std::unique_ptr< std::thread> recv_thread_ : 不斷從 recv_queue 讀取message並調用 recv_handle_;
RecvHandle recv_handle_ :worker 或者 server 的消息處理函數。
std::vector<std::pair<int, int>> tracker_ :request & response 的同步變量。
具體定義以下:
class Customer { public: /** * \brief the handle for a received message * \param recved the received message */ using RecvHandle = std::function<void(const Message& recved)>; /** * \brief constructor * \param app_id the globally unique id indicating the application the postoffice * serving for * \param customer_id the locally unique id indicating the customer of a postoffice * \param recv_handle the functino for processing a received message */ Customer(int app_id, int customer_id, const RecvHandle& recv_handle); /** * \brief desconstructor */ ~Customer(); /** * \brief return the globally unique application id */ inline int app_id() { return app_id_; } /** * \brief return the locally unique customer id */ inline int customer_id() { return customer_id_; } /** * \brief get a timestamp for a new request. threadsafe * \param recver the receive node id of this request * \return the timestamp of this request */ int NewRequest(int recver); /** * \brief wait until the request is finished. threadsafe * \param timestamp the timestamp of the request */ void WaitRequest(int timestamp); /** * \brief return the number of responses received for the request. threadsafe * \param timestamp the timestamp of the request */ int NumResponse(int timestamp); /** * \brief add a number of responses to timestamp */ void AddResponse(int timestamp, int num = 1); /** * \brief accept a received message from \ref Van. threadsafe * \param recved the received the message */ inline void Accept(const Message& recved) { recv_queue_.Push(recved); } private: /** * \brief the thread function */ void Receiving(); int app_id_; int customer_id_; RecvHandle recv_handle_; ThreadsafePQueue recv_queue_; std::unique_ptr<std::thread> recv_thread_; std::mutex tracker_mu_; std::condition_variable tracker_cond_; std::vector<std::pair<int, int>> tracker_; DISALLOW_COPY_AND_ASSIGN(Customer); };
在構建函數中,會創建接受線程。
recv_thread_ = std::unique_ptr<std::thread>(new std::thread(&Customer::Receiving, this));
線程處理函數以下,具體邏輯就是:
void Customer::Receiving() { while (true) { Message recv; recv_queue_.WaitAndPop(&recv); if (!recv.meta.control.empty() && recv.meta.control.cmd == Control::TERMINATE) { break; } recv_handle_(recv); if (!recv.meta.request) { std::lock_guard<std::mutex> lk(tracker_mu_); tracker_[recv.meta.timestamp].second++; tracker_cond_.notify_all(); } } }
由於是使用 recv_handle_ 來進行具體的業務邏輯,因此咱們下面就看看 recv_handle_ 如何設置,其實也就是 Customer 如何構建,使用。
咱們須要提早使用下文將要分析的一些類,由於他們是 Customer 的使用者,耦合的太緊密了。
首先咱們看看SimpleApp,這是具體邏輯功能節點的基類。
每一個SimpleApp對象持有一個Customer類的成員,且Customer須要在PostOffice進行註冊,
這裏就是 新建一個Custom對象初始化obj_成員。
inline SimpleApp::SimpleApp(int app_id, int customer_id) : SimpleApp() { using namespace std::placeholders; obj_ = new Customer(app_id, customer_id, std::bind(&SimpleApp::Process, this, _1)); }
咱們再看看SimpleApp的兩個子類。
KVServer類主要用來保存key-values數據,進行一些業務操做,好比梯度更新。主要方法爲:Process() 和Response()。
在其構造函數中會:
Customer:: recv_handle_
;構造函數以下:
/** * \brief constructor * \param app_id the app id, should match with \ref KVWorker's id */ explicit KVServer(int app_id) : SimpleApp() { using namespace std::placeholders; obj_ = new Customer(app_id, app_id, std::bind(&KVServer<Val>::Process, this, _1)); }
KVWorker類 主要用來向Server Push/Pull 本身的 key-value 數據。包括以下方法: Push(),Pull(),Wait()。
在其構造函數中會:
/** * \brief constructor * * \param app_id the app id, should match with \ref KVServer's id * \param customer_id the customer id which is unique locally */ explicit KVWorker(int app_id, int customer_id) : SimpleApp() { using namespace std::placeholders; slicer_ = std::bind(&KVWorker<Val>::DefaultSlicer, this, _1, _2, _3); obj_ = new Customer(app_id, customer_id, std::bind(&KVWorker<Val>::Process, this, _1)); }
構建函數邏輯以下:
app_id_, custom_id_ , recv_handle
成員具體構建函數以下:
Customer::Customer(int app_id, int customer_id, const Customer::RecvHandle& recv_handle) : app_id_(app_id), customer_id_(customer_id), recv_handle_(recv_handle) { Postoffice::Get()->AddCustomer(this); recv_thread_ = std::unique_ptr<std::thread>(new std::thread(&Customer::Receiving, this)); }
你們可能對 app_id 和 customer_id 有些疑問,好比:
在 KVWorker 構建函數中有:
在 KVServer 構建函數中有:
咱們使用源碼自帶的 tests/test_kv_app_multi_workers.cc 來梳理一下 app_id 與 customer_id 的邏輯關係。
咱們提早劇透:worker是用 customer_id 來肯定本身的身份。customer id 在 worker 代碼中被用來肯定 本worker 對應的 key 的範圍。
從腳本中能夠看出來,使用以下作測試:
find test_* -type f -executable -exec ./repeat.sh 4 ./local.sh 2 2 ./{} \;
文件中啓動了一個 server 和 兩個 worker。
所以,咱們能夠理出來:
具體代碼以下:
#include <cmath> #include "ps/ps.h" using namespace ps; void StartServer() { // 啓動服務 if (!IsServer()) return; auto server = new KVServer<float>(0); server->set_request_handle(KVServerDefaultHandle<float>()); RegisterExitCallback([server](){ delete server; }); } void RunWorker(int customer_id) { // 啓動worker Start(customer_id); if (!IsWorker()) { return; } KVWorker<float> kv(0, customer_id); // init int num = 10000; std::vector<Key> keys(num); std::vector<float> vals(num); int rank = MyRank(); srand(rank + 7); for (int i = 0; i < num; ++i) { keys[i] = kMaxKey / num * i + customer_id; vals[i] = (rand() % 1000); } // push int repeat = 50; std::vector<int> ts; for (int i = 0; i < repeat; ++i) { ts.push_back(kv.Push(keys, vals)); // to avoid too frequency push, which leads huge memory usage if (i > 10) kv.Wait(ts[ts.size()-10]); } for (int t : ts) kv.Wait(t); // pull std::vector<float> rets; kv.Wait(kv.Pull(keys, &rets)); // pushpull std::vector<float> outs; for (int i = 0; i < repeat; ++i) { kv.Wait(kv.PushPull(keys, vals, &outs)); } float res = 0; float res2 = 0; for (int i = 0; i < num; ++i) { res += fabs(rets[i] - vals[i] * repeat); res += fabs(outs[i] - vals[i] * 2 * repeat); } CHECK_LT(res / repeat, 1e-5); CHECK_LT(res2 / (2 * repeat), 1e-5); LL << "error: " << res / repeat << ", " << res2 / (2 * repeat); // stop system Finalize(customer_id, true); } int main(int argc, char *argv[]) { // start system bool isWorker = (strcmp(argv[1], "worker") == 0); if (!isWorker) { Start(0); // setup server nodes,啓動server節點 StartServer(); Finalize(0, true); return 0; } // run worker nodes,啓動兩個worker節點 std::thread t0(RunWorker, 0); std::thread t1(RunWorker, 1); t0.join(); t1.join(); return 0; }
咱們再回憶下 Postoffice 的初始化,能夠看到,啓動時候,worker是用 customer_id 來肯定本身的身份。因而,customer id 在 worker 代碼中被用來肯定 本worker 對應的 key 的範圍。
void Postoffice::Start(int customer_id, const char* argv0, const bool do_barrier) { // init node info. // 對於全部的worker,進行node設置 for (int i = 0; i < num_workers_; ++i) { int id = WorkerRankToID(i); for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup, kWorkerGroup + kScheduler, kWorkerGroup + kServerGroup + kScheduler}) { node_ids_[g].push_back(id); } } // 對於全部的server,進行node設置 for (int i = 0; i < num_servers_; ++i) { int id = ServerRankToID(i); for (int g : {id, kServerGroup, kWorkerGroup + kServerGroup, kServerGroup + kScheduler, kWorkerGroup + kServerGroup + kScheduler}) { node_ids_[g].push_back(id); } } // 設置scheduler的node for (int g : {kScheduler, kScheduler + kServerGroup + kWorkerGroup, kScheduler + kWorkerGroup, kScheduler + kServerGroup}) { node_ids_[g].push_back(kScheduler); } init_stage_++; } // start van van_->Start(customer_id); // 這裏有 customer_id ...... // do a barrier here,這裏有 customer_id if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler); }
再看看 Van 的初始化,也是用 customer_id 來肯定本身的身份。
void Van::Start(int customer_id) { if (init_stage == 0) { // get my node info if (is_scheduler_) { my_node_ = scheduler_; } else { my_node_.hostname = ip; my_node_.role = role; my_node_.port = port; my_node_.id = Node::kEmpty; my_node_.customer_id = customer_id; // 這裏有 customer_id } } if (!is_scheduler_) { // let the scheduler know myself Message msg; Node customer_specific_node = my_node_; customer_specific_node.customer_id = customer_id; // 這裏有 customer_id msg.meta.recver = kScheduler; msg.meta.control.cmd = Control::ADD_NODE; msg.meta.control.node.push_back(customer_specific_node); msg.meta.timestamp = timestamp_++; Send(msg); } ...... }
因此,也可以解釋了爲何在 KVWorker 發送消息時候使用 app_id 和 customer_id。
template <typename Val> void KVWorker<Val>::Send(int timestamp, bool push, bool pull, int cmd, const KVPairs<Val>& kvs) { ..... for (size_t i = 0; i < sliced.size(); ++i) { Message msg; msg.meta.app_id = obj_->app_id(); // 注意這裏 msg.meta.customer_id = obj_->customer_id();// 注意這裏 msg.meta.request = true; ...... Postoffice::Get()->van()->Send(msg); } }
在 KVServer 之中,也須要在迴應消息時候,使用 app_id 和 customer_id。
template <typename Val> void KVServer<Val>::Response(const KVMeta& req, const KVPairs<Val>& res) { Message msg; msg.meta.app_id = obj_->app_id();// 注意這裏 msg.meta.customer_id = req.customer_id;// 注意這裏 msg.meta.request = false; msg.meta.push = req.push; msg.meta.pull = req.pull; msg.meta.head = req.cmd; msg.meta.timestamp = req.timestamp; msg.meta.recver = req.sender; ...... Postoffice::Get()->van()->Send(msg); }
那麼問題來了,爲何 Server 端,app_id 與 customer_id 相等?
由於目前沒有 ps 的最初代碼,因此猜想是:
在 ps 代碼中,Server 端也是有多個 cusomer,可是出於精簡目的,在 ps-lite 之中刪除了這部分功能,所以在 ps-lite 之中,app_id 與 customer_id 相等。
所以咱們再次梳理流程(接受消息邏輯)以下:
worker節點 或者 server節點 在程序的最開始會執行Postoffice::start()
。
Postoffice::start()
會初始化節點信息,而且調用Van::start()
。
Van::start()
啓動一個本地線程,使用Van::Receiving()
來持續監聽收到的message。
Van::Receiving()
接收後消息以後,根據不一樣命令執行不一樣動做。針對數據消息,若是須要下一步處理,會調用 ProcessDataMsg:
Customer::Accept
函數。Customer::Accept() 函數將消息添加到一個隊列recv_queue_
;
Customer 對象自己也會啓動一個接受線程 recv_thread_
,使用 Customer::Receiving()
recv_queue_
隊列取消息。tracker_[req.timestamp].second++
recv_handle_
函數對消息進行處理。對於worker來講,其註冊的recv_handle_
是KVWorker::Process()
函數。由於worker的recv thread接受到的消息主要是從server處pull下來的KV對,所以該Process()
主要是接收message中的KV對;
而對於Server來講,其註冊的recv_handle_
是KVServer::Process()
函數。由於server接受的是worker們push上來的KV對,須要對其進行處理,所以該Process()
函數中調用的用戶經過KVServer::set_request_handle()
傳入的函數對象。
目前邏輯以下圖,在 第 8 步,recv_handle_ 指向 KVServer
+--------------------------+ | Van | | | DataMessage +-----------> Receiving | | 1 + | +---------------------------+ | | | | Postoffice | | | 2 | | | | v | GetCustomer | | | ProcessDataMsg <------------------> unordered_map customers_| | + | 3 | | | | | +---------------------------+ +--------------------------+ | | | 4 | +-------------------------+ | Customer | | | | | | v | | Accept | | + | | | | | | 5 | | v | | recv_queue_ | +-----------------+ | + | |KVWorker | | | 6 | +--------> | | | | | | 8 | Process | | v | | +-----------------+ | Receiving | | | + | | | | 7 | | | | | | +-----------------+ | v | | |KVServer | | recv_handle_+---------+--------> | | | | 8 | Process | +-------------------------+ +-----------------+
如下這些 Customer 函數都是被其餘模塊調用。
此函數的做用是:當發送一個 request 時候,新增對此 request 的計數。因此,當咱們須要給一個Resquest計數的時候,使用此函數。
特色以下:
每次發送消息前,先修改此條消息 應收到的 Response數量。
recver表示接收者的node_id,由於ps-lite中一個整數可能對應於多個node_id,因此使用Postoffice解碼得到全部的真實node_id 的數目。
好比給 kServerGroup 發消息,kServerGroup 裏面有3 個 server,則 num 爲 3,就是應該收到 3 個response。tracker_ 對應的item 就是 [3,0],表示應該收到 3個,目前收到 0 個。
函數的返回值能夠認爲是一個時間戳,這個時間戳 會做爲此次request的id,調用wait時會保證後續Wait以此爲ID識別。
int Customer::NewRequest(int recver) { std::lock_guard<std::mutex> lk(tracker_mu_); int num = Postoffice::Get()->GetNodeIDs(recver).size(); // recver 可能會表明一個group。 tracker_.push_back(std::make_pair(num, 0)); return tracker_.size() - 1; // 表明這次請求的時間戳timestamp,後續customer使用這個值表明這個request }
具體調用舉例就是在 worker 向 server 推送時候。
int ZPush(const SArray<Key>& keys, const SArray<Val>& vals, const SArray<int>& lens = {}, int cmd = 0, const Callback& cb = nullptr, int priority = 0) { int ts = obj_->NewRequest(kServerGroup); // 這裏會調用 AddCallback(ts, cb); KVPairs<Val> kvs; kvs.keys = keys; kvs.vals = vals; kvs.lens = lens; kvs.priority = priority; Send(ts, true, false, cmd, kvs); return ts; }
做用是:針對request已經返回response進行計數。
特色以下:
當外部調用者收到Response時,調用AddResponse告訴Customer對象。
主動增長某次請求實際收到的Response數,主要用於客戶端發送請求時,有時可跳過與某些server的通訊(這次通訊的keys沒有分佈在這些server上),在客戶端就可直接認爲已接收到Response。
另外,在Customer::Receiving
中,當處理了一條非request請求後,也會增長對應的請求的Response數。 tracker_[recv.meta.timestamp].second++;
這個類有個缺陷,對於過時的之後不會再用到的Request信息,沒有刪除操做。而這個類的單個對象的生存週期又近乎等於進程的生存週期。所以,基於ps-lite程序跑的時間久了基本都會OOM。
void Customer::AddResponse(int timestamp, int num) { std::lock_guard<std::mutex> lk(tracker_mu_); tracker_[timestamp].second += num; }
在 KVWorker 的 Send 方法會調用,由於某些狀況下,(這次通訊的keys沒有分佈在這些server上),在客戶端就可直接認爲已接收到Response,因此要跳過。
template <typename Val> void KVWorker<Val>::Send(int timestamp, bool push, bool pull, int cmd, const KVPairs<Val>& kvs) { // slice the message SlicedKVs sliced; slicer_(kvs, Postoffice::Get()->GetServerKeyRanges(), &sliced); // need to add response first, since it will not always trigger the callback int skipped = 0; for (size_t i = 0; i < sliced.size(); ++i) { if (!sliced[i].first) ++skipped; } obj_->AddResponse(timestamp, skipped); // 這裏調用 if ((size_t)skipped == sliced.size()) { RunCallback(timestamp); } for (size_t i = 0; i < sliced.size(); ++i) { const auto& s = sliced[i]; if (!s.first) continue; Message msg; msg.meta.app_id = obj_->app_id(); msg.meta.customer_id = obj_->customer_id(); msg.meta.request = true; msg.meta.push = push; msg.meta.pull = pull; msg.meta.head = cmd; msg.meta.timestamp = timestamp; msg.meta.recver = Postoffice::Get()->ServerRankToID(i); msg.meta.priority = kvs.priority; const auto& kvs = s.second; if (kvs.keys.size()) { msg.AddData(kvs.keys); msg.AddData(kvs.vals); if (kvs.lens.size()) { msg.AddData(kvs.lens); } } Postoffice::Get()->van()->Send(msg); } }
功能是:當咱們須要等待某個發出去的Request對應的Response所有收到時,使用此函數會阻塞等待,直到 應收到Response數 等於 實際收到的Response數。
wait操做的過程就是tracker_cond_一直阻塞等待,直到發送出去的數量和已經返回的數量相等。
void Customer::WaitRequest(int timestamp) { std::unique_lock<std::mutex> lk(tracker_mu_); tracker_cond_.wait(lk, [this, timestamp]{ return tracker_[timestamp].first == tracker_[timestamp].second; }); }
Wait 函數就是使用 WaitRequest 來確保操做完成。
/** * \brief Waits until a push or pull has been finished * * Sample usage: * \code * int ts = w.Pull(keys, &vals); * Wait(ts); * // now vals is ready for use * \endcode * * \param timestamp the timestamp returned by the push or pull */ void Wait(int timestamp) { obj_->WaitRequest(timestamp); }
可是具體如何調用,則是用戶自行決定,好比:
for (int i = 0; i < repeat; ++i) { kv.Wait(kv.Push(keys, vals)); }
因而這就來到了同步策略的問題。
不一樣的worker同時並行運算的時候,可能由於網絡、機器配置等外界緣由,致使不一樣的worker的進度是不同的,如何控制worker的同步機制是一個比較重要的課題。
通常來講,有三個級別的異步控制協議:BSP(Bulk Synchronous Parallel),SSP(Stalness Synchronous Parallel)和ASP(Asynchronous Parallel),它們的同步限制依次放寬。爲了追求更快的計算速度,算法能夠選擇更寬鬆的同步協議。
爲了解決性能的問題,業界開始探索這裏的一致性模型,最早出來的版本是ASP模式,在ASP以後提出了另外一種相對極端的同步協議BSP,後來有人提出將ASP和BSP作一下折中,就是SSP。
這三個協議具體以下:
ASP:task之間徹底不用相互等待,徹底不顧worker之間的順序,每一個worker按照本身的節奏走,跑完一個迭代就update,先完成的task,繼續下一輪的訓練。
優勢:消除了等待慢task的時間,減小了GPU的空閒時間,所以與BSP相比提升了硬件效率。計算速度快,最大限度利用了集羣的計算能力,全部的worker所在的機器都不用等待
缺點:
BSP:是通常分佈式計算採用的同步協議,每一輪迭代中都須要等待全部的task計算完成。每一個worker都必須在同一個迭代運行,只有一個迭代任務全部的worker都完成了,纔會進行一次worker和server之間的同步和分片更新。
BSP的模式和單機串行由於僅僅是batch size的區別,因此在模型收斂性上是徹底同樣的。同時,由於每一個worker在一個週期內是能夠並行計算的,因此有了必定的並行能力。spark用的就是這種方式。
優勢:適用範圍廣;每一輪迭代收斂質量高
缺點:每一輪迭代中,,BSP要求每一個worker等待或暫停來自其餘worker的梯度,這樣就須要等待最慢的task,從而顯著下降了硬件效率,致使總體任務計算時間長。整個worker group的性能由其中最慢的worker決定;這個worker通常稱爲straggler。
SSP:容許必定程度的task進度不一致,但這個不一致有一個上限,稱爲staleness值,即最快的task最多領先最慢的task staleness輪迭代。
就是把將ASP和BSP作一下折中。既然ASP是容許不一樣worker之間的迭代次數間隔任意大,而BSP則只容許爲0,那我就取一個常數s。有了SSP,BSP就能夠經過指定s=0而獲得。而ASP一樣能夠經過制定s=∞來達到。
優勢:必定程度減小了task之間的等待時間,計算速度較快。
缺點:每一輪迭代的收斂質量不如BSP,達到一樣的收斂效果可能須要更多輪的迭代,適用性也不如BSP,部分算法不適用。
沐神在論文中提到,parameter server 爲用戶提供了多種任務依賴方式:
Sequential: 這裏實際上是 synchronous task,任務之間是有順序的,只有上一個任務完成,才能開始下一個任務;
Eventual: 跟 sequential 相反,全部任務之間沒有順序,各自獨立完成本身的任務,
Bounded Delay:這是sequential 跟 eventual 之間的trade-off,能夠設置一個 \(\tau\) 做爲最大的延時時間。也就是說,只有 \(>\tau\) 以前的任務都被完成了,才能開始一個新的任務;極端的狀況:
ps-lite裏面有幾個涉及到等待同步的地方:
更復雜的好比Asp,bsp,ssp能夠經過增長相應的Command來完成。
假設咱們要解決如下問題
其中 (yi, xi) 是一個樣本對,w是模型權重。
咱們考慮使用批量大小爲b的小批量隨機梯度降低(SGD)來解決上述問題。 在步驟 t,該算法首先隨機選取b個樣本,而後經過下面公式更新權重w
咱們使用兩個例子來展現在ps-lite之中如何實現一個分佈式優化算法。
第一個示例中,咱們將SGD擴展爲異步SGD。 服務器會維護模型權重w,其中server k 將得到權重w的第k個階段,由 wk 表示。 一旦Server從worker收到梯度,server k將更新它所維護的權重。
t = 0; while (Received(&grad)) { w_k -= eta(t) * grad; t++; }
對於一個worker來講,每個步驟會作四件事情
Read(&X, &Y); // 讀取一個 minibatch 數據 Pull(&w); // 從服務器拉去最新的權重 ComputeGrad(X, Y, w, &grad); // 計算梯度 Push(grad); // 把權重推送給服務器
ps-lite將提供push和pull函數,worker 將與具備正確部分數據的server通訊。
請注意:異步SGD在算法模式上與單機版本不一樣。 因爲worker之間沒有通訊,所以有可能在一個worker計算梯度的時候,其餘worker就更新了服務器上的權重。 即,每一個worker可能會用到延遲的權重。
與異步版本不一樣,同步版本在語義上與單機算法相同。 就是每一次迭代都要全部的worker計算好梯度,而且同步到server中。
咱們使用scheduler 來管理數據同步。
for (t = 0, t < num_iteration; ++t) { for (i = 0; i < num_worker; ++i) { IssueComputeGrad(i, t); } for (i = 0; i < num_server; ++i) { IssueUpdateWeight(i, t); } WaitAllFinished(); }
IssueComputeGrad
和 IssueUpdateWeight
會發送命令給 worker 和 servers,而後 scheduler 會調用 WaitAllFinished
等待全部發送的命令結束。
對於一個worker接受到一個命令,它會作以下:
ExecComputeGrad(i, t) { Read(&X, &Y); // 讀取數據 minibatch = batch / num_workers 個樣本 Pull(&w); // 從服務器拉取最新權重 ComputeGrad(X, Y, w, &grad); // 計算梯度 Push(grad); // 把權重推送給服務器 }
這個算法和ASGD幾乎相同,只是每次步驟中,只有 b/num_workers個樣本被處理。
在 server 節點,與ASGD相比,多了一個聚合步驟。是把全部worker的梯度累計起來以後,再配合 學習速率進行迭代。
ExecUpdateWeight(i, t) { for (j = 0; j < num_workers; ++j) { Receive(&grad); aggregated_grad += grad; } w_i -= eta(t) * aggregated_grad; }
PostOffice:一個單例模式的全局管理類,每個 node (每一個 Node 可使用 hostname + port 來惟一標識)在生命期內具備一個PostOffice,直接從字面意義能夠知道,PostOffice就是郵局;
Van:通訊模塊,負責與其餘節點的網絡通訊和Message的實際收發工做。PostOffice持有一個Van成員,直接從字面意義能夠知道,Van就是小推車,用來提供送信的功能;
SimpleApp:KVServer和KVWorker的父類,它提供了簡單的Request, Wait, Response,Process功能;KVServer和KVWorker分別根據本身的使命重寫了這些功能;
Customer:每一個SimpleApp對象持有一個Customer類的成員,且Customer須要在PostOffice進行註冊,該類主要負責:
Customer 由名字就能夠知道,是郵局的客戶,就是 SimpleApp 在郵局的代理人。由於須要 worker,server 須要集中精力爲算法上,因此把 worker,server 邏輯上與網絡相關的收發消息功能都總結/轉移到 Customer 之中。
下面給出了邏輯圖。
+--------------------------+ | Van | | | DataMessage +-----------> Receiving | | 1 + | +---------------------------+ | | | | Postoffice | | | 2 | | | | v | GetCustomer | | | ProcessDataMsg <------------------> unordered_map customers_| | + | 3 | | | | | +---------------------------+ +--------------------------+ | | | 4 | +-------------------------+ | Customer | | | | | | v | | Accept | | + | | | | | | 5 | | v | | recv_queue_ | +-----------------+ | + | |KVWorker | | | 6 | +--------> | | | | | | 8 | Process | | v | | +-----------------+ | Receiving | | | + | | | | 7 | | | | | | +-----------------+ | v | | |KVServer | | recv_handle_+---------+--------> | | | | 8 | Process | +-------------------------+ +-----------------+
★★★★★★關於生活和技術的思考★★★★★★
微信公衆帳號:羅西的思考
若是您想及時獲得我的撰寫文章的消息推送,或者想看看我的推薦的技術資料,敬請關注。
****
https://www.cs.cmu.edu/~muli/file/parameter_server_osdi14.pdf
sona:Spark on Angel大規模分佈式機器學習平臺介紹
基於Parameter Server的可擴展分佈式機器學習架構
Mu Li. Scaling Distributed Machine Learning with the Parameter Server.
CMU. http://parameterserver.org/
Joseph E.Gonzalez. Emerging Systems For Large-scale Machine Learning.
【分佈式計算】MapReduce的替代者-Parameter Server
Parameter Server for Distributed Machine Learning
PS-Lite Documents
ps-lite源碼剖析
http://blog.csdn.net/stdcoutzyx/article/details/51241868
http://blog.csdn.net/cyh_24/article/details/50545780
https://www.zybuluo.com/Dounm/note/529299
http://blog.csdn.net/KangRoger/article/details/73307685
http://www.cnblogs.com/heguanyou/p/7868596.html
MXNet之ps-lite及parameter server原理
ps-lite學些系列之一 ----- mac安裝ps-lite
https://www.zhihu.com/topic/20175752/top-answers
Large Scale Machine Learning--An Engineering Perspective--目錄
ps-lite學些系列之3 --- ps-lite的簡介(1. Overview)
https://www.zhihu.com/topic/20175752/top-answers
https://blog.csdn.net/zkwdn/article/details/53840091