Paracel是豆瓣開發的一個分佈式計算框架,它基於參數服務器範式來解決機器學習的問題:邏輯迴歸、SVD、矩陣分解(BFGS,sgd,als,cg),LDA,Lasso...。html
Paracel支持數據和模型的並行,爲用戶提供簡單易用的通訊接口,比mapreduce式的系統要更加靈活。Paracel同時支持異步的訓練模式,使迭代問題收斂地更快。此外,Paracel程序的結構與串行程序十分類似,用戶能夠更加專一於算法自己,不需將精力過多放在分佈式邏輯上。node
由於咱們以前已經用ps-lite對參數服務器的基本功能作了介紹,因此在本文中,咱們主要與ps-lite比對大的方面和一些關鍵技術點(paracel沒有開源容錯機制,是個不小的遺憾),而不會像對 ps-lite 那樣作較詳細的分析。python
對於本文來講,ps-lite的主要邏輯以下:c++
本系列其餘文章是:算法
[源碼解析] 機器學習參數服務器ps-lite 之(1) ----- PostOfficeshell
[源碼解析] 機器學習參數服務器ps-lite(2) ----- 通訊模塊Vanapache
[源碼解析] 機器學習參數服務器ps-lite 之(3) ----- 代理人Customerjson
[源碼解析]機器學習參數服務器ps-lite(4) ----- 應用節點實現服務器
本文在解析時候會刪除部分非主體代碼。微信
咱們首先經過源碼提供的LR算法看看如何使用。
咱們從源碼中找到 LR 相關部分來看,如下就是一些必要配置,在其中我作了部分翻譯,須要留意的是:用一條命令能夠啓動若干不一樣類型的實例,實例運行的都是可執行程序 lr。
- Enter Paracel's home directory 進入Paracel工做目錄
```cd paracel;```
- Generate training dataset for classification 產生訓練數據集
```python ./tool/datagen.py -m classification -o training.dat -n 2500 -k 100```
- Set up link library path: 設置連接庫路徑
```export LD_LIBRARY_PATH=your_paracel_install_path/lib```
Create a json file named
cfg.json
, see example in Parameters section below. 建立配置文件Run (4 workers, local mode in the following example) 運行(4個worker,2個參數服務器)
```./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr```
Default parameters are set in a JSON format file. For example, we create a cfg.json as below(modify
your_paracel_install_path
):{
"training_input" : "training.dat", 訓練集
"test_input" : "training.dat", 驗證集
"predict_input" : "training.dat", label數據
"output" : "./lr_result/",
"update_file" : "your_paracel_install_path/lib/liblr_update.so",
"update_func" : "lr_theta_update", 更新函數
"method" : "ipm",
"rounds" : 100,
"alpha" : 0.001,
"beta" : 0.01,
"debug" : false
}
經過makefile咱們能夠看到,是把 lr_driver.cpp, lr.cpp一塊兒編譯成爲 lr 可執行文件,這就是worker進程。
同時也把 update.cpp 編譯成庫,將會被server加載調用。
add_library(lr_update SHARED update.cpp) # 參數服務器如何更新 target_link_libraries(lr_update ${CMAKE_DL_LIBS}) install(TARGETS lr_update LIBRARY DESTINATION lib) add_library(lr_method SHARED lr.cpp) # 算法代碼 target_link_libraries(lr_method ${Boost_LIBRARIES} comm scheduler) install(TARGETS lr_method LIBRARY DESTINATION lib) add_executable(lr lr_driver.cpp) # 驅動代碼 target_link_libraries(lr ${Boost_LIBRARIES} comm scheduler lr_method) install(TARGETS lr RUNTIME DESTINATION bin)
對於 LR,有四種 大規模深度神經網絡的隨機梯度降低法 能夠選擇
dgd: distributed gradient descent learning
ipm: iterative parameter mixtures learning
downpour: asynchrounous gradient descent learning
agd: slow asynchronous gradient descent learning
咱們選擇 agd 算法來學習分析:http://www.eecs.berkeley.edu/~brecht/papers/hogwildTR.pdf
首先,咱們看看驅動代碼 lr_driver.cpp,邏輯就是:
DEFINE_string(server_info, "host1:7777PARACELhost2:8888", "hosts name string of paracel-servers.\n"); DEFINE_string(cfg_file, "", "config json file with absolute path.\n"); int main(int argc, char *argv[]) { // 配置運行環境和通訊 paracel::main_env comm_main_env(argc, argv); paracel::Comm comm(MPI_COMM_WORLD); google::SetUsageMessage("[options]\n\t--server_info\n\t--cfg_file\n"); google::ParseCommandLineFlags(&argc, &argv, true); // 讀取分析參數 paracel::json_parser pt(FLAGS_cfg_file); std::string training_input, test_input, predict_input, output, update_file, update_func, method; try { training_input = pt.check_parse<std::string>("training_input"); test_input = pt.check_parse<std::string>("test_input"); predict_input = pt.check_parse<std::string>("predict_input"); output = pt.parse<std::string>("output"); update_file = pt.check_parse<std::string>("update_file"); update_func = pt.parse<std::string>("update_func"); method = pt.parse<std::string>("method"); } catch (const std::invalid_argument & e) { std::cerr << e.what(); return 1; } int rounds = pt.parse<int>("rounds"); double alpha = pt.parse<double>("alpha"); double beta = pt.parse<double>("beta"); bool debug = pt.parse<bool>("debug"); // 生成 logistic_regression,進行訓練,驗證,預測 paracel::alg::logistic_regression lr_solver(comm, FLAGS_server_info, training_input, output, update_file, update_func, method, rounds, alpha, beta, debug); lr_solver.solve(); std::cout << "final loss: " << lr_solver.calc_loss() << std::endl; lr_solver.test(test_input); lr_solver.predict(predict_input); lr_solver.dump_result(); return 0; }
從以前的配置中咱們知道更新部分是:
"update_file" : "your_paracel_install_path/lib/liblr_update.so", "update_func" : "lr_theta_update",
因此咱們從 alg/classification/logistic_regression/update.cpp 中獲得更新函數以下:
具體就是合併兩個參數而後返回。這部分代碼被編譯成庫,在server之中被加載運行。
#include <vector> #include "proxy.hpp" #include "paracel_types.hpp" using std::vector; extern "C" { extern paracel::update_result lr_theta_update; } vector<double> local_update(vector<double> a, vector<double> b) { vector<double> r; for(int i = 0; i < (int)a.size(); ++i) { r.push_back(a[i] + b[i]); } return r; } paracel::update_result lr_theta_update = paracel::update_proxy(local_update);
logistic_regression 是類定義,位於lr.hpp。logistic_regression 須要繼承 paracel::paralg 才能使用。
namespace paracel { namespace alg { class logistic_regression: public paracel::paralg { public: logistic_regression(paracel::Comm, string, string _input, string output, string update_file_name, string update_func_name, string = "ipm", int _rounds = 1, double _alpha = 0.002, double _beta = 0.1, bool _debug = false); virtual ~logistic_regression(); double lr_hypothesis(const vector<double> &); void dgd_learning(); // distributed gradient descent learning void ipm_learning(); // by default: iterative parameter mixtures learning void downpour_learning(); // asynchronous gradient descent learning void agd_learning(); // slow asynchronous gradient descent learning virtual void solve(); double calc_loss(); void dump_result(); void print(const vector<double> &); void test(const std::string &); void predict(const std::string &); private: void local_parser(const vector<string> &, const char); void local_parser_pred(const vector<string> &, const char); private: string input; string update_file, update_func; std::string learning_method; int worker_id; int rounds; double alpha, beta; bool debug = false; vector<vector<double> > samples, pred_samples; vector<double> labels; vector<double> theta; vector<double> loss_error; vector<std::pair<vector<double>, double> > predv; int kdim; // not contain 1 }; } // namespace alg } // namespace paracel
solve 是主體代碼,依據不一樣配置選擇不一樣的隨機梯度降低法來訓練。
void logistic_regression::solve() { auto lines = paracel_load(input); local_parser(lines); paracel_sync(); if(learning_method == "dgd") { dgd_learning(); } else if(learning_method == "ipm") { ipm_learning(); } else if(learning_method == "downpour") { downpour_learning(); } else if(learning_method == "agd") { agd_learning(); } else { ERROR_ABORT("method do not support"); } paracel_sync(); }
咱們找出論文中的算法比對:
下面代碼和論文算法基本一一對應,邏輯以下。
void logistic_regression::agd_learning() { int data_sz = samples.size(); int data_dim = samples[0].size(); theta = paracel::random_double_list(data_dim); paracel_write("theta", theta); // first push // 首先把 theta 推送到參數服務器 vector<int> idx; for(int i = 0; i < data_sz; ++i) { idx.push_back(i); } paracel_register_bupdate(update_file, update_func); double coff2 = 2. * beta * alpha; vector<double> delta(data_dim); unsigned time_seed = std::chrono::system_clock::now().time_since_epoch().count(); // train loop for(int rd = 0; rd < rounds; ++rd) { std::shuffle(idx.begin(), idx.end(), std::default_random_engine(time_seed)); theta = paracel_read<vector<double> >("theta"); // 從參數服務器讀取最新的 theta vector<double> theta_old(theta); // traverse data for(auto sample_id : idx) { theta = paracel_read<vector<double> >("theta"); theta_old = theta; double coff1 = alpha * (labels[sample_id] - lr_hypothesis(samples[sample_id])); for(int i = 0; i < data_dim; ++i) { double t = coff1 * samples[sample_id][i] - coff2 * theta[i]; theta[i] += t; } if(debug) { loss_error.push_back(calc_loss()); } for(int i = 0; i < data_dim; ++i) { delta[i] = theta[i] - theta_old[i]; } // 把計算結果推送到參數服務器 paracel_bupdate("theta", delta); // you could push a batch of delta into a queue to optimize } // traverse } // rounds theta = paracel_read<vector<double> >("theta"); // last pull // 獲得最終結果 }
lr的邏輯圖以下:
+------------+ +-------------------------------------------------+ | lr_driver | |logistic_regression | | | | | | +---------------------------------------> solve | +------------+ lr_solver.solve() | + | | | | | | | | | | | +---------------------+-----------------------+ | | | agd_learning | | | | +-----------------------+ | | | | | | | | | | | v | | | | | theta = paracel_read("theta") | | | | | | | | | | | | | | | | | v | | | | | | | | | | delta[i] = theta[i] - theta_old[i] | | | | | + | | | | | | | | | | | | | | | | | v | | | | | paracel_bupdate("theta", delta) | | | | | + + | | | | | | | | | | | +-----------------------+ | | | | +---------------------------------------------+ | | | | +-------------------------------------------------+ | Worker | +------------------------------------------------------------------------------------+ Server | +---------------------+ | Server | | | | | | v | | local_update | | | +---------------------+
至此,咱們知道了Paracel如何使用,實現是以driver爲核心進行展開,用戶須要編寫 update函數和算法函數。可是距離深刻了解還差得很遠。
咱們目前有幾個問題須要解決:
咱們須要經過啓動部分來繼續研究。
如前所述./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr
是啓動命令,paracel 經過 prun.py 進入系統,因此咱們分析這個腳本。
下面咱們省略一些非主體代碼,好比處理參數,邏輯以下:
if __name__ == '__main__': optpar = OptionParser() # 省略處理參數 (options, args) = optpar.parse_args() nsrv = 1 nworker = 1 if options.parasrv_num: nsrv = options.parasrv_num if options.worker_num: nworker = options.worker_num if not options.method_server: options.method_server = options.method if not options.ppn_server: options.ppn_server = options.ppn if not options.mem_limit_server: options.mem_limit_server = options.mem_limit if not options.hostfile_server: options.hostfile_server = options.hostfile # 利用 init_starter 獲得如何啓動server,worker,構建出相應字符串 server_starter = init_starter(options.method_server, str(options.mem_limit_server), str(options.ppn_server), options.hostfile_server, options.server_group) worker_starter = init_starter(options.method, str(options.mem_limit), str(options.ppn), options.hostfile, options.worker_group) #initport = random.randint(30000, 65000) #initport = get_free_port() initport = 11777 start_parasrv_cmd_lst = [server_starter, str(nsrv), os.path.join(PARACEL_INSTALL_PREFIX, 'bin/start_server --start_host'), socket.gethostname(), ' --init_port', str(initport)] start_parasrv_cmd = ' '.join(start_parasrv_cmd_lst) # 利用 subprocess.Popen 啓動server,其中server的執行程序是 bin/start_server procs = subprocess.Popen(start_parasrv_cmd, shell=True, preexec_fn=os.setpgrp) try: serverinfo = paracelrun_cpp_proxy(nsrv, initport) entry_cmd = '' if args: entry_cmd = ' '.join(args) alg_cmd_lst = [worker_starter, str(nworker), entry_cmd, '--server_info', serverinfo, '--cfg_file', options.config] alg_cmd = ' '.join(alg_cmd_lst) # 利用 os.system 啓動 worker os.system(alg_cmd) os.killpg(procs.pid, 9) except Exception as e: logger.exception(e) os.killpg(procs.pid, 9)
init_starter 函數會依據配置構建一個字符串。其中 paracel 有三種啓動方式:
The –m_server and -m options above refer to what type of cluster you use. Paracel support mesos clusters, mpi clusters and multiprocessers in a single machine.
咱們利用前面horovod文章的知識能夠知道,mpirun 是能夠啓動多個進程。
結合以前的命令行,./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr
,能夠知道 local 就是 mpirun,因此paracel 經過 mpirun 來啓動了 4 個 lr 進程。
具體代碼以下:
def init_starter(method, mem_limit, ppn, hostfile, group): '''Assemble commands for running paracel programs''' starter = '' if not hostfile: hostfile = '~/.mpi/large.18' if method == 'mesos': if group: starter = '%s/mrun -m %s -p %s -g %s -n ' % (PARACEL_INSTALL_PREFIX, mem_limit, ppn, group) else: starter = '%s/mrun -m %s -p %s -n ' % (PARACEL_INSTALL_PREFIX, mem_limit, ppn) elif method == 'mpi': starter = 'mpirun --hostfile %s -n ' % hostfile elif method == 'local': starter = 'mpirun -n ' else: print 'method %s not supported.' % method sys.exit(1) return starter
前面提到,server 執行程序對應的是 bin/start_server。
咱們看看其構建 src/CMakeLists.txt,因而咱們能夠去查找 start_server.cpp。
add_library(comm SHARED comm.cpp) # 通訊相關庫 install(TARGETS comm LIBRARY DESTINATION lib) add_library(scheduler SHARED scheduler.cpp # 調度 install(TARGETS scheduler LIBRARY DESTINATION lib) add_library(default SHARED default.cpp) # 缺省庫 install(TARGETS default LIBRARY DESTINATION lib) # 這裏能夠看到start_server.cpp add_executable(start_server start_server.cpp) target_link_libraries(start_server ${Boost_LIBRARIES} ${CMAKE_DL_LIBS}) install(TARGETS start_server RUNTIME DESTINATION bin) add_executable(paracelrun_cpp_proxy paracelrun_cpp_proxy.cpp) target_link_libraries(paracelrun_cpp_proxy ${Boost_LIBRARIES} ${CMAKE_DL_LIBS}) install(TARGETS paracelrun_cpp_proxy RUNTIME DESTINATION bin)
src/start_server.cpp 是服務器主體代碼。
結合以前的命令行,./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr
,能夠知道 local 就是 mpirun,因此paracel 經過 mpirun 來啓動了 2 個 start_server 進程,即兩個參數服務器。
#include <gflags/gflags.h> #include "server.hpp" DEFINE_string(start_host, "beater7", "host name of start node\n"); DEFINE_string(init_port, "7773", "init port"); int main(int argc, char *argv[]) { google::SetUsageMessage("[options]\n\ --start_host\tdefault: balin\n\ --init_port\n"); google::ParseCommandLineFlags(&argc, &argv, true); paracel::init_thrds(FLAGS_start_host, FLAGS_init_port); // join inside return 0; }
在 include/server.hpp 文件之中,init_thrds 函數啓動了一系列線程,具體邏輯以下。
// init_host is the hostname of starter void init_thrds(const paracel::str_type & init_host, const paracel::str_type & init_port) { // 構建 zmq 環境 zmq::context_t context(2); zmq::socket_t sock(context, ZMQ_REQ); paracel::str_type info = "tcp://" + init_host + ":" + init_port; sock.connect(info.c_str()); char hostname[1024], freeport[1024]; size_t size = sizeof(freeport); // hostname of servers gethostname(hostname, sizeof(hostname)); paracel::str_type ports = hostname; ports += ":"; // create sock in every thrd 爲每一個線程創建了socket std::vector<zmq::socket_t *> sock_pt_lst; for(int i = 0; i < paracel::threads_num; ++i) { zmq::socket_t *tmp; tmp = new zmq::socket_t(context, ZMQ_REP); sock_pt_lst.push_back(tmp); sock_pt_lst.back()->bind("tcp://*:*"); sock_pt_lst.back()->getsockopt(ZMQ_LAST_ENDPOINT, &freeport, &size); if(i == paracel::threads_num - 1) { ports += local_parse_port(paracel::str_type(freeport)); } else { ports += local_parse_port(std::move(paracel::str_type(freeport))) + ","; } } zmq::message_t request(ports.size()); std::memcpy((void *)request.data(), &ports[0], ports.size()); sock.send(request); zmq::message_t reply; sock.recv(&reply); // 創建服務器處理線程 thrd_exec paracel::list_type<std::thread> threads; for(int i = 0; i < paracel::threads_num - 1; ++i) { threads.push_back(std::thread(thrd_exec, std::ref(*sock_pt_lst[i]))); } // 創建ssp線程 thrd_exec_ssp threads.push_back(std::thread(thrd_exec_ssp, std::ref(*sock_pt_lst.back()))); // 等待線程結束 for(auto & thrd : threads) { thrd.join(); } for(int i = 0; i < paracel::threads_num; ++i) { delete sock_pt_lst[i]; } zmq_ctx_destroy(context); } // init_thrds
./prun.py -w 4 -p 2 -c cfg.json -m local your_paracel_install_path/bin/lr
的對應啓動邏輯圖具體以下:
prun.py + | | | +----------------+ | +--> | start_server | v | +----------------+ server_starter = init_starter +--> mpirun -n 2 +----+ + | +----------------+ | | | start_server | | | | + | | +--> | | | v | | | worker_starter = init_starter +--> mpirun -n 4 | | | + | v | | | init_thrds | | | + | | | | | +-------+----+--+-------+ | | | | | | | | | | | | | | | v | v v v v | thrd_exec | bin/lr bin/lr bin/lr bin/lr | + | | | | | | | | | | | v | | thrd_exec_ssp | +----------------+
目前咱們知道了,worker和server都有多種啓動方式,好比用 mpi 的方式來啓動多個進程。
worker 端就是經過 driver.cpp 爲主體,啓動多個進程。
server端就是經過 start_server 爲主體,啓動多個進程,就是多個進程(參數服務器)組成了一個集羣。
以上這些和ps-lite很是相似。
下面咱們要分別深刻這兩個角色的內部。
經過以前ps-lite咱們知道,參數服務器大多使用 KV 存儲來保存參數,因此咱們先介紹KV存儲。
在 include/kv_def.hpp 給出了server 端使用的KV存儲。
#include "paracel_types.hpp" #include "kv.hpp" namespace paracel { paracel::kvs<paracel::str_type, int> ssp_tbl; // 用來協助實現 SSP paracel::kvs<paracel::str_type, paracel::str_type> tbl_store; // 主要的kv存儲 }
KV 存儲的定義在 include/kv.hpp,下面省略了部分代碼。
能夠看出來,基本功能就是維護了內存table,提供了set系列函數和get系列函數,其中當須要返回 value, unique 的時候,就採用hash函數處理。
template <class K, class V> struct kvs { public: bool contains(const K & k) { return kvdct.count(k); } void set(const K & k, const V & v) { kvdct[k] = v; } void set_multi(const paracel::dict_type<K, V> & kvdict) { for(auto & kv : kvdict) { set(kv.first, kv.second); } } boost::optional<V> get(const K & k) { auto fi = kvdct.find(k); if(fi != kvdct.end()) { return boost::optional<V>(fi->second); } else return boost::none; } bool get(const K & k, V & v) { auto fi = kvdct.find(k); if(fi != kvdct.end()) { v = fi->second; return true; } else { return false; } } paracel::list_type<V> get_multi(const paracel::list_type<K> & keylst) { paracel::list_type<V> valst; for(auto & key : keylst) { valst.push_back(kvdct.at(key)); } return valst; } void get_multi(const paracel::list_type<K> & keylst, paracel::list_type<V> & valst) { for(auto & key : keylst) { valst.push_back(kvdct.at(key)); } } void get_multi(const paracel::list_type<K> & keylst, paracel::dict_type<K, V> & valdct) { valdct.clear(); for(auto & key : keylst) { auto it = kvdct.find(key); if(it != kvdct.end()) { valdct[key] = it->second; } } } // 這裏使用了 hash 函數 // gets(key) -> value, unique boost::optional<std::pair<V, paracel::hash_return_type> > gets(const K & k) { if(auto v = get(k)) { std::pair<V, paracel::hash_return_type> ret(*v, hfunc(*v)); return boost::optional< std::pair<V, paracel::hash_return_type> >(ret); } else { return boost::none; } } // compare-and-set, cas(key, value, unique) -> True/False bool cas(const K & k, const V & v, const paracel::hash_return_type & uniq) { if(auto r = gets(k)) { if(uniq == (*r).second) { set(k, v); return true; } else { return false; } } else { kvdct[k] = v; } return true; } paracel::dict_type<K, V> getall() { return kvdct; } private: //std::tr1::unordered_map<K, V> kvdct; paracel::dict_type<K, V> kvdct; paracel::hash_type<V> hfunc; };
thrd_exec 線程實現了參數服務器的基本處理邏輯:就是針對worker傳來的不一樣的命令進行相關處理(大部分就是針對KV存儲進行處理),好比:
須要注意的是,這裏使用了用戶定義的update函數,即:
下面刪除了部分非主體代碼。
// thread entry void thrd_exec(zmq::socket_t & sock) { paracel::packer<> pk; update_result update_f; filter_result pullall_special_f; filter_result remove_special_f; // 這裏使用了dlopen_update_lambda來對用戶設置的update函數進行生成,賦值爲 update_f auto dlopen_update_lambda = [&](const paracel::str_type & fn, const paracel::str_type & fcn) { void *handler = dlopen(fn.c_str(), RTLD_NOW | RTLD_LOCAL | RTLD_NODELETE); auto local = dlsym(handler, fcn.c_str()); update_f = *(std::function<paracel::str_type(paracel::str_type, paracel::str_type)>*) local; dlclose(handler); }; // 主體邏輯 while(1) { zmq::message_t s; sock.recv(&s); auto scrip = paracel::str_type(static_cast<const char *>(s.data()), s.size()); auto msg = paracel::str_split_by_word(scrip, paracel::seperator); auto indicator = pk.unpack(msg[0]); if(indicator == "pull") { // 若是是從參數服務器讀取參數,則直接返回 auto key = pk.unpack(msg[1]); paracel::str_type result; auto exist = paracel::tbl_store.get(key, result); // 讀取kv if(!exist) { paracel::str_type tmp = "nokey"; rep_send(sock, tmp); } else { rep_send(sock, result); // 返回 } } if(indicator == "pull_multi") { // 讀取多個參數 paracel::packer<paracel::list_type<paracel::str_type> > pk_l; auto key_lst = pk_l.unpack(msg[1]); auto result = paracel::tbl_store.get_multi(key_lst); rep_pack_send(sock, result); } if(indicator == "pullall") { // 讀取全部參數 auto dct = paracel::tbl_store.getall(); rep_pack_send(sock, dct); } mutex.lock(); if(indicator == "push") { // 插入參數 auto key = pk.unpack(msg[1]); paracel::tbl_store.set(key, msg[2]); bool result = true; rep_pack_send(sock, result); } if(indicator == "push_multi") { // 插入多個參數 paracel::packer<paracel::list_type<paracel::str_type> > pk_l; paracel::dict_type<paracel::str_type, paracel::str_type> kv_pairs; auto key_lst = pk_l.unpack(msg[1]); auto val_lst = pk_l.unpack(msg[2]); assert(key_lst.size() == val_lst.size()); for(int i = 0; i < (int)key_lst.size(); ++i) { kv_pairs[key_lst[i]] = val_lst[i]; } paracel::tbl_store.set_multi(kv_pairs); //插入kv bool result = true; rep_pack_send(sock, result); } if(indicator == "update" || indicator == "bupdate") { // 更新參數 if(msg.size() > 3) { if(msg.size() != 5) { ERROR_ABORT("invalid invoke in server end"); } // open request func auto file_name = pk.unpack(msg[3]); auto func_name = pk.unpack(msg[4]); dlopen_update_lambda(file_name, func_name); } else { if(!update_f) { dlopen_update_lambda("../local/build/lib/default.so", "default_incr_i"); } } auto key = pk.unpack(msg[1]); // 這裏使用用戶的update函數來對kv進行處理 std::string result = kv_update(key, msg[2], update_f); rep_send(sock, result); } if(indicator == "remove") { // 刪除參數 auto key = pk.unpack(msg[1]); auto result = paracel::tbl_store.del(key); rep_pack_send(sock, result); } mutex.unlock(); } // while } // thrd_exec
簡化如圖:
+--------------------------------------------------------------------------------------+ | thrd_exec | | | | +---------------------------------> while(1) | | | + | | | | | | | | | | | +----------+----------+--------+--+------+----------+---------+---------+ | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | v v v v v v v v | | | | | | pull pull_multi pullall push push_multi update bupdate remove | | | + + + + + + + + | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | v v v v v v v v | | | +----------+----------+--------+----+----+----------+---------+---------+ | | | | | | | | | | | | | | | | | | +-----------------------------------------+ | | | +--------------------------------------------------------------------------------------+
目前爲止,咱們能夠看到,Paracel和ps-lite也很相似,服務器維護了一個存儲,服務器也能夠處理客戶端的請求。
Worker 就是用來訓練算法的進程。從前面咱們瞭解,算法須要繼承paracel::paralg才能使用參數服務器功能。
namespace paracel { namespace alg { class logistic_regression: public paracel::paralg { .....
paracel::paralg 就能夠認爲是參數服務器的API,或者代理,咱們下面就看看。
Paralg是提供Paracel主要功能的基本類,能夠理解爲一個算法API類,或者對外功能API類。
咱們只給出其成員變量,暫時省略其函數實現。最主要幾個爲:
class paralg { private: class parasrv { // 能夠理解爲是參數服務器類 using l_type = paracel::list_type<paracel::kvclt>; using dl_type = paracel::list_type<paracel::dict_type<paracel::str_type, paracel::str_type> >; public: parasrv(paracel::str_type hosts_dct_str) { // init dct_lst dct_lst = paracel::get_hostnames_dict(hosts_dct_str); // init srv_sz srv_sz = dct_lst.size(); // init kvm for(auto & srv : dct_lst) { paracel::kvclt kvc(srv["host"], srv["ports"]); kvm.push_back(std::move(kvc)); } // init servers for(auto i = 0; i < srv_sz; ++i) { servers.push_back(i); } // init hashring p_ring = new paracel::ring<int>(servers); } virtual ~parasrv() { delete p_ring; } public: dl_type dct_lst; int srv_sz = 1; l_type kvm; paracel::list_type<int> servers; // 具體服務器列表 paracel::ring<int> *p_ring; // hash ring }; // nested class parasrv private: int stale_cache, clock, total_iters; // 同步須要 int clock_server = 0; paracel::Comm worker_comm; //通訊類,好比 MPI 通訊 paracel::str_type output; int nworker = 1; int rounds = 1; int limit_s = 0; bool ssp_switch = false; parasrv *ps_obj; // 能夠理解爲是正式的參數服務器類。 paracel::dict_type<paracel::default_id_type, paracel::default_id_type> rm; paracel::dict_type<paracel::default_id_type, paracel::default_id_type> cm; paracel::dict_type<paracel::default_id_type, paracel::default_id_type> dm; paracel::dict_type<paracel::default_id_type, paracel::default_id_type> col_dm; paracel::dict_type<paracel::str_type, paracel::str_type> keymap; paracel::dict_type<paracel::str_type, boost::any> cached_para; paracel::update_result update_f; int npx = 1, npy = 1; }
編寫一個Paracel程序須要對paralg基類進行子類化,而且必須重寫virtual solve方法。其中一些是SPMD iterfaces 並行接口。
咱們從以前 LR 的實現能夠看到須要繼承 paracel::paralg 。
class logistic_regression: public paracel::paralg
就是說,用戶的solve函數能夠直接調用 Paralg 的函數來完成基本功能。
咱們以 paracel::paracel_read 爲例,能夠看到是使用 parasrv.kvm 的功能,咱們後續會繼續介紹 parasrv。
template <class V> V paracel_read(const paracel::str_type & key, int replica_id = -1) { if(ssp_switch) { // 若是應用ssp,應該如何處理。咱們下文就將具體介紹ssp如何處理 V val; if(clock == 0 || clock == total_iters) { cached_para[key] = boost::any_cast<V>(ps_obj-> kvm[ps_obj->p_ring->get_server(key)]. pull<V>(key)); val = boost::any_cast<V>(cached_para[key]); } else if(stale_cache + limit_s > clock) { val = boost::any_cast<V>(cached_para[key]); } else { while(stale_cache + limit_s < clock) { stale_cache = ps_obj-> kvm[clock_server].pull_int(paracel::str_type("server_clock")); } cached_para[key] = boost::any_cast<V>(ps_obj-> kvm[ps_obj->p_ring->get_server(key)]. pull<V>(key)); val = boost::any_cast<V>(cached_para[key]); } return val; } // 不然直接返回 return ps_obj->kvm[ps_obj->p_ring->get_server(key)].pull<V>(key); }
worker邏輯以下:
+---------------------------------------------------------------------------+ | Algorithm | | ^ +------------------------------v | | | | | | | | | | | v | | | +----------------------------+------------------------------+ | | | | paracel_read | | | | | | | | | | ps_obj+>kvm[ps_obj+>p_ring+>get_server(key)].pull<V>(key) | | | | | | | | | +----------------------------+------------------------------+ | | | | | | | | | | | | | | | v | | | Compute | | | + | | | | | | | | | | | v | | | +---------------------------+-------------------------------+ | | | | paracel_bupdate | | | | | ps_obj->kvm[indx].bupdate | | | | | | | | | +---------------------------+-------------------------------+ | | | | | | | | | | | | | | | | | | +-----<--------------------------+ | | | +---------------------------------------------------------------------------+
Worker端的機理也相似ps-lite,經過read,pull等操做,向服務器提出請求。
在沐神論文中,Ring hash 是與數據一致性,容錯,可擴展等機制聯繫在一塊兒,好比:
parameter server 在數據一致性上,使用的是傳統的一致性哈希算法,參數key與server node id被插入到一個hash ring中。
但惋惜的是,ps-lite 沒有提供這部分代碼,paracel 雖然有 ring hash,但也不齊全,豆瓣沒有開源容錯和一致性等部分。咱們只能基於已有代碼進行學習分析。
這裏只是大體講解下,有需求的同窗能夠去網上搜索詳細文章。
從拗口的技術術語來解釋,一致性哈希的技術關鍵點是:按照經常使用的hash算法來將對應的key哈希到一個具備2^32次方個桶的空間中,即0 ~ (2^32)-1的數字空間。咱們能夠將這些數字頭尾相連,想象成一個閉合的環形。
用通俗白話來理解,這個關鍵點就是:在部署服務器的時候,服務器的序號空間已經配置成了一個固定的很是大的數字 1~2^32(不須要再改變)。服務器能夠分配爲 1~2^32 中任一序號。這樣服務器集羣能夠固定大多數算法規則 (由於序號空間是算法的重要參數),這樣面對擴容等變化只有"分配規則" 須要根據實際系統容量作相應微調。從而對總體系統影響較小。
ring 就是hash 環的實現類,這裏主要功能就是把 服務器 加入到 hash ring 之中,以及從ring之中取出服務器。
// T rep type of server name template <class T> class ring { public: ring(paracel::list_type<T> names) { for(auto & name : names) { add_server(name); } } ring(paracel::list_type<T> names, int cp) : replicas(cp) { for(auto & name : names) { add_server(name); } } void add_server(const T & name) { //std::hash<paracel::str_type> hfunc; paracel::hash_type<paracel::str_type> hfunc; std::ostringstream tmp; tmp << name; auto name_str = tmp.str(); for(int i = 0; i < replicas; ++i) { //對每個副本進行處理 std::ostringstream cvt; cvt << i; auto n = name_str + ":" + cvt.str(); auto key = hfunc(n); // 依據name生成一個key srv_hashring_dct[key] = name; //添加value srv_hashring.push_back(key); //往list添加內容 } // sort srv_hashring std::sort(srv_hashring.begin(), srv_hashring.end()); } void remove_server(const T & name) { //std::hash<paracel::str_type> hfunc; paracel::hash_type<paracel::str_type> hfunc; std::ostringstream tmp; tmp << name; auto name_str = tmp.str(); for(int i = 0; i < replicas; ++i) { // 對每一個副本進行處理 std::ostringstream cvt; cvt << i; auto n = name_str + ":" + cvt.str(); auto key = hfunc(n);// 依據name生成一個key srv_hashring_dct.erase(key);// 刪除value auto iter = std::find(srv_hashring.begin(), srv_hashring.end(), key); if(iter != srv_hashring.end()) { srv_hashring.erase(iter); // 刪除list中的內容 } } } // TODO: relief load of srv_hashring_dct[srv_hashring[0]] template <class P> T get_server(const P & skey) { //std::hash<P> hfunc; paracel::hash_type<P> hfunc; auto key = hfunc(skey);// 依據name生成一個key auto server = srv_hashring[paracel::ring_bsearch(srv_hashring, key)];//獲取server return srv_hashring_dct[server]; } private: int replicas = 32; // 分別用list和dict存儲 paracel::list_type<paracel::hash_return_type> srv_hashring; paracel::dict_type<paracel::hash_return_type, T> srv_hashring_dct; };
咱們使用 paracel_read 來看,能夠發現調用順序是
V paracel_read(const paracel::str_type & key, int replica_id = -1) { ...... ps_obj->kvm[ps_obj->p_ring->get_server(key)].pull<V>(key); }
這裏是和ps-lite的不一樣之處,就是用ring-hash來維護數據一致性,容錯等,好比把 服務器 加入到 hash ring 之中,以及從ring之中取出服務器。
咱們把目前邏輯梳理一下,綜合看看。
如何使用ring hash,須要從 parasrv 提及。
咱們知道,paralg 是基礎API類,其中在 paralg 中有以下定義 以及 構建了 ps_obj , ps_obj是一個 parasrv 類型的實例。
注:如下都是在worker端使用的類型。
// paralg 內代碼 parasrv *ps_obj; // 成員變量定義,參數服務器接口 paralg(paracel::str_type hosts_dct_str, paracel::Comm comm, paracel::str_type _output = "", int _rounds = 1, int _limit_s = 0, bool _ssp_switch = false) : worker_comm(comm), output(_output), nworker(comm.get_size()), rounds(_rounds), limit_s(_limit_s), ssp_switch(_ssp_switch) { ps_obj = new parasrv(hosts_dct_str); // 構建參數服務器,一個parasrv的實例 init_output(_output); clock = 0; stale_cache = 0; clock_server = 0; total_iters = rounds; if(worker_comm.get_rank() == 0) { paracel::str_type key = "worker_sz"; (ps_obj->kvm[clock_server]). push_int(key, worker_comm.get_size()); // 初始化時鐘服務器 } paracel_sync(); // mpi barrier同步一下 }
parasrv 的定義以下,其中 p_ring 就是 ring 實例,使用 p_ring = new paracel::ring<int>(servers)
來完成了構建。
其中p_ring 是 ring hash,kvm是具體的kv存儲列表。
class parasrv { using l_type = paracel::list_type<paracel::kvclt>; using dl_type = paracel::list_type<paracel::dict_type<paracel::str_type, paracel::str_type> >; public: parasrv(paracel::str_type hosts_dct_str) { // 初始化host信息,srv大小,kvm,servers,ring hash // init dct_lst dct_lst = paracel::get_hostnames_dict(hosts_dct_str); // init srv_sz srv_sz = dct_lst.size(); // init kvm for(auto & srv : dct_lst) { paracel::kvclt kvc(srv["host"], srv["ports"]); kvm.push_back(std::move(kvc)); } // init servers for(auto i = 0; i < srv_sz; ++i) { servers.push_back(i); } // init hashring p_ring = new paracel::ring<int>(servers); // 構建 } virtual ~parasrv() { delete p_ring; } public: dl_type dct_lst; int srv_sz = 1; l_type kvm; // 具體KV存儲接口 paracel::list_type<int> servers; paracel::ring<int> *p_ring; // ring hash }; // nested class parasrv
kvm 初始化以下:
// init kvm for(auto & srv : dct_lst) { paracel::kvclt kvc(srv["host"], srv["ports"]); kvm.push_back(std::move(kvc)); }
kvclt 是 kv control 的抽象。
只摘取部分代碼,就是找到對應的服務器進行交互。
namespace paracel { struct kvclt { public: kvclt(paracel::str_type hostname, paracel::str_type ports) : host(hostname), context(1) { ports_lst = paracel::str_split(ports, ','); conn_prefix = "tcp://" + host + ":"; } template <class V, class K> bool pull(const K & key, V & val) { // 從參數服務器拉取 if(p_pull_sock == nullptr) { p_pull_sock.reset(create_req_sock(ports_lst[0])); } auto scrip = paste(paracel::str_type("pull"), key); // paracel::str_type return req_send_recv(*p_pull_sock, scrip, val); } template <class K, class V> bool push(const K & key, const V & val) { // 往參數服務器推送 if(p_push_sock == nullptr) { p_push_sock.reset(create_req_sock(ports_lst[1])); } auto scrip = paste(paracel::str_type("push"), key, val); bool stat; auto r = req_send_recv(*p_push_sock, scrip, stat); return r && stat; } template <class V> bool req_send_recv(zmq::socket_t & sock, const paracel::str_type & scrip, V & val) { zmq::message_t req_msg(scrip.size()); std::memcpy((void *)req_msg.data(), &scrip[0], scrip.size()); sock.send(req_msg); zmq::message_t rep_msg; sock.recv(&rep_msg); paracel::packer<V> pk; if(!rep_msg.size()) { ERROR_ABORT("paracel internal error!"); } else { std::string data = paracel::str_type( static_cast<char*>(rep_msg.data()), rep_msg.size()); if(data == "nokey") return false; val = pk.unpack(data); } return true; } private: paracel::str_type host; paracel::list_type<paracel::str_type> ports_lst; paracel::str_type conn_prefix; zmq::context_t context; std::unique_ptr<zmq::socket_t> p_contains_sock = nullptr; std::unique_ptr<zmq::socket_t> p_pull_sock = nullptr; std::unique_ptr<zmq::socket_t> p_pull_multi_sock = nullptr; std::unique_ptr<zmq::socket_t> p_pullall_sock = nullptr; std::unique_ptr<zmq::socket_t> p_push_sock = nullptr; std::unique_ptr<zmq::socket_t> p_push_multi_sock = nullptr; std::unique_ptr<zmq::socket_t> p_update_sock = nullptr; std::unique_ptr<zmq::socket_t> p_bupdate_sock = nullptr; std::unique_ptr<zmq::socket_t> p_bupdate_multi_sock = nullptr; std::unique_ptr<zmq::socket_t> p_remove_sock = nullptr; std::unique_ptr<zmq::socket_t> p_clear_sock = nullptr; std::unique_ptr<zmq::socket_t> p_ssp_sock = nullptr; }; // struct kvclt } // namespace paracel
因此目前整體邏輯以下:
+------------------+ worker + server | paralg | | | | | | | | | parasrv *ps_obj | | | + | | +------------------+ | | | | | start_server | +------------------+ | | | | | | | | | | | v | | | +------------+-----+ +------------------+ +---------+ | | thrd_exec | | parasrv | |kvclt | | kvclt | | | | | | | | | | | | | | | | host | | | | | thrd_exec_ssp | | servers | | | | | | | | | | | ports_lst | | | | | | | kvm +-----------> | |.....| | | | ssp_tbl | | | | context | | | | | | | p_ring | | | | | | | | | + | | conn_prefix | | | | | tbl_store | | | | | | | | | | | +------------------+ | p_pull_sock+---+ | | | | | | | | | | | | | | | | p_push_sock | | | | | | | | | + | | | | | | | v | | | | | | | | | +------------+------+ +------------------+ | +---------+ | | | | ring | | | | +---+---+----------+ | | | | | ^ ^ | | | | | | | | srv_hashring | | +-----------------------+ | | | +------------------------------------+ | srv_hashring_dct | | | | | +-------------------+ +
手機以下:
★★★★★★關於生活和技術的思考★★★★★★
微信公衆帳號:羅西的思考
若是您想及時獲得我的撰寫文章的消息推送,或者想看看我的推薦的技術資料,敬請關注。