[轉]tensorflow中的gather

原文連接 tensorflow中取下標的函數包括:tf.gather , tf.gather_nd 和 tf.batch_gather。python

1.tf.gather(params,indices,validate_indices=None,name=None,axis=0)

indices必須是一維張量 主要參數:api

  • params:被索引的張量
  • indices:一維索引張量
  • name:返回張量名稱

返回值:經過indices獲取params下標的張量。 例子:函數

import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([1,2,0],dtype=tf.int32)
tensor_c = tf.Variable([0,0],dtype=tf.int32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.gather(tensor_a,tensor_b)))
    print(sess.run(tf.gather(tensor_a,tensor_c)))

上個例子tf.gather(tensor_a,tensor_b) 的值爲[[4,5,6],[7,8,9],[1,2,3]],tf.gather(tensor_a,tensor_b) 的值爲[[1,2,3],[1,2,3]]學習

對於tensor_a,其第1個元素爲[4,5,6],第2個元素爲[7,8,9],第0個元素爲[1,2,3],因此以[1,2,0]爲索引的返回值是[[4,5,6],[7,8,9],[1,2,3]],一樣的,以[0,0]爲索引的值爲[[1,2,3],[1,2,3]]。spa

https://www.tensorflow.org/api_docs/python/tf/gather.net

2.tf.gather_nd(params,indices,name=None)

功能和參數與tf.gather相似,不一樣之處在於tf.gather_nd支持多維度索引,即indices能夠使多維張量。 例子:code

import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[1,0],[1,1],[1,2]],dtype=tf.int32)
tensor_c = tf.Variable([[0,2],[2,0]],dtype=tf.int32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.gather_nd(tensor_a,tensor_b)))
    print(sess.run(tf.gather_nd(tensor_a,tensor_c)))
tf.gather_nd(tensor_a,tensor_b)值爲[4,5,6],tf.gather_nd(tensor_a,tensor_c)的值爲[3,7].

對於tensor_a,下標[1,0]的元素爲4,下標爲[1,1]的元素爲5,下標爲[1,2]的元素爲6,索引[1,0],[1,1],[1,2]]的返回值爲[4,5,6],一樣的,索引[[0,2],[2,0]]的返回值爲[3,7].blog

https://www.tensorflow.org/api_docs/python/tf/gather_nd索引

3.tf.batch_gather(params,indices,name=None)

支持對張量的批量索引,各參數意義見(1)中描述。注意由於是批處理,因此indices要有和params相同的第0個維度。get

例子:

import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32)
tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.batch_gather(tensor_a,tensor_b)))
    print(sess.run(tf.batch_gather(tensor_a,tensor_c)))
tf.gather_nd(tensor_a,tensor_b)值爲[1,5,9],tf.gather_nd(tensor_a,tensor_c)的值爲[1,4,7].

tensor_a的三個元素[1,2,3],[4,5,6],[7,8,9]分別對應索引元素的第一,第二和第三個值。[1,2,3]的第0個元素爲1,[4,5,6]的第1個元素爲5,[7,8,9]的第2個元素爲9,因此索引[[0],[1],[2]]的返回值爲[1,5,9],一樣地,索引[[0],[0],[0]]的返回值爲[1,4,7].

https://www.tensorflow.org/api_docs/python/tf/batch_gather

  在深度學習的模型訓練中,有時候須要對一個batch的數據進行相似於tf.gather_nd的操做,但tensorflow中並無tf.batch_gather_nd之類的操做,此時須要tf.map_fn和tf.gather_nd結合來實現上述操做。

相關文章
相關標籤/搜索