DL中咱們可能會根據tensor中元素的值進行不一樣的操做(好比loss階段會根據grandtruth或outputs中元素的大小進行不一樣的loss操做),這時就要對tensor中的元素進行判斷。在python中能夠用for + if語句進行判斷。但TF中輸入是Tensor,for和if語句失效。python
格式:tf.where(condition, x=None, y=None, name=None)less
參數:
condition: 一個元素爲bool型的tensor。元素內容爲false,或true。
x: 一個和condition有相同shape的tensor,若是x是一個高維的tensor,x的第一維size必須和condition同樣。
y: 和x有同樣shape的tensorcode
返回:
一個和x,y有一樣shape的tensorelement
功能:
遍歷condition Tensor中的元素,若是該元素爲true,則output Tensor中對應位置的元素來自x Tensor中對應位置的元素;不然output Tensor中對應位置的元素來自Y tensor中對應位置的元素。文檔
好比當tensor中元素x大於等於5時,對應輸出tensor中元素y=x * 2;不然 y= x * 3的計算get
import os import sys import tensorflow as tf import numpy as np # # y = x * 2 (x >= 5) # y = x * 3 (x < 5) # a = tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9]) tmp = tf.constant([0, 0, 0, 0, 0, 0, 0, 0, 0]) condition = tf.less(a, 5) smaller = tf.where(condition, a, tmp) bigger = tf.where(condition, tmp, a) compute_smaller = smaller * 3 compute_bigger = bigger * 2 result = compute_smaller + compute_bigger with tf.Session() as sess: print(sess.run(result)) # # 結果: [ 3 6 9 12 10 12 14 16 18]
上述過程也能夠改爲以下it
a = tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9]) condition = tf.less(a, 5) result = tf.where(condition, a * 3, a * 2) with tf.Session() as sess: print(sess.run(result))
整個過程當中核心部分是condition條件的設置,該條件能夠用tf提供的less,equal等操做實現(詳細查看tf文檔)。
tmp變量爲引入的一個臨時變量,目的是爲了保證where按條件選擇後輸出的Tensor大小不變(tmp的0元素在乘法中是無心義計算,用該方法保證未背選擇的元素在smaller和bigger中不參與計算)io
詳細的condition 和 tmp須要根據實際的計算進行不一樣的設置import
https://stackoverflow.com/questions/42689342/compare-two-tensors-elementwise-tensorflow
https://stackoverflow.com/questions/37912161/how-can-i-compute-element-wise-conditionals-on-batches-in-tensorflow變量