tensorflow踩坑:按tensor中元素進行比較

問題描述

DL中咱們可能會根據tensor中元素的值進行不一樣的操做(好比loss階段會根據grandtruth或outputs中元素的大小進行不一樣的loss操做),這時就要對tensor中的元素進行判斷。在python中能夠用for + if語句進行判斷。但TF中輸入是Tensor,for和if語句失效。python

tf.where說明

  • 格式:tf.where(condition, x=None, y=None, name=None)less

  • 參數:
    condition: 一個元素爲bool型的tensor。元素內容爲false,或true。
    x: 一個和condition有相同shape的tensor,若是x是一個高維的tensor,x的第一維size必須和condition同樣。
    y: 和x有同樣shape的tensorcode

  • 返回:
    一個和x,y有一樣shape的tensorelement

  • 功能:
    遍歷condition Tensor中的元素,若是該元素爲true,則output Tensor中對應位置的元素來自x Tensor中對應位置的元素;不然output Tensor中對應位置的元素來自Y tensor中對應位置的元素。文檔

使用例子

好比當tensor中元素x大於等於5時,對應輸出tensor中元素y=x * 2;不然 y= x * 3的計算get

import os
import sys
import tensorflow as tf
import numpy as np
#
# y = x * 2 (x >= 5)
# y = x * 3 (x < 5)
#
a = tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9])
tmp = tf.constant([0, 0, 0, 0, 0, 0, 0, 0, 0]) 
condition = tf.less(a, 5)
smaller = tf.where(condition, a, tmp)
bigger = tf.where(condition, tmp, a)
compute_smaller = smaller * 3
compute_bigger = bigger * 2
result = compute_smaller + compute_bigger
with tf.Session() as sess:
    print(sess.run(result))
#
# 結果: [ 3  6  9 12 10 12 14 16 18]

上述過程也能夠改爲以下it

a = tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9])
condition = tf.less(a, 5)
result = tf.where(condition, a * 3, a * 2)
with tf.Session() as sess:
    print(sess.run(result))

整個過程當中核心部分是condition條件的設置,該條件能夠用tf提供的less,equal等操做實現(詳細查看tf文檔)。
tmp變量爲引入的一個臨時變量,目的是爲了保證where按條件選擇後輸出的Tensor大小不變(tmp的0元素在乘法中是無心義計算,用該方法保證未背選擇的元素在smaller和bigger中不參與計算)io

詳細的condition 和 tmp須要根據實際的計算進行不一樣的設置import

參考

https://stackoverflow.com/questions/42689342/compare-two-tensors-elementwise-tensorflow
https://stackoverflow.com/questions/37912161/how-can-i-compute-element-wise-conditionals-on-batches-in-tensorflow變量

相關文章
相關標籤/搜索