Control Flow in Tensorflow TF中的控制流解析

#寫在前面 本文翻譯自Tensorflow團隊的文章Tensorflow Control Flow Implementation,部份內容加入了筆者本身的理解,若有不妥之處還望各位指教。異步

目錄

  • 概覽
  • 控制流核心概念
  • 控制流結構的編譯
  • 條件表達式
  • while循環
  • 實現
  • 分佈式條件表達式
  • 分佈式while循環
  • 自動微分

概覽

本文將會介紹當前在Tensorflow中控制流操做的設計和實現。這是一篇基於原始設計的描述性文檔,設計的細節還請參考源代碼。分佈式

本文將要講述的內容是:函數

  • 介紹Tensorflow爲了處理控制流加入的5個核心的操做;
  • 展現高層的控制流是如何經過5個基礎操做融入數據流圖的;
  • 解釋加入了控制流的數據流圖是怎樣被Tensorflow運行時執行的,包括融合了多種設備(CPU,GPU,TPU)的分佈式執行方式;
  • 描述了對控制流結構如何自動求導;

控制流核心概念

Tensorflow中控制流的基礎設計理念是,經過引入少許的簡單基礎操做,爲多樣的Tensorflow應用提供豐富的控制流表達。咱們指望這些操做靈活且富有表現力,可以做爲高層的領域專用語言(DSL,Domain Specific Language)的編譯目標。它們須要很方便的嵌入Tensorflow目前的數據流模型中,而且能夠方便的進行並行的、分佈式的執行以及自動求導。本節將介紹這5種控制流相關的基本操做。它們與Dennis和Arvind在數據流機(dataflow machines)中引入的控制流操做很像。使用Switch和Merge可使咱們事先條件控制,將這5種基礎操做組合起來,可使咱們實現while循環。oop

圖1

在Tensorflow中,每個op都會在一個執行幀(execution frame)中被執行,控制流操做負責建立和管理這些執行幀。好比,對於while循環,Tensorflow的運行時會建立一個執行幀,而後將全部屬於該while循環的操做放在這個執行幀中執行。不一樣執行幀中的操做能夠並行執行,只要它們之間沒有數據依賴。url

Switch:一個Switch操做根據控制輸入p的布爾值,將一個輸入張量d推動到某一個輸出(二選一)。只有到Switch操做的兩個輸入都準備好以後,它纔會執行。spa

MergeMerge操做將它的其中一個輸入推向輸出。當一個Merge操做的任意一個輸入準備好以後,Merge操做就會執行。在多個輸入都準備好的狀況下,Merge操做的輸出不肯定。.net

Enter(name)Enter操做將它的輸入推向名爲name的執行幀。Enter操做其實是把一個執行幀的張量推向它的子執行幀。同一個子執行幀上可能會有多個Enter操做,它們將不一樣的張量推向子執行幀。當輸入準備好以後,Enter操做就會執行。一個新的執行幀在它的第一個Enter操做執行以後開始執行。翻譯

ExitExit操做,將一個張量從一個子執行幀推向它的父執行幀。它的做用是將張量從子執行幀返回給父執行幀。一個子執行幀可能有多個Exit操做指向父執行幀,每一個操做都會異步的將一個張量返回給父執行幀。當它的輸入準備好以後,Exit操做開始執行。設計

NextIterationNextIteration操做將一個張量從當前執行幀的一輪迭代傳遞到下一輪迭代。Tensorflow的運行時在執行幀內部保存了一個迭代輪數。任何一個在執行幀中執行的操做都有惟一的一個迭代輪數的屬性,它能夠幫助咱們分辨一個迭代運算中不一樣的執行輪次。注意在一個執行幀中可能會有多個NextIteration操做。當執行幀的第N輪執行的第一個NextIteration操做開始執行時,Tensorflow的運行時開始執行第N+1輪的迭代。當更多的張量經過了NextIteration操做進入新的執行輪次時,新執行輪次中更多的操做就會開始運行。當輸入準備完成以後,NextIteration操做開始執行。code

