Session是TensorFlow先後端鏈接的橋樑。用戶利用session使得client可以與master的執行引擎創建鏈接,並經過session.run()來觸發一次計算。它創建了一套上下文環境,封裝了operation計算以及tensor求值的環境。前端
session建立時,系統會分配一些資源,好比graph引用、要鏈接的計算引擎的名稱等。故計算完畢後,須要使用session.close()關閉session,避免引發內存泄漏,特別是graph沒法釋放的問題。能夠顯式調用session.close(),或利用with上下文管理器,或者直接使用InteractiveSession。node
session之間採用共享graph的方式來提升運行效率。一個session只能運行一個graph實例,但一個graph能夠運行在多個session中。通常狀況下,建立session時若是不指定Graph實例,則會使用系統默認Graph。常見狀況下,咱們都是使用一個graph,即默認graph。當session建立時,不會從新建立graph實例,而是默認graph引用計數加1。當session close時,引用計數減1。只有引用計數爲0時,graph纔會被回收。這種graph共享的方式,大大減小了graph建立和回收的資源消耗,優化了TensorFlow運行效率。python
op運算和tensor求值時,若是沒有指定運行在哪一個session中,則會運行在默認session中。經過session.as_default()能夠將本身設置爲默認session。但我的建議最好仍是經過session.run(operator)和session.run(tensor)來進行op運算和tensor求值。c++
operation.run()後端
operation.run()等價於tf.get_default_session().run(operation)api
@tf_export("Operation") class Operation(object): # 經過operation.run()調用,進行operation計算 def run(self, feed_dict=None, session=None): _run_using_default_session(self, feed_dict, self.graph, session) def _run_using_default_session(operation, feed_dict, graph, session=None): # 沒有指定session,則獲取默認session if session is None: session = get_default_session() # 最終仍是經過session.run()進行運行的。tf中任何運算,都是經過session來run的。 # 經過session來創建client和master的鏈接,並將graph發送給master,master再進行執行 session.run(operation, feed_dict)
tensor.eval()數組
tensor.eval()等價於tf.get_default_session().run(tensor), 以下session
@tf_export("Tensor") class Tensor(_TensorLike): # 經過tensor.eval()調用,進行tensor運算 def eval(self, feed_dict=None, session=None): return _eval_using_default_session(self, feed_dict, self.graph, session) def _eval_using_default_session(tensors, feed_dict, graph, session=None): # 若是沒有指定session,則獲取默認session if session is None: session = get_default_session() return session.run(tensors, feed_dict)
默認session的管理app
tf經過運行時維護的session本地線程棧,來管理默認session。故不一樣的線程會有不一樣的默認session,默認session是線程做用域的。框架
# session棧 _default_session_stack = _DefaultStack() # 獲取默認session的接口 @tf_export("get_default_session") def get_default_session(): return _default_session_stack.get_default() # _DefaultStack默認session棧是線程相關的 class _DefaultStack(threading.local): # 默認session棧的建立,其實就是一個list def __init__(self): super(_DefaultStack, self).__init__() self._enforce_nesting = True self.stack = [] # 獲取默認session def get_default(self): return self.stack[-1] if len(self.stack) >= 1 else None
session類圖
會話Session的UML類圖以下
分爲兩種類型,普通Session和交互式InteractiveSession。InteractiveSession和Session基本相同,區別在於
Session和InteractiveSession的代碼邏輯很少,主要邏輯均在其父類BaseSession中。主要代碼以下
@tf_export('Session') class Session(BaseSession): def __init__(self, target='', graph=None, config=None): # session建立的主要邏輯都在其父類BaseSession中 super(Session, self).__init__(target, graph, config=config) self._default_graph_context_manager = None self._default_session_context_manager = None
@tf_export('InteractiveSession') class InteractiveSession(BaseSession): def __init__(self, target='', graph=None, config=None): self._explicitly_closed = False # 將本身設置爲default session self._default_session = self.as_default() self._default_session.enforce_nesting = False # 自動調用上下文管理器的__enter__()方法 self._default_session.__enter__() self._explicit_graph = graph def close(self): super(InteractiveSession, self).close() ## 省略無關代碼 ## 自動調用上下文管理器的__exit__()方法,避免內存泄漏 self._default_session.__exit__(None, None, None) self._default_session = None
BaseSession
BaseSession基本包含了全部的會話實現邏輯。包括會話的整個生命週期,也就是建立 執行 關閉和銷燬四個階段。生命週期後面詳細分析。BaseSession包含的主要成員變量有graph引用,序列化的graph_def, 要鏈接的tf引擎target,session配置信息config等。
在後端master中,根據前端client調用tf.Session(target='', graph=None, config=None)時指定的target,來建立不一樣的Session。target爲要鏈接的tf後端執行引擎,默認爲空字符串。Session建立採用了抽象工廠模式,若是爲空字符串,則建立本地DirectSession,若是以grpc://開頭,則建立分佈式GrpcSession。類圖以下
DirectSession只能利用本地設備,將任務建立到本地的CPU GPU上。而GrpcSession則能夠利用遠端分佈式設備,將任務建立到其餘機器的CPU GPU上,而後經過grpc協議進行通訊。grpc協議是谷歌發明並開源的遠程通訊協議。
Session做爲先後端鏈接的橋樑,以及上下文運行環境,其生命週期尤爲關鍵。大體分爲4個階段
session.__del__()
進行回收。生命週期方法入口基本都在前端Python的BaseSession中,它會經過swig自動生成的函數符號映射關係,調用C層的實現。
5.1 建立
先從BaseSession類的init方法看起,只保留了主要代碼。
def __init__(self, target='', graph=None, config=None): # graph表示構建的圖。TensorFlow的一個session會對應一個圖。這個圖包含了全部涉及到的算子 # graph若是沒有設置(一般都不會設置),則使用默認graph if graph is None: self._graph = ops.get_default_graph() else: self._graph = graph self._opened = False self._closed = False self._current_version = 0 self._extend_lock = threading.Lock() # target爲要鏈接的tf執行引擎 if target is not None: self._target = compat.as_bytes(target) else: self._target = None self._delete_lock = threading.Lock() self._dead_handles = [] # config爲session的配置信息 if config is not None: self._config = config self._add_shapes = config.graph_options.infer_shapes else: self._config = None self._add_shapes = False self._created_with_new_api = ops._USE_C_API # 調用C層來建立session self._session = None opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) self._session = tf_session.TF_NewSession(self._graph._c_graph, opts, status)
BaseSession先進行成員變量的賦值,而後調用TF_NewSession來建立session。TF_NewSession()方法由swig自動生成,在bazel-bin/tensorflow/python/pywrap_tensorflow_internal.py中
def TF_NewSession(graph, opts, status): return _pywrap_tensorflow_internal.TF_NewSession(graph, opts, status)
_pywrap_tensorflow_internal包含了C層函數的符號表。在swig模塊import時,會加載pywrap_tensorflow_internal.so動態連接庫,從而獲得符號表。在pywrap_tensorflow_internal.cc中,註冊了供Python調用的函數的符號表,從而實現Python到C的函數映射和調用。
// c++函數調用的符號表,Python經過它能夠調用到C層代碼。符號表和動態連接庫由swig自動生成 static PyMethodDef SwigMethods[] = { // .. 省略其餘函數定義 // TF_NewSession的符號表,經過這個映射,Python中就能夠調用到C層代碼了。 { (char *)"TF_NewSession", _wrap_TF_NewSession, METH_VARARGS, NULL}, // ... 省略其餘函數定義 }
最終調用到c_api.c中的TF_NewSession()
// TF_NewSession建立session的新實現,在C層後端代碼中 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Status* status) { Session* session; // 建立session status->status = NewSession(opt->options, &session); if (status->status.ok()) { TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { // 採用了引用計數方式,多個session共享一個圖實例,效率更高。 // session建立時,引用計數加1。session close時引用計數減1。引用計數爲0時,graph纔會被回收。 mutex_lock l(graph->mu); graph->sessions[new_session] = Status::OK(); } return new_session; } else { DCHECK_EQ(nullptr, session); return nullptr; } }
session建立時,並建立graph,而是採用共享方式,只是引用計數加1了。這種方式減小了session建立和關閉時的資源消耗,提升了運行效率。NewSession()根據前端傳遞的target,使用sessionFactory建立對應的TensorFlow::Session實例。
Status NewSession(const SessionOptions& options, Session** out_session) { SessionFactory* factory; const Status s = SessionFactory::GetFactory(options, &factory); // 經過sessionFactory建立多態的Session。本地session爲DirectSession,分佈式爲GRPCSession *out_session = factory->NewSession(options); if (!*out_session) { return errors::Internal("Failed to create session."); } return Status::OK(); }
建立session採用了抽象工廠模式。根據client傳遞的target,來建立不一樣的session。若是target爲空字符串,則建立本地DirectSession。若是以grpc://開頭,則建立分佈式GrpcSession。TensorFlow包含本地運行時和分佈式運行時兩種運行模式。
下面來看DirectSessionFactory的NewSession()方法
class DirectSessionFactory : public SessionFactory { public: Session* NewSession(const SessionOptions& options) override { std::vector<Device*> devices; // job在本地執行 const Status s = DeviceFactory::AddDevices( options, "/job:localhost/replica:0/task:0", &devices); if (!s.ok()) { LOG(ERROR) << s; return nullptr; } DirectSession* session = new DirectSession(options, new DeviceMgr(devices), this); { mutex_lock l(sessions_lock_); sessions_.push_back(session); } return session; }
GrpcSessionFactory的NewSession()方法就不詳細分析了,它會將job任務建立在分佈式設備上,各job經過grpc協議通訊。
5.2 運行
經過session.run()能夠啓動graph的執行。入口在BaseSession的run()方法中, 一樣只列出關鍵代碼
class BaseSession(SessionInterface): def run(self, fetches, feed_dict=None, options=None, run_metadata=None): # fetches能夠爲單個變量,或者數組,或者元組。它是圖的一部分,能夠是操做operation,也能夠是數據tensor,或者他們的名字String # feed_dict爲對應placeholder的實際訓練數據,它的類型爲字典 result = self._run(None, fetches, feed_dict, options_ptr,run_metadata_ptr) return result def _run(self, handle, fetches, feed_dict, options, run_metadata): # 建立fetch處理器fetch_handler fetch_handler = _FetchHandler( self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles) # 通過不一樣類型的fetch_handler處理,獲得最終的fetches和targets # targets爲要執行的operation,fetches爲要執行的tensor _ = self._update_with_movers(feed_dict_tensor, feed_map) final_fetches = fetch_handler.fetches() final_targets = fetch_handler.targets() # 開始運行 if final_fetches or final_targets or (handle and feed_dict_tensor): results = self._do_run(handle, final_targets, final_fetches, feed_dict_tensor, options, run_metadata) else: results = [] # 輸出結果到results中 return fetch_handler.build_results(self, results) def _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata): # 將要運行的operation添加到graph中 self._extend_graph() # 執行一次運行run,會調用底層C來實現 return tf_session.TF_SessionPRunSetup_wrapper( session, feed_list, fetch_list, target_list, status) # 將要運行的operation添加到graph中 def _extend_graph(self): with self._extend_lock: if self._graph.version > self._current_version: # 生成graph_def對象,它是graph的序列化表示 graph_def, self._current_version = self._graph._as_graph_def( from_version=self._current_version, add_shapes=self._add_shapes) # 經過TF_ExtendGraph將序列化後的graph,也就是graph_def傳遞給後端 with errors.raise_exception_on_not_ok_status() as status: tf_session.TF_ExtendGraph(self._session, graph_def.SerializeToString(), status) self._opened = True
邏輯仍是十分複雜的,主要有一下幾步
咱們分別來看extend和run。
5.2.1 extend添加節點到graph中
TF_ExtendGraph()會調用到c_api中,這個邏輯一樣經過swig工具自動生成。下面看c_api.cc中的TF_ExtendGraph()方法
// 增長節點到graph中,proto爲序列化後的graph void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto, size_t proto_len, TF_Status* status) { GraphDef g; // 先將proto反序列化,獲得client傳遞的graph,放入g中 if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) { status->status = InvalidArgument("Invalid GraphDef"); return; } // 再調用session的extend方法。根據建立的不一樣session類型,多態調用不一樣方法。 status->status = s->session->Extend(g); }
後端系統根據生成的Session類型,多態的調用Extend方法。若是是本地session,則調用DirectSession的Extend()方法。若是是分佈式session,則調用GrpcSession的相關方法。下面來看GrpcSession的Extend方法。
Status GrpcSession::Extend(const GraphDef& graph) { CallOptions call_options; call_options.SetTimeout(options_.config.operation_timeout_in_ms()); return ExtendImpl(&call_options, graph); } Status GrpcSession::ExtendImpl(CallOptions* call_options, const GraphDef& graph) { bool handle_is_empty; { mutex_lock l(mu_); handle_is_empty = handle_.empty(); } if (handle_is_empty) { // 若是graph句柄爲空,則代表graph尚未建立好,此時extend就等同於create return Create(graph); } mutex_lock l(mu_); ExtendSessionRequest req; req.set_session_handle(handle_); *req.mutable_graph_def() = graph; req.set_current_graph_version(current_graph_version_); ExtendSessionResponse resp; // 調用底層實現,來添加節點到graph中 Status s = master_->ExtendSession(call_options, &req, &resp); if (s.ok()) { current_graph_version_ = resp.new_graph_version(); } return s; }
Extend()方法中要注意的一點是,若是是首次執行Extend(), 則要先調用Create()方法進行graph的註冊。不然纔是執行添加節點到graph中。
5.2.2 run執行圖的計算
一樣,Python經過swig自動生成的代碼,來實現對C API的調用。C層實如今c_api.cc的TF_Run()中。
// session.run()的C層實現 void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options, // Input tensors,輸入的數據tensor const char** c_input_names, TF_Tensor** c_inputs, int ninputs, // Output tensors,運行計算後輸出的數據tensor const char** c_output_names, TF_Tensor** c_outputs, int noutputs, // Target nodes,要運行的節點 const char** c_target_oper_names, int ntargets, TF_Buffer* run_metadata, TF_Status* status) { // 省略一段代碼 TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names, c_outputs, target_oper_names, run_metadata, status); } // 真正的實現了session.run() static void TF_Run_Helper() { RunMetadata run_metadata_proto; // 調用不一樣的session實現類的run方法,來執行 result = session->Run(run_options_proto, input_pairs, output_tensor_names, target_oper_names, &outputs, &run_metadata_proto); // 省略代碼 }
最終會調用建立的session來執行run方法。DirectSession和GrpcSession的Run()方法會有所不一樣。後面很複雜,就不接着分析了。
5.3 關閉session
經過session.close()來關閉session,釋放相關資源,防止內存泄漏。
class BaseSession(SessionInterface): def close(self): tf_session.TF_CloseSession(self._session, status)
會調用到C API的TF_CloseSession()方法。
void TF_CloseSession(TF_Session* s, TF_Status* status) { status->status = s->session->Close(); }
最終根據建立的session,多態的調用其Close()方法。一樣分爲DirectSession和GrpcSession兩種。
::tensorflow::Status DirectSession::Close() { cancellation_manager_->StartCancel(); { mutex_lock l(closed_lock_); if (closed_) return ::tensorflow::Status::OK(); closed_ = true; } // 註銷session if (factory_ != nullptr) factory_->Deregister(this); return ::tensorflow::Status::OK(); }
DirectSessionFactory中的Deregister()方法以下
void Deregister(const DirectSession* session) { mutex_lock l(sessions_lock_); // 釋放相關資源 sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session), sessions_.end()); }
5.4 銷燬session
session的銷燬是由Python的GC自動執行的。python經過引用計數方法來判斷是否回收對象。當對象的引用計數爲0,且虛擬機觸發了GC時,會調用對象的__del__()
方法來銷燬對象。引用計數法有個很致命的問題,就是沒法解決循環引用問題,故會存在內存泄漏。Java虛擬機採用了調用鏈分析的方式來決定哪些對象會被回收。
class BaseSession(SessionInterface): def __del__(self): # 先close,防止用戶沒有調用close() try: self.close() # 再調用c api的TF_DeleteSession來銷燬session if self._session is not None: try: status = c_api_util.ScopedTFStatus() if self._created_with_new_api: tf_session.TF_DeleteSession(self._session, status)
c_api.cc中的相關邏輯以下
void TF_DeleteSession(TF_Session* s, TF_Status* status) { status->status = Status::OK(); TF_Graph* const graph = s->graph; if (graph != nullptr) { graph->mu.lock(); graph->sessions.erase(s); // 若是graph的引用計數爲0,也就是graph沒有被任何session持有,則考慮銷燬graph對象 const bool del = graph->delete_requested && graph->sessions.empty(); graph->mu.unlock(); // 銷燬graph對象 if (del) delete graph; } // 銷燬session和TF_Session delete s->session; delete s; }
TF_DeleteSession()會判斷graph的引用計數是否爲0,若是爲0,則會銷燬graph。而後銷燬session和TF_Session對象。經過Session實現類的析構函數,來銷燬session,釋放線程池Executor,資源管理器ResourceManager等資源。
DirectSession::~DirectSession() { for (auto& it : partial_runs_) { it.second.reset(nullptr); } // 釋放線程池Executor for (auto& it : executors_) { it.second.reset(); } for (auto d : device_mgr_->ListDevices()) { d->op_segment()->RemoveHold(session_handle_); } // 釋放ResourceManager for (auto d : device_mgr_->ListDevices()) { d->ClearResourceMgr(); } // 釋放CancellationManager實例 functions_.clear(); delete cancellation_manager_; // 釋放ThreadPool for (const auto& p_and_owned : thread_pools_) { if (p_and_owned.second) delete p_and_owned.first; } execution_state_.reset(nullptr); flib_def_.reset(nullptr); }
Session是TensorFlow的client和master鏈接的橋樑,client任何運算也是經過session來run。它是client端最重要的對象。在Python層和C++層,均有不一樣的session實現。session生命週期會經歷四個階段,create run close和del。四個階段均由Python前端開始,最終調用到C層後端實現。由此也能夠看到,TensorFlow框架的先後端分離和模塊化設計是多麼的精巧。
本文爲雲棲社區原創內容,未經容許不得轉載。