TensorFlow中判斷語句和循環語句的用法

因爲tensorflow使用的是graph的計算概念,在沒有涉及控制數據流向的時候編程和普通編程語言的編程差異不大,可是涉及到控制數據流向的操做時,就要特別當心,否則很容易出錯。這也是TensorFlow比較反直覺的地方。編程

在TensorFlow中,tf.cond()相似於c語言中的if...else...,用來控制數據流向,可是僅僅相似而已,其中差異仍是挺大的。關於tf.cond()函數的具體操做,我參考了tf的說明文檔。session

format:tf.cond(pred, fn1, fn2, name=None)編程語言

Return :either fn1() or fn2() based on the boolean predicate pred.(注意這裏,也就是說'fnq'和‘fn2’是兩個函數)ide

arguments:fn1 and fn2 both return lists of output tensors. fn1 and fn2 must have the same non-zero number and type of outputs('fnq'和‘fn2’返回的是非零的且類型相同的輸出)函數

官方例子:
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
上面例子執行這樣的操做,若是x<y則result這個操做是tf.add(x,z),反之則是tf.square(y)。這一點上,確實很像邏輯控制中的if...else...,可是官方說明裏也提到
Since z is needed for at least one branch of the cond,branch of the cond, the tf.mul operation is always executed, unconditionally.oop

由於z在cond函數中的至少一個分支被用到,因此
z = tf.multiply(a, b)
老是被無條件執行,這個確實很反直覺,跟我想象中的不太同樣,按通常的邏輯不該該是不用到就不執行麼?,而後查閱官方文檔,我感覺到了來之官方文檔深深的鄙視0.0
Although this behavior is consistent with the dataflow model of TensorFlow,it has occasionally surprised some users who expected a lazier semantics.this

翻譯過來應該是:儘管這樣的操做與TensorFlow的數據流模型一致,可是偶爾仍是會令那些指望慵懶語法的用戶吃驚。(應該是這麼翻譯的吧,淦,我就那個懶人0.0)翻譯

好吧,我就大概記錄一下我本身的理解(若是錯了,歡迎拍磚)。由於TensorFlow是基於圖的計算,數據以流的形式存在,因此只要構建好了圖,有數據源,那麼應該都會 數據流過,因此在執行tf.cond以前,兩個數據流一個是tf.add()中的x,z,一個是tf.square(y)中的y,而tf.cond()就決定了是數據流x,z從tf.add()流過,仍是數據流y從tf.square()流過。這裏這個tf.cond也就像個控制水流的閥門,水流管道x,z,y在這個閥門交匯,而tf.cond決定了誰將流向後面的管道,可是無論哪個水流流向下一個管道,在閥門做用以前,水流應該都是要到達閥門的。(囉囉嗦嗦了一大堆,仍是不太理解)code

栗子:orm

import tensorflow as tf  
a=tf.constant(2)      
b=tf.constant(3)      
x=tf.constant(4)      
y=tf.constant(5)      
z = tf.multiply(a, b)      
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))      
with tf.Session() as session:      
    print(result.eval())

tensorflow for循環 while循環 實例

import tensorflow as tf

n1 = tf.constant(2)
n2 = tf.constant(3)
n3 = tf.constant(4)

def cond1(i, a, b):
    return i < n1

def cond2(i, a, b):
    return i < n2

def cond3(i, a, b):
    return i < n3

def body(i, a, b):
    return i + 1, b, a + b

i1, a1, b1 = tf.while_loop(cond1, body, (2, 1, 1))
i2, a2, b2 = tf.while_loop(cond2, body, (2, 1, 1))
i3, a3, b3 = tf.while_loop(cond3, body, (2, 1, 1))
sess = tf.Session()

print(sess.run(i1))
print(sess.run(a1))
print(sess.run(b1))
print("-")
print(sess.run(i2))
print(sess.run(a2))
print(sess.run(b2))
print("-")
print(sess.run(i3))
print(sess.run(a3))
print(sess.run(b3))
相關文章
相關標籤/搜索