控制流結構的編譯

有了這5種基礎的操做,高級的程序部件,例如條件表達式和whiile循環就能夠被編譯進入數據流圖,而後被Tensorflow的運行時執行。下面咱們來看一下條件表達式和while循環是如何在Tensorflow內部實現的。

條件表達式

如下是構建條件表達式cond(pred, fn1, fn2)的數據流圖的高層僞代碼。爲了簡化,咱們忽略了實際使用中的細節,讀者能夠在control_flow_ops.py中找到實現細節:

//構建true分支圖
context_t = CondContext(pred, branch=1)
res_t = context_t.Call(fn1)

//構建false分支圖
context_t = CondContext(pred, branch=0)
res_f = context_f.Call(fn2)

//爲輸出添加Merge節點
merges = [Merge([f,t]) for (f,t) in zip(res_f, res_t)]
return merges

對於條件表達式的每個分支,咱們建立了一個新的控制流上下文,而且在上下文中調用了圖構建的函數(fn1或者fn2)。條件上下文容許咱們獲取任意的外部張量(不在上下文中建立的),而且插入一個合適的Switch操做來保證它會進入一個分支。這就保證了,只有當這個分支被選擇時,它對應的操做纔會被執行。因爲Tensorflow是異步執行的,外部的張量可能在不一樣的時間到達,所以咱們爲每個外部張量準備了一個Switch操做來最大化並行度。

每一個分支都返回了張量的列表(res_t或者res_f),所以咱們又添加了一個Merge操做來對結果進行合併,這樣只要任何一個分支執行成功了,就能獲得輸出(前面講到,對於Merge操做,只要其中一個輸入準備好了,就會產生輸出)。

讓咱們來看一個簡單的例子:

圖2

tf.cond(x<y, lambda: tf.add(x,z), lambda: tf.square(y))

在生成的數據流圖中,Switch操做的插入是爲了控制x,y,z張量的流動。在true/false分支,只有Switch操做的true/false的輸出纔會被使用。因爲Add操做的輸入來自Switch操做的true分支,所以只有x小於y時,Add操做纔會被執行。一樣的,只有x大於等於y時,Square操做纔會被執行。最終Merge操做發送Add或者Square的結果。若是條件表達式有多個結果,那麼將會有多個Merge操做,每一個結果對應一個Merge操做。

固然,利用Switch和Merge操做實現條件表達式還有不少方法,咱們選擇當前的實現,主要是由於它更容易進行自動求導。

while循環

如下是構建數據流圖中while循環的高層僞代碼:

while_context = WhileContext()
while_context.Enter()

//爲每個循環變量添加Enter節點
enter_vars = [Enter(x, frame_name) for x in loop_vars]

//添加Merge節點,注意input[1]將會在後面被迭代
merge_vars = [Merge([x,x]) for x in enter_vars]

//構建循環條件子圖
pred_result = pred(*merge_vars)

//添加Switch節點
switch_vars = [Switch(x, pred_result) for x in merge_vars]

//構建循環體子圖
body_result = body(*[x[1] for x in switch_vars])

//添加NextIteration節點
next_vars = [NextIteration(x) for x in body_result]

//構建循環
for m,v in zip(merge_vars, next_vars):
    m.op._update_input(1,v)

//添加Exit節點
exit_vars = [Exit(x[0]) for x in switch_vars]
while_context.Exit()
return exit_vars

整個while循環圖建立在while循環的控制流上下文中。整個思路比較簡單。

從循環變量開始,咱們爲它們分別添加一個Enter操做和一個Merge操做。咱們使用它們的結果(merge_vars)來構建判斷子圖,從而計算循環終止條件。

