在TensorFlow 2.0中,默認狀況下啓用了急切執行。 對於用戶而言直觀且靈活(運行一次性操做更容易,更快),但這可能會犧牲性能和可部署性。node
要得到最佳性能並使模型可在任何地方部署,能夠優先使用tf.function從程序中構建圖。 由於有AutoGraph,可使用tf.function構建高效性能的Python代碼,但仍有一些陷阱須要警戒。python
今天咱們就來介紹一下tensorflow2.0中的TF fuction和AutoGraph。編程
下面的輔助程序代碼,用於演示可能遇到的各類錯誤。緩存
import contextlib安全
# 構建包含上下文管理器的函數,使其能夠在with中使用app
@contextlib.contextmanagerless
def assert_raises(error_class):dom
try:異步
yield分佈式
except error_class as e:
print('Caught expected exception \n {}: {}'.format(error_class, e))
except Exception as e:
print('Got unexpected exception \n {}: {}'.format(type(e), e))
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))
tf.function
一個tf.function定義就像是一個核心TensorFlow操做:能夠急切地執行它; 也能夠在靜態圖中使用它; 且它具備梯度。
# 相似一個tensorflow操做
@tf.function
def add(a, b):
return a+b
add(tf.ones([2,2]), tf.ones([2,2]))
array([[2., 2.],
[2., 2.]], dtype=float32)>
# tf.function操做能夠計算梯度
@tf.function
def add(a, b):
return a+b
v = tf.Variable(2.0)
with tf.GradientTape() as tape:
res = add(v, 1.0)
tape.gradient(res, v)
# 能夠內嵌調用tf.function
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
array([[3., 3.],
[3., 3.],
[3., 3.]], dtype=float32)>
跟蹤和多態
Python的動態類型意味着可使用各類參數類型調用函數,Python將在每一個場景中執行不一樣的操做。
另外一方面,TensorFlow圖須要靜態dtypes和形狀尺寸。tf.function經過在必要時回溯函數來生成正確的圖結構來彌補這一差距。大多數使用的tf.function源於這種迴歸行爲。
咱們可使用不一樣類型的參數調用函數來查看正在發生的事情。
# 函數的多態
@tf.function
def double(a):
print('追蹤變量:',a)
return a + a
print('結果:',double(tf.constant(1)))
print()
print('結果:',double(tf.constant(1.1)))
print()
print('結果:',double(tf.constant('c')))
print()
追蹤變量: Tensor("a:0", shape=(), dtype=int32)
結果: tf.Tensor(2, shape=(), dtype=int32)
追蹤變量: Tensor("a:0", shape=(), dtype=float32)
結果: tf.Tensor(2.2, shape=(), dtype=float32)
追蹤變量: Tensor("a:0", shape=(), dtype=string)
結果: tf.Tensor(b'cc', shape=(), dtype=string)
控制參數類型:
建立一個新的tf.function。tf.function確保單獨的對象不共享追蹤。
使用該get_concrete_function方法獲取特定追蹤
指定input_signature什麼時候調用tf.function以確保僅構建一個功能圖。
print('構建許可的追蹤')
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("執行追蹤函數")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
print("使用不合法參數")
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
構建許可的追蹤
追蹤變量: Tensor("a:0", dtype=string)
執行追蹤函數
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
使用不合法參數
Caught expected exception
: cannot compute __inference_double_98 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_98]
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(tf.equal(x % 2, 0), x // 2, 3 * x + 1)
print(next_collatz(tf.constant([1, 2])))
# 只能輸入1維向量
with assert_raises(ValueError):
next_collatz(tf.constant([[1, 2], [3, 4]]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception
: Python inputs incompatible with input_signature: inputs ((
array([[1, 2],
[3, 4]], dtype=int32)>,)), input_signature ((TensorSpec(shape=(None,), dtype=tf.int32, name=None),))
何時回溯?
多態tf.function經過跟蹤生成具體函數的緩存。緩存鍵其實是從函數args和kwargs生成的鍵的元組。爲tf.Tensor參數生成的關鍵是其形狀和類型。爲Python原語生成的密鑰是它的值。對於全部其餘Python類型,鍵都基於對象,id()以便爲每一個類的實例獨立跟蹤方法。未來,TensorFlow能夠爲Python對象添加更復雜的緩存,能夠安全地轉換爲張量。
使用Python參數仍是Tensor參數?
一般,Python的參數被用來控制超參數和圖的結構-例如,num_layers=10或training=True或nonlinearity=‘relu’。所以,若是Python參數發生變化,那麼必須回溯圖。
可是,Python參數可能不會用於控制圖構造。在這些狀況下,Python值的變化可能會觸發沒必要要的回溯。舉例來講,這個訓練循環,AutoGraph將動態展開。儘管存在多條跡線,但生成的圖其實是相同的,所以這有點低效。
def train_one_step():
pass
@tf.function
def train(num_steps):
print("追蹤: num_steps = {}".format(num_steps))
for _ in tf.range(num_steps):
train_one_step()
train(num_steps=10)
train(num_steps=20)
追蹤: num_steps = 10
追蹤: num_steps = 20
# 使用tensor,同類型不會重複追蹤
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
追蹤: num_steps = Tensor("num_steps:0", shape=(), dtype=int32)
# 使用tensor,類型不一樣纔會有新的追蹤,(前一個單元格已追蹤int型,因此該處不追蹤)
train(num_steps=tf.constant(10, dtype=tf.int32))
train(num_steps=tf.constant(20.6))
追蹤: num_steps = Tensor("num_steps:0", shape=(), dtype=float32)
反作用 tf.function
一般,Python反作用(如打印或變異對象)僅在跟蹤期間發生。怎麼能可靠地觸發反作用tf.function呢?
通常的經驗法則是僅使用Python反作用來調試跟蹤。可是,TensorFlow操做相似於tf.Variable.assign,tf.print和tf.summary是確保TensorFlow運行時,在每次調用時跟蹤和執行代碼的最佳方法。一般使用功能樣式將產生最佳結果。
tf.function函數中的print()被用於跟蹤,因此要調試輸出每次調用(反作用),就須要tf.function()
@tf.function
def f(x):
print("追蹤:", x)
tf.print('執行:', x)
f(1)
f(1)
f(2)
追蹤: 1
執行: 1
執行: 1
追蹤: 2
執行: 2
若是想在每次調用期間執行Python代碼tf.function,可使用tf.py_function。tf.py_function缺點是它不便攜和高效,也不能在分佈式(多GPU,TPU)設置中很好地工做。此外,因爲tf.py_function必須鏈接到圖,它將全部輸入/輸出轉換爲張量。
external_list = []
def side_effect(x):
print('Python side effect')
external_list.append(x)
@tf.function
def f(x):
tf.py_function(side_effect, inp=[x], Tout=[])
f(1)
f(1)
f(1)
print(external_list)
WARNING: Logging before flag parsing goes to stderr.
W0609 06:41:05.048375 139792217777920 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32
W0609 06:41:05.053524 139792217777920 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32
W0609 06:41:05.056409 139792226170624 backprop.py:842] The dtype of the watched tensor must be floating (e.g. tf.float32), got tf.int32
Python side effect
Python side effect
Python side effect
[, , ]
謹防Python狀態
許多Python功能(如生成器和迭代器)依賴於Python運行時來跟蹤狀態。 一般,雖然這些構造在Eager模式下按預期工做,但因爲跟蹤行爲,tf.function內部可能會發生許多意外狀況。
external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
external_var.assign_add(next(iterator))
tf.print('external_var:', external_var)
iterator = iter([0,1,2,3])
buggy_consume_next(iterator)
# 後面沒有正常迭代,輸出的都是第一個
buggy_consume_next(iterator)
buggy_consume_next(iterator)
external_var: 0
external_var: 0
external_var: 0
若是在tf.function中生成並徹底使用了迭代器,那麼它應該能夠正常工做。可是,整個迭代器可能正在被跟蹤,這可能致使一個巨大的圖。若是正在訓練一個表示爲Python列表的大型內存數據集,那麼這會生成一個很是大的圖,而且tf.function不太可能產生加速。
若是要迭代Python數據,最安全的方法是將其包裝在tf.data.Dataset中並使用該for x in y慣用法。AutoGraph特別支持tf.data.Dataset 時安全地轉換循環。
def measure_graph_size(f, *args):
g = f.get_concrete_function(*args).graph
print("{}({}) 的圖中包含了 {} 個節點".format(
f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))
@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss
small_data = [(1, 1)] * 2
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1)]) 的圖中包含了 8 個節點
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) 的圖中包含了 32 個節點
train(, ), types: (tf.int32, tf.int32)>) 的圖中包含了 4 個節點
train(, ), types: (tf.int32, tf.int32)>) 的圖中包含了 4 個節點
在數據集中包裝Python / Numpy數據時,請注意tf.data.Dataset.from_generator與tf.data.Dataset.from_tensors。前者將數據保存在Python中並經過tf.py_function它獲取性能影響,然後者將數據的副本捆綁爲圖中的一個大tf.constant()節點,這可能會對內存產生影響。
經過TFRecordDataset / CsvDataset / etc從文件中讀取數據。是最有效的數據處理方式,由於TensorFlow自己能夠管理數據的異步加載和預取,而沒必要涉及Python。
自動控制依賴項
在通常數據流圖上,做爲編程模型的函數的一個很是吸引人的特性是函數能夠爲運行時提供有關代碼的預期行爲的更多信息。
例如,當編寫具備多個讀取和寫入相同變量的代碼時,數據流圖可能不會天然地編碼最初預期的操做順序。在tf.function,咱們經過引用原始Python代碼中的語句的執行順序來解決執行順序中的歧義。這樣,有序狀態操做的排序tf.function複製了Eager模式的語義。
這意味着不須要添加手動控制依賴項; tf.function足夠聰明,能夠爲代碼添加最小的必要和充分的控制依賴關係,以便正確運行。
# 按順序自動執行
a = tf.Variable(1.0)
b = tf.Variable(2.0)
@tf.function
def f(x, y):
a.assign(y * b)
b.assign_add(x * a)
return a + b
f(1.0, 2.0)
變量
咱們可使用相同的想法來利用代碼的預期執行順序,使變量建立和利用變得很是容易tf.function。可是有一個很是重要的警告,即便用變量,能夠編寫在急切模式和圖形模式下表現不一樣的代碼。
具體來講,每次調用建立一個新變量時都會發生這種狀況。因爲跟蹤語義,tf.function每次調用都會重用相同的變量,可是eager模式會在每次調用時建立一個新變量。爲防止出現此錯誤,tf.function若是檢測到危險變量建立行爲,則會引起錯誤。
@tf.function
def f(x):
# tf.function會重複調用相同變量,而eager每次都會建立新的變量
v = tf.Variable(1.0)
v.assign_add(x)
return v
with assert_raises(ValueError):
f(1.0)
Caught expected exception
: in converted code:
:4 f *
v = tf.Variable(1.0)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:262 __call__
return cls._variable_v2_call(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
shape=shape)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:60 getter
return captured_getter(captured_previous, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:364 invalid_creator_scope
"tf.function-decorated function tried to create "
ValueError: tf.function-decorated function tried to create variables on non-first call.
不會報錯的方法是
v = tf.Variable(1.0) # 把變量拿到tf.function外面
@tf.function
def f(x):
return v.assign_add(x)
print(f(1.0)) # 2.0
print(f(2.0)) # 4.0
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)
也能夠在tf.function中建立變量,只要能夠保證這些變量僅在第一次執行函數時建立。
class C: pass
obj = C(); obj.v = None
@tf.function
def g(x):
if obj.v is None:
obj.v = tf.Variable(1.0)
return obj.v.assign_add(x)
print(g(1.0)) # 2.0
print(g(2.0)) # 4.0
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)
變量初始值設定項能夠依賴於函數參數和其餘變量的值。 咱們可使用與生成控制依賴關係相同的方法找出正確的初始化順序。
state = []
@tf.function
def fn(x):
if not state:
state.append(tf.Variable(2.0 * x))
state.append(tf.Variable(state[0] * 3.0))
return state[0] * x * state[1]
print(fn(tf.constant(1.0)))
print(fn(tf.constant(3.0)))
tf.Tensor(12.0, shape=(), dtype=float32)
tf.Tensor(36.0, shape=(), dtype=float32)
使用AutoGraph
該簽名庫徹底集成tf.function,它將改寫條件和循環依賴於張量在圖形動態運行。
tf.cond而且tf.while_loop繼續使用tf.function,可是當以命令式樣式編寫時,具備控制流的代碼一般更容易編寫和理解。
# 簡單的循環
@tf.function
def f(x):
# 直接用python中的while寫循環
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([5]))
[0.829342961 0.858322263 0.900950909 0.851897 0.530384183]
[0.680123031 0.695392191 0.716760576 0.692059278 0.485674709]
[0.591599405 0.601434886 0.614898741 0.599303305 0.450776756]
[0.53104496 0.538069844 0.547566235 0.536553681 0.422537297]
[0.486179501 0.491525501 0.498693913 0.490374774 0.399065822]
[0.451178908 0.455426365 0.461089343 0.454513818 0.379149348]
[0.422867566 0.426349223 0.430971652 0.425602287 0.361968517]
[0.399343461 0.402265817 0.406133026 0.401639521 0.346946776]
[0.379387051 0.381885976 0.385184318 0.381350905 0.333665]
[0.362175018 0.36434418 0.367201209 0.363880038 0.321810097]
[0.347128421 0.349034756 0.351541221 0.348627061 0.311142713]
[0.333826423 0.335519224 0.337741673 0.335157365 0.30147627]
[0.321954757 0.323471278 0.325459719 0.323147237 0.292663]
[0.311273336 0.312642276 0.314435244 0.312349856 0.284584]
[0.301595032 0.302838922 0.304466605 0.302573323 0.277142316]
[0.292771578 0.293908447 0.295394808 0.293665737 0.270258158]
[0.284683794 0.285728157 0.287092626 0.285505235 0.263865024]
[0.277234435 0.278198302 0.279456645 0.277992576 0.257907033]
[0.270343572 0.271236718 0.272402078 0.271046132 0.25233686]
[0.263944477 0.264775217 0.265858531 0.264597982 0.247114092]
[0.257981181 0.258756459 0.259766966 0.258591145 0.242203966]
[0.252406299 0.253132015 0.254077554 0.252977312 0.237576365]
[0.24717927 0.247860536 0.248747766 0.247715324 0.233205199]
[0.242265314 0.242906466 0.24374117 0.242769822 0.229067564]
[0.237634286 0.238239139 0.239026278 0.238110229 0.225143358]
[0.233259991 0.233831868 0.234575793 0.233709976 0.221414775]
[0.229119495 0.229661271 0.230365857 0.229545817 0.217866093]
[0.225192651 0.22570689 0.22637549 0.225597292 0.214483246]
[0.221461684 0.221950635 0.222586185 0.221846417 0.211253688]
[0.217910782 0.218376443 0.218981609 0.218277216 0.208166167]
[0.214525893 0.214970052 0.215547174 0.214875415 0.205210552]
[0.211294428 0.211718708 0.212269917 0.211628318 0.202377662]
[0.208205134 0.208611 0.209138155 0.20852454 0.199659243]
[0.205247864 0.205636591 0.206141427 0.2055538 0.197047815]
[0.20241344 0.202786222 0.203270242 0.202706844 0.194536477]
array([0.19969359, 0.2000515 , 0.2005161 , 0.19997531, 0.192119 ],
dtype=float32)>
print(f)
能夠檢查代碼簽名生成。 但感受就像閱讀彙編語言同樣。
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
print(tf.autograph.to_code(f))
def tf__f(x):
do_return = False
retval_ = ag__.UndefinedReturnValue()
def loop_test(x_1):
return ag__.converted_call('reduce_sum', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (x_1,), None) > 1
def loop_body(x_1):
ag__.converted_call('print', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (x_1,), None)
x_1 = ag__.converted_call('tanh', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (x_1,), None)
return x_1,
x, = ag__.while_stmt(loop_test, loop_body, (x,))
do_return = True
retval_ = x
cond = ag__.is_undefined_return(retval_)
def get_state():
return ()
def set_state(_):
pass
def if_true():
retval_ = None
return retval_
def if_false():
return retval_
retval_ = ag__.if_stmt(cond, if_true, if_false, get_state, set_state)
return retval_
AutoGraph:條件
AutoGraph會將if語句轉換爲等效的tf.cond調用。
若是條件是Tensor,則進行此替換。不然,在跟蹤期間執行條件。
# 測試
def test_tf_cond(f, *args):
# 獲取圖
g = f.get_concrete_function(*args).graph
if any(node.name=='cond' for node in g.as_graph_def().node):
print("{}({}) 使用 tf.cond.".format(
f.__name__, ', '.join(map(str, args))))
else:
print("{}({}) 正常執行.".format(
f.__name__, ', '.join(map(str, args))))
只有條件爲tensor,纔會使用tf.cond
@tf.function
def hyperparam_cond(x, training=True):
if training:
x = tf.nn.dropout(x, rate=0.5)
return x
@tf.function
def maybe_tensor_cond(x):
if x < 0:
x = -x
return x
test_tf_cond(hyperparam_cond, tf.ones([1], dtype=tf.float32))
test_tf_cond(maybe_tensor_cond, tf.constant(-1)) # 條件爲tensor
test_tf_cond(maybe_tensor_cond, -1)
hyperparam_cond(tf.Tensor([1.], shape=(1,), dtype=float32)) 正常執行.
maybe_tensor_cond(tf.Tensor(-1, shape=(), dtype=int32)) 使用 tf.cond.
maybe_tensor_cond(-1) 正常執行.
tf.cond有一些細微之處。 - 它的工做原理是跟蹤條件的兩邊,而後根據條件在運行時選擇適當的分支。跟蹤雙方可能致使意外執行Python代碼 - 它要求若是一個分支建立下游使用的張量,另外一個分支也必須建立該張量。
@tf.function
def f():
x = tf.constant(0)
if tf.constant(True):
x = x + 1
tf.print('執行,x:', x)
print("Tracing `then` branch")
else:
x = x - 1
tf.print('執行,x:', x) # 沒有執行
print("Tracing `else` branch") # 該分支雖然不執行但也被追蹤
return x
f()
Tracing `then` branch
Tracing `else` branch
執行,x: 1
兩個分支必須都定義x
@tf.function
def f():
if tf.constant(True):
x = tf.ones([3, 3])
return x
# 兩個分支必須都定義x, 不然會拋出異常
with assert_raises(ValueError):
f()
Caught expected exception
: in converted code:
:3 f *
if tf.constant(True):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:439 if_stmt
return tf_if_stmt(cond, body, orelse, get_state, set_state)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:456 tf_if_stmt
outputs, final_state = control_flow_ops.cond(cond, body, orelse)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py:507 new_func
return func(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:1147 cond
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/cond_v2.py:86 cond_v2
op_return_value=pred)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:716 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:486 wrapper
outputs = func()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:512 wrapper
tuple(s.symbol_name for s in undefined)))
ValueError: The following symbols must also be initialized in the else branch: ('x',). Alternatively, you may initialize them before the if statement.
AutoGraph和循環
AutoGraph有一些簡單的轉換循環規則。
for:若是iterable是張量,則轉換
while:若是while條件取決於張量,則轉換
若是循環被轉換,它將被動態展開tf.while_loop,或者在a的特殊狀況下for x in tf.data.Dataset轉換爲tf.data.Dataset.reduce。
若是未轉換循環,則將靜態展開
# 測試
def test_dynamically_unrolled(f, *args):
g = f.get_concrete_function(*args).graph
if any(node.name == 'while' for node in g.as_graph_def().node):
print("{}({}) uses tf.while_loop.".format(
f.__name__, ', '.join(map(str, args))))
elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):
print("{}({}) uses tf.data.Dataset.reduce.".format(
f.__name__, ', '.join(map(str, args))))
else:
print("{}({}) gets unrolled.".format(
f.__name__, ', '.join(map(str, args))))
@tf.function
def for_in_range():
x = 0
for i in range(5):
x += i
return x
@tf.function
def for_in_tfrange():
x = tf.constant(0, dtype=tf.int32)
for i in tf.range(5): # 生成迭代的張量
x += i
return x
@tf.function
def for_in_tfdataset():
x = tf.constant(0, dtype=tf.int64)
for i in tf.data.Dataset.range(5):
x += i
return x
test_dynamically_unrolled(for_in_range)
test_dynamically_unrolled(for_in_tfrange)
test_dynamically_unrolled(for_in_tfdataset)
for_in_range() gets unrolled.
for_in_tfrange() uses tf.while_loop.
for_in_tfdataset() uses tf.data.Dataset.reduce.
@tf.function
def while_py_cond():
x = 5
while x > 0:
x -= 1
return x
@tf.function
def while_tf_cond():
x = tf.constant(5)
while x > 0: # while中的x爲張量
x -= 1
return x
test_dynamically_unrolled(while_py_cond)
test_dynamically_unrolled(while_tf_cond)
while_py_cond() gets unrolled.
while_tf_cond() uses tf.while_loop.
若是有一個break或早期的return子句依賴於張量,那麼頂級條件或者iterable也應該是一個張量。
@tf.function
def buggy_while_py_true_tf_break(x):
while True:
if tf.equal(x, 0):
break
x -= 1
return x
@tf.function
def while_tf_true_tf_break(x):
while tf.constant(True): # 有break,頂級條件必須爲張量
if tf.equal(x, 0):
break
x -= 1
return x
with assert_raises(TypeError):
test_dynamically_unrolled(buggy_while_py_true_tf_break, 5)
test_dynamically_unrolled(while_tf_true_tf_break, 5)
Caught expected exception
: in converted code:
:3 buggy_while_py_true_tf_break *
while True:無錫人流醫院哪家好 http://www.ytsg029.com/
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:313 while_stmt
return _py_while_stmt(test, body, init_state, opts)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:401 _py_while_stmt
while test(*state):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:698 __bool__
raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.
while_tf_true_tf_break(5) uses tf.while_loop.
@tf.function
def buggy_py_for_tf_break():
x = 0
for i in range(5):
if tf.equal(i, 3):
break
x += i
return x
@tf.function
def tf_for_tf_break():
x = 0
for i in tf.range(5): # 有break,頂級迭代器必須爲張量
if tf.equal(i, 3):
break
x += i
return x
with assert_raises(TypeError):
test_dynamically_unrolled(buggy_py_for_tf_break)
test_dynamically_unrolled(tf_for_tf_break)
Caught expected exception
: in converted code:
:4 buggy_py_for_tf_break *
for i in range(5):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:110 for_stmt
return _py_for_stmt(iter_, extra_test, body, init_state)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:117 _py_for_stmt
if extra_test is not None and not extra_test(*state):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:698 __bool__
raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.
tf_for_tf_break() uses tf.while_loop.
爲了累積動態展開循環的結果,須要使用tf.TensorArray。
# 實現一個動態rnn
batch_size = 32
seq_len = 3
feature_size=4
# rnn步,輸入與狀態疊加
def rnn_step(inputs, state):
return inputs + state
@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2]) # 每一個時間維度,都是整個batch數據喂入
max_seq_len = input_data.shape[0]
# 保存循環中的狀態,必須使用tf.TensorArray
states = tf.TensorArray(tf.float32, size=max_seq_len)
state = initial_state
# 迭代時間步
for i in tf.range(max_seq_len):
state = rnn_step(input_data[i], state)
states = states.write(i, state)
# 把 batch_size從新換到前面
return tf.transpose(states.stack(), [1, 0, 2])
dynamic_rnn(rnn_step,
tf.random.uniform([batch_size, seq_len, feature_size]),
tf.zeros([batch_size, feature_size]))
array([[[0.42647886, 0.73600817, 0.10211909, 0.89989746],
[0.772506 , 1.6853498 , 0.48793948, 1.4499462 ],
[1.1096102 , 2.3388233 , 0.5920907 , 1.588302 ]],
...
[[0.15579033, 0.4594922 , 0.17970431, 0.19183934],
[0.19597077, 0.5362154 , 0.19988954, 0.38290274],
[0.7524748 , 1.0519221 , 0.76595306, 0.5257962 ]]], dtype=float32)>
與此同時tf.cond,tf.while_loop還帶有一些細微之處。 - 因爲循環能夠執行0次,所以必須在循環上方初始化在while_loop下游使用的全部張量 - 全部循環變量的形狀/ dtypes必須與每次迭代保持一致
@tf.function
def buggy_loop_var_uninitialized():
for i in tf.range(3):
x = i # 必須在循環上方初始化好x
return x
@tf.function
def f():
x = tf.constant(0)
for i in tf.range(3):
x = i
return x
with assert_raises(ValueError):
buggy_loop_var_uninitialized()
f()
Caught expected exception
: in converted code:
:3 buggy_loop_var_uninitialized *
for i in tf.range(3):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:95 for_stmt
return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:125 _known_len_tf_for_stmt
_disallow_undefs_into_loop(*init_state)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:50 _disallow_undefs_into_loop
tuple(s.symbol_name for s in undefined)))
ValueError: TensorFlow requires that the following symbols must be defined before the loop: ('x',)
循環時 變量的類型不能改變
@tf.function
def buggy_loop_type_changes():
x = tf.constant(0, dtype=tf.float32)
for i in tf.range(3): # Yields tensors of type tf.int32...
x = i
return x
with assert_raises(tf.errors.InvalidArgumentError):
buggy_loop_type_changes()
Caught expected exception
: Input 1 of node while/merge/_10 was passed int32 from while/next_iteration/_28:0 incompatible with expected float. [Op:__inference_buggy_loop_type_changes_2119]
循環時變量形狀也不能改變
@tf.function
def buggy_concat():
x = tf.ones([0, 10])
for i in tf.range(5):
x = tf.concat([x, tf.ones([1, 10])], axis=0) # 循環時變量形狀不能改變
return x
with assert_raises(ValueError):
buggy_concat()
@tf.function
def concat_with_padding():
x = tf.zeros([5, 10])
for i in tf.range(5):
x = tf.concat([x[:i], tf.ones([1, 10]), tf.zeros([4-i, 10])], axis=0)
x.set_shape([5, 10])
return x
concat_with_padding()
Caught expected exception
: in converted code:
:4 buggy_concat *
for i in tf.range(5):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:95 for_stmt
return _known_len_tf_for_stmt(iter_, extra_test, body, init_state)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:156 _known_len_tf_for_stmt
opts=dict(maximum_iterations=n))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/operators/control_flow.py:327 _tf_while_stmt
retval = control_flow_ops.while_loop(test, body, init_state, **opts)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2646 while_loop
return_same_structure=return_same_structure)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:213 while_loop
len_orig_loop_vars], expand_composites=True))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:869 _check_shapes_compat
"specify a less-specific shape." % (input_t.name, shape, t.shape))
ValueError: Input tensor 'ones:0' enters the loop with shape (0, 10), but has shape (1, 10) after one iteration. To allow the shape to vary across iterations, use the `shape_invariants` argument of tf.while_loop to specify a less-specific shape.
array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>