AutoGraph是TF提供的一個很是具備前景的工具, 它可以將一部分python語法的代碼轉譯成高效的圖表示代碼. 因爲從TF 2.0開始, TF將會默認使用動態圖(eager execution), 所以利用AutoGraph, 在理想狀況下, 能讓咱們實現用動態圖寫(方便, 靈活), 用靜態圖跑(高效, 穩定).html
可是! 在使用的過程當中, 如無心外確定是會有意外的, 這篇文章就是指出一些AutoGraph和tf.function的奇怪的行爲, 讓你更愉快地使用它們.python
本文假設讀者具備必定的Python和TensorFlow的使用經驗.git
對tf1.X有經驗的讀者應該不會對讓咱們又愛又恨的計算圖(tf.Graph
)和執行會話(tf.Session
)感到陌生, 一個常規的流程以下:github
y=tf.matmul(a, x) + b
)tf.Session
tf.Session
tf.Session.run
來執行計算圖的節點, 被執行的節點會反向追蹤全部依賴的須要執行的節點並執行計算.如下是上述過程的一個代碼例子:apache
g = tf.Graph() #初始化計算圖 with g.as_default(): # 設置爲默認計算圖 a = tf.constant([[10,10],[11.,1.]]) x = tf.constant([[1.,0.],[0.,1.]]) b = tf.Variable(12.) y = tf.matmul(a, x) + b # 描述計算圖 init_op = tf.global_variables_initializer() # 待執行節點 with tf.Session() as sess: # 配置會話 sess.run(init_op) # 執行節點 print(sess.run(y)) # 輸出結果
在TF 2.0中, 因爲默認爲動態圖, 計算會直接被執行, 也就是說, 咱們不須要緩存
tf.control_dependencies
來聲明節點的非直接依賴咱們能夠像寫普通python代碼(or pytorch)同樣, 寫了就執行:session
a = tf.constant([[10,10],[11.,1.]]) x = tf.constant([[1.,0.],[0.,1.]]) b = tf.Variable(12.) y = tf.matmul(a, x) + b print(y.numpy())
通常來講, eager代碼會比執行相同操做的靜態圖代碼的效率低, 由於不少計算圖優化的方法只能用在數據流圖上.函數
若是想在TF 2.0上構建傳統的計算圖, 咱們就須要用到tf.function
.工具
TF 2.0的其中一個重要改變就是去除tf.Session
(此處應有掌聲). 這個改變會迫使用戶用更好的方式來組織代碼: 不用再用讓人糾結的tf.Session
來執行代碼, 就是一個個python函數, 加上一個簡單的裝飾器.學習
在TF 2.0裏面, 若是須要構建計算圖, 咱們只須要給python函數加上@tf.function
的裝飾器.
上文提到靜態圖的執行效率更高, 可是加速並非必定的. 通常來講, 計算圖越複雜, 加速效果越明顯. 對於複雜的計算圖, 好比訓練深度學習模型, 得到的加速是巨大的. (譯者注: 我的感受仍是要結合實際來看, 若是某一部分的計算既有複雜的計算圖, 而計算圖的複雜性又帶來了額外的 內存消耗
或者計算量, 那麼加速會比較明顯, 可是不少時候, 好比通常的CNN模型, 主要計算量並不在於圖的複雜性, 而在於卷積、矩陣乘法等操做, 加速並不會很明顯. 此處想法有待驗證)
這個自動將python代碼轉成圖表示代碼的工具就叫作AutoGraph.
在TF 2.0中, 若是一個函數被@tf.function
裝飾了, 那麼AutoGraph將會被自動調用, 從而將python函數轉換成可執行的圖表示.
在第一次調用被@tf.function
裝飾的函數時, 下列事情將會發生:
tf.
API只會定義一個生成tf.Tensor
輸出的節點while
→tf.while
,for
→tf.while
,if
→tf.cond
,assert
→tf.assert
...)tf.control_dependencies
,以便在執行第i+1
行時確保第i
行已經被執行. 至此計算圖已經肯定map [id] = graph
下一節將會具體闡述如何將TF 1.X代碼塊分別改寫到eager和計算圖版本.
要使用tf.function
, 第一步須要先將TF 1.X的設計計算圖的代碼放進python函數裏面.
def f(): a = tf.constant([[10,10],[11.,1.]]) x = tf.constant([[1.,0.],[0.,1.]]) b = tf.Variable(12.) y = tf.matmul(a, x) + b return y
應爲TF 2.0默認是eager的, 咱們能夠直接執行該函數(不須要tf.Session
):
print(f().numpy())
咱們就會獲得輸出:
[[22. 22.] [23. 13.]]
咱們能夠直接用@tf.function
來裝飾函數f
, 咱們在原來f
的基礎上加上宇宙第一的debug大法: print
來更好地看看究竟發生了什麼.
@tf.function def f(): a = tf.constant([[10,10],[11.,1.]]) x = tf.constant([[1.,0.],[0.,1.]]) b = tf.Variable(12.) y = tf.matmul(a, x) + b print("PRINT: ", y) tf.print("TF-PRINT: ", y) return y f()
因此發生了什麼呢?
@tf.function
將函數f
包進了tensorflow.python.eager.def_function.Function
這個對象, 函數f
被賦予到了這個對象的.python_function
屬性.f()
被執行的時候, 計算圖會同時被構建, 可是計算不會執行, 所以咱們會獲得如下結果, tf.
的操做不會被執行:PRINT: Tensor("add:0", shape=(2, 2), dtype=float32)
ValueError: tf.function-decorated function tried to create variables on non-first call.
在 RFC: Functions, not Session裏面有個很是明確的指示:
State (liketf.Variable
objects) are only created the first time the function f is called. 狀態(好比tf.Variable
) 只會在函數被第一次調用時建立.
可是 Alexandre Passos指出, 在函數轉換成圖表示時, 咱們沒有辦法肯定tf.function
調用了多少次函數, 所以咱們在第一次調用函數f
時, 在圖構建的過程當中, 可能會被執行了屢次, 這就致使了上述錯誤.
形成這個錯誤的根源在於一樣的命令在動態圖和靜態圖中的不一致性. 在動態圖中, tf.Variable
時一個普通的python變量, 超出了其做用域範圍就會被銷燬. 而在靜態圖中, tf.Variable
則是計算圖中一個持續存在的節點, 不受python的做用域的影響. 所以, 這是使用tf.function
的第一個教訓:
將一個在動態圖中可行的函數轉換成靜態圖須要用靜態圖的方式思考該函數是否可行
那麼咱們能夠怎樣去規避這個錯誤呢?
tf.Variable
做爲函數的參數傳入tf.Variable
tf.Variable
做爲類屬性來調用這裏指方法2和方法3. 顯然的, 咱們推薦使用方法3:
class F(): def __init__(self): self._b = None @tf.function def __call__(self): a = tf.constant([[10, 10], [11., 1.]]) x = tf.constant([[1., 0.], [0., 1.]]) if self._b is None: self._b = tf.Variable(12.) y = tf.matmul(a, x) + self._b print("PRINT: ", y) tf.print("TF-PRINT: ", y) return y f = F() f()
咱們以後會看到, 咱們並不能隨意地用tf.function
來轉化eager的代碼並達到加速的目的, 咱們須要想象一下轉化是怎麼完成的, 在轉python的代碼到圖操做的時候究竟發生了什麼, 這些轉化包含了什麼黑魔法. 這裏的例子比較簡單, 咱們會在接下來的文章中更深刻的探討.
@tf.function def f(b): a = tf.constant([[10,10],[11.,1.]]) x = tf.constant([[1.,0.],[0.,1.]]) y = tf.matmul(a, x) + b print("PRINT: ", y) tf.print("TF-PRINT: ", y) return y b = tf.Variable(12.) f(b)
上述函數會獲得咱們想要的結果, 另外, 做爲參數被傳入的變量可以在函數中直接更新, 而更新後的值會在函數外也適用. 下面的代碼會打印出1,2,3
a = tf.Variable(0) @tf.function def g(x): x.assign_add(1) return x print(g(a)) print(g(a)) print(g(a))
@tf.function
裝飾器來將python代碼轉成圖表示代碼tf.Variable
在以後的部分咱們會更加深刻地探討輸入參數類型對效率的影響, 以及python操做的轉換細節.
聲明: 本文翻譯自Paolo Galeone的博客, 已取得做者的贊成, 如需轉載本文請聯繫本人
Disclaimer: This is a translation of the article Analyzing tf.function to discover AutoGraph strengths and subtleties by Paolo Galeone.