在添加了Switch操做以後,咱們使用Switch操做的true分支來構建循環體子圖。循環體的結果須要進入下一輪迭代,所以咱們添加了一個NextIteration操做,而且將其輸出指向Merge操做的第二個輸入,這樣就造成了循環,容許咱們在執行圖是不斷的運行一樣的一組操做。

Switch操做的false輸出是整個while循環的輸出,所以咱們在它後面加入了Exit操做,來返回運算結果。與條件表達式相似,while循環的上下文被用來追蹤在pred和lambda中使用的外部張量。這些外部張量被看作是循環常數,咱們自動爲每個外部張量插入了一個Enter操做,使它在while循環的上下文內部可以被訪問。嵌套的循環須要添加嵌套的Enter操做。

一樣的,讓咱們看一個簡單的例子:

圖3

tf.while_loop(lambda i:i<10, lambda i: tf.add(i,1),[0])

如上圖所示,咱們只有一個循環變量。若是有多個循環變量,咱們須要添加多個Enter,Merge,Switch,NextIteration和Exit操做。這使得跨循環和跨迭代輪次的執行成爲可能。你可能注意到咱們省略了常量的表示方法,若是你想要理解更深層次的細節,請查看源代碼。

這種對於條件表達式和while循環的支持,使得咱們能夠表達任意嵌套的條件和循環。例如,一個循環體內可能嵌套着另一個循環體。TF保證每一個循環被賦予了一個惟一的幀名稱。

實現

Tensorflow的運行時負責對數據流圖進行執行。下面咱們先來對此作一個快速的概覽。

爲了在多臺設備上運行,TF自動將計算操做分配到不一樣的設備上。基於設備分配,TF自動的將數據流圖劃分紅子圖,每臺設備有一個子圖對應。當數據流圖的一條邊被圖分割切段時(邊兩側的節點分配在兩臺不一樣的設備上),咱們自動的插入一對send和recv節點,以便在設備間傳輸數據。一對send和recv節點經過一個惟一的鍵實現通訊,recv節點主動的從send節點拉取數據。例如,如下就是將原圖分割到兩臺設備後的結果。TF對於分割沒有添加任何限制,只要一個節點可以在一臺設備上進行運算,就能夠被分配到這臺設備。

圖4

若是一個子圖被分配到一個設備上運行,那麼這個設備將會使用隸屬於它的執行器來執行這個子圖。執行器從source節點開始,依次執行已經準備好的節點。除了Merge節點以外,對於任何一個其餘節點來講,只要它的輸入準備好了,這個節點就能夠開始執行了。注意一張子圖中全部的recv節點都被認爲是source節點。

若是沒有控制流,圖執行的過程會很是的直接:每一個節點僅被執行一次,而且當全部節點都執行結束以後,整個圖的執行就完成了。控制流的引入帶來了必定的複雜性。有了控制流,一個節點可能被執行任意次(甚至包括0次)。執行器須要管理對於同一個節點的多個同時存在的執行實例,而且決定計算圖合適執行結束。

爲了追蹤計算中產生的張量,執行器中的張量被使用一個形如(value, is_dead, tag)的元組來標識,value是張量值,is_dead是一個布爾值,用來標識這個張量是否在一個未執行的條件分支上,tag是這個張量的惟一標識(產生張量的節點的執行實例)。本質上,tag定義了執行的上下文,在同一個執行上下文下,一個操做最多被執行一次。tag是send/recv之間通訊的鍵的一部分,用來辨識同一對send/recv節點的不一樣執行。

執行器遵循了以下的執行規則(注意,某個節點的全部輸入都必須包含一樣的tag)

Switch(p,d) = (r1,r2)
r1 = (value(d), p || is_dead(d),tag(d))
r2 = (value(d), !p || is_dead(d),tag(d))

Merge(d1,d2) = r
r = if is_dead(d1) then d2 else d1

Enter(d, frame_name) = r
value(r) = value(d)
is_dead(r) = is_dead(d)
tag(r) = tag(d)/frame_name/0

