原文連接 tensorflow中取下標的函數包括:tf.gather , tf.gather_nd 和 tf.batch_gather。python
indices必須是一維張量 主要參數:api
返回值:經過indices獲取params下標的張量。 例子:函數
import tensorflow as tf tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]]) tensor_b = tf.Variable([1,2,0],dtype=tf.int32) tensor_c = tf.Variable([0,0],dtype=tf.int32) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(tf.gather(tensor_a,tensor_b))) print(sess.run(tf.gather(tensor_a,tensor_c)))
上個例子tf.gather(tensor_a,tensor_b) 的值爲[[4,5,6],[7,8,9],[1,2,3]],tf.gather(tensor_a,tensor_b) 的值爲[[1,2,3],[1,2,3]]學習
對於tensor_a,其第1個元素爲[4,5,6],第2個元素爲[7,8,9],第0個元素爲[1,2,3],因此以[1,2,0]爲索引的返回值是[[4,5,6],[7,8,9],[1,2,3]],一樣的,以[0,0]爲索引的值爲[[1,2,3],[1,2,3]]。spa
https://www.tensorflow.org/api_docs/python/tf/gather.net
功能和參數與tf.gather相似,不一樣之處在於tf.gather_nd支持多維度索引,即indices能夠使多維張量。 例子:code
import tensorflow as tf tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]]) tensor_b = tf.Variable([[1,0],[1,1],[1,2]],dtype=tf.int32) tensor_c = tf.Variable([[0,2],[2,0]],dtype=tf.int32) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(tf.gather_nd(tensor_a,tensor_b))) print(sess.run(tf.gather_nd(tensor_a,tensor_c))) tf.gather_nd(tensor_a,tensor_b)值爲[4,5,6],tf.gather_nd(tensor_a,tensor_c)的值爲[3,7].
對於tensor_a,下標[1,0]的元素爲4,下標爲[1,1]的元素爲5,下標爲[1,2]的元素爲6,索引[1,0],[1,1],[1,2]]的返回值爲[4,5,6],一樣的,索引[[0,2],[2,0]]的返回值爲[3,7].blog
https://www.tensorflow.org/api_docs/python/tf/gather_nd索引
支持對張量的批量索引,各參數意義見(1)中描述。注意由於是批處理,因此indices要有和params相同的第0個維度。get
例子:
import tensorflow as tf tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]]) tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32) tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(tf.batch_gather(tensor_a,tensor_b))) print(sess.run(tf.batch_gather(tensor_a,tensor_c))) tf.gather_nd(tensor_a,tensor_b)值爲[1,5,9],tf.gather_nd(tensor_a,tensor_c)的值爲[1,4,7].
tensor_a的三個元素[1,2,3],[4,5,6],[7,8,9]分別對應索引元素的第一,第二和第三個值。[1,2,3]的第0個元素爲1,[4,5,6]的第1個元素爲5,[7,8,9]的第2個元素爲9,因此索引[[0],[1],[2]]的返回值爲[1,5,9],一樣地,索引[[0],[0],[0]]的返回值爲[1,4,7].
https://www.tensorflow.org/api_docs/python/tf/batch_gather
在深度學習的模型訓練中,有時候須要對一個batch的數據進行相似於tf.gather_nd的操做,但tensorflow中並無tf.batch_gather_nd之類的操做,此時須要tf.map_fn和tf.gather_nd結合來實現上述操做。