Exit(d) = r
value(r) = value(d)
is_dead(r) = is_dead(d)
tag(r) = tag1 where tag(d)=tag1/frame_name/n

NextIteration(d) = d1
value(d1) = value(d)
is_dead(d1) = is_dead(d)
tag(d1) = tag1/frame_name/(n+1) where tag(d) = tag1/frame_name/n

Op(d1,...,dm) = (r1,...,rn)
value(ri) = Op.Compute(value(d1),...,value(dm)) if !is_dead(ri)
is_dead(ri) = any(is_dead(d1),...,is_dead(dm)), for all i
tag(ri) = tag(d1), for all i

最後一個規則適用於全部的非控制流節點。注意只有當全部的輸入都有效時,計算纔會執行。若是有一個dead輸入,咱們將會跳過計算,而將dead信號傳遞下去。對於dead信號的傳遞將有助於支持控制流的分佈式執行。

分佈式條件表達式

對於分佈式執行來講,一個條件表達式可能被分配到了不一樣的設備上,以下圖所示:

圖5

因爲每個recv節點都是source節點,而且隨時可能會開始執行,在設備B上的recv節點甚至在出於未選擇的條件分支上時也會執行。爲了讓出於未選擇的分支上的recv節點的執行合理化,咱們將is_dead標籤經過send節點跨設備傳輸到recv節點。這種信息會一直跨越設備傳輸下去。這種簡單的傳輸機制使得在分佈式環境下的條件判斷更加天然,也有助於分佈式環境下的while循環。

分佈式的while循環

在分佈式環境下,一個while循環(特別是循環體),可能被分割到不一樣的設備上。若是咱們簡單的應用分割邏輯,而後在跨設備的節點之間插入send/recv,那麼設備上的局部執行器將缺乏準確執行while循環的信息。

圖6

讓咱們經過一個例子來認識這個問題。在上述例子中,Op在循環體中,而且被分配給了設備B。一個簡單的分割可能會在Switch和Op之間插入一對send/recv節點來執行跨設備的數據傳輸。然而,這樣是沒法工做的,由於設備B並不知道recv和Op操做是處在一個循環當中的,在執行完Op一次以後,設備B上的執行器就會認爲,它的工做已經完成了(從設備B的角度看,它只須要從recv獲取數據,執行Op,而後將結果經過send發送出去,執行就結束了)。解決方案是,重寫數據流圖,在while循環體分配到的每一個設備上,添加一個控制循環狀態機(以下圖中所示)。標量0被用來做爲Enter節點的輸入。

圖7

這些控制循環爲設備上的執行器提供了足夠的信息,使得它們能夠像之前同樣獨立的執行,同時經過send/recv與其它設備通訊。注意到圖中的虛線表明了控制輸入。

(具體執行過程分爲0次執行,和大於等於1次執行兩種狀況討論,這裏就不寫了,你們能夠自行分析)

注意到執行中有很是多的並行執行。例如,在接收到P以後,設備B能夠開始下一輪迭代,或者中止執行。一個設備可能同時存在並行的多個執行輪次,而且兩個不一樣的設備還能夠同時處在同一個循環的不一樣迭代輪次上。

這種while循環的分佈式執行方式帶來的開銷是,任何一個參與的設備都必須在每個迭代輪次裏,接收來自產生P的設備傳遞過來的布爾張量。因爲執行過程是高度並行的,這種開銷能夠忽略不計了。

下圖展現了當一個while循環被分割到不一樣的設備上時是什麼樣子。每一個分割的部分都被添加了一個控制循環結構,用來控制while循環內部的recv操做。重寫以後的新圖與原圖是語義等價的。

圖8

對於嵌套的while循環,咱們按照下圖所示的方式將控制循環堆疊起來。注意若是一臺設備僅包含了外層循環的節點,咱們不會在它上面添加與內層循環有關的控制循環結構。

圖9

自動微分

待補充。

相關文章
相關標籤/搜索