tf.gather()、tf.gather_nd()、tf.batch_gather()、tf.where()和tf.slice()

1.tf.gather

tf.gather(params, indices, validate_indices=None, name=None, axis=0)  
功能:根據提供的 indicesaxis這個軸上對 params進行索引,拼接成一個新的張量。
參數:
  1. params:須要被索引的張量
  2. indices:必須爲整數類型,如int32,int64等,注意檢查不要越界了,由於若是越界了,若是使用的CPU,則會報錯,若是在GPU上進行操做的,那麼相應的輸出值將會被置爲0,而不會報錯,所以認真檢查是否越界。
  3. name:返回張量名稱
返回維度: params.shape[:axis] + indices.shape + params.shape[axis + 1:]
舉例:
import tensorflow as tf
temp4=tf.reshape(tf.range(0,20)+tf.constant(1,shape=[20]),[2,2,5])
temp4:
[[[ 1 2 3 4 5]
[ 6 7 8 9 10]]
 
[[11 12 13 14 15]
[16 17 18 19 20]]]
(1)當indices是向量時,輸出的形狀和輸入形狀相同,不改變
temp5=tf.gather(temp4,[0,1],axis=0) #indices是向量
temp5:
[[[ 1 2 3 4 5] [ 6 7 8 9 10]] [[11 12 13 14 15] [16 17 18 19 20]]]

temp7=tf.gather(temp4,[1,4],axis=2)
# (2,2,5)[:2]+(2,)+(2,2,5)[3:]=(2,2,2)
temp7:
[[[ 2 5]
[ 7 10]]
 
[[12 15]
[17 20]]]
(2)當indices是數值時,輸出的形狀比輸入的形狀少一維
temp6=tf.gather(temp4,1,axis=1) #indices是數值
# (2,2,5)[:1]+()+(2,2,5)[2:]=(2,5)
temp:
[[ 6 7 8 9 10] [16 17 18 19 20]]
(3)當indices是多維時
temp8=tf.gather(temp4,[[0,1],[3,4]],axis=2) #indices是多維的
# (2,2,5)[:2]+(2,2)+(2,2,5)[3:]=(2,2,2,2)
temp8:
[[[[ 1 2]
[ 4 5]]
 
[[ 6 7]
[ 9 10]]]
 
[[[11 12]
[14 15]]
 
[[16 17]
[19 20]]]]

bert源碼:python

flat_input_ids = tf.reshape(input_ids, [-1]) #【batch_size*seq_length*input_num】
if use_one_hot_embeddings:
  one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
  output = tf.matmul(one_hot_input_ids, embedding_table)
else:
  output = tf.gather(embedding_table, flat_input_ids)

2.tf.gather_nd

tf.gather_nd(
  params,
  indices,
  name=None,
  batch_dims=0)

功能:相似於tf.gather,不事後者只能在一個維度上進行索引,而前者能夠在多個維度上進行索引,數組

參數:機器學習

  1. params:待索引輸入張量
  2. indices:索引,int32,int64,indices將切片定義爲params的前N個維度,其中N = indices.shape [-1]
    1. 一般要求indices.shape[-1] <= params.rank(能夠用np.ndim(params)查看)
    2. 若是等號成立是在索引具體元素
    3. 若是等號不成立是在沿params的indices.shape[-1]軸進行切片
  3. name=None:操做的名稱(可選)

返回維度: indices.shape[:-1] + params.shape[indices.shape[-1]:],前面的indices.shape[:-1]表明索引後的指定形狀函數

舉例:學習

indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
# (2,2)[:-1]+(2,2)[(2,2)[-1]:]=(2,)
output = ['a', 'd']
表示將params對應第一行第一列的'a'和第二行第二列的'd'取出來

indices = [[1], [0]]
params = [['a', 'b'], ['c', 'd']]
# (2,1)[:-1]+(2,2)[(2,1)[-1]:]=(2,)+(2,)=(2,2)
output = [['c', 'd'], ['a', 'b']]
表示將params對應第二行和第一行取出來 

'''
功能:T是一個二維tensor,咱們想要根據另一個二維tensor value的最後一維最大元素的下標選出tensor T中
最後一維最大的元素,組成一個新的一維的tensor,那麼就能夠首先選出最後一維度的下標[1,2,3],
而後將其擴展成[[0,1],[1,2],[2,3]],而後使用這個函數選擇便可。
'''
import tensorflow as tf
sess = tf.InteractiveSession()
values = tf.constant([[0, 0, 0, 1],
                      [0, 1, 0, 0],
                      [0, 0, 1, 0]])
T = tf.constant([[0,1,2,3],
                 [4,5,6,7],
                 [8,9,10,11]])

max_indices = tf.argmax(values, axis=1) # 行
print('max_indices',max_indices.eval()) # [3 1 2]
# If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0].
print(tf.stack((tf.range(T.get_shape()[0],dtype=max_indices.dtype),max_indices),axis=1).eval())
print(tf.range(T.get_shape()[0]).eval())
result = tf.gather_nd(T, tf.stack((tf.range(T.get_shape()[0],
                                   dtype=max_indices.dtype),
                                   max_indices),
                                   axis=1))
print(result.eval())

3.tf.batch_gather

做用:支持對張量的批量索引.注意由於是批處理,因此indices要有和params相同的第0個維度。spa

import tensorflow as tf
tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]])
tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32)
tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('gather')
    print(sess.run(tf.gather(tensor_a,tensor_b)))
    print(sess.run(tf.gather(tensor_a,tensor_c)))
    print('gather_nd')
    print(sess.run(tf.gather_nd(tensor_a, tensor_b)))
    print(sess.run(tf.gather_nd(tensor_a, tensor_c)))
    print('batch_gather')
    print(sess.run(tf.batch_gather(tensor_a, tensor_b)))
    print(sess.run(tf.batch_gather(tensor_a, tensor_c)))

4.tf.where

tf.where(condition, x=None, y=None, name=None)

做用: 返回condition爲True的元素座標(x=y=None).net

  1. condition:布爾型張量,True/False
  2. x:與y具備相同類型的張量,可使用條件和y進行廣播。
  3. y:與x具備相同類型的張量,能夠在條件和x的條件下進行廣播。
  4. name:操做名稱(可選)

返回維度: (num_true, dim_size(condition)),其中dim_size爲condition的維度。code

(1)tf.where(condition)

  1. condition是bool型值,True/False
  2. 返回值,是condition中元素爲True對應的索引
import tensorflow as tf
a = [[1,2,3],[4,5,6]]
b = [[1,0,3],[1,5,1]]
condition1 = [[True,False,False],
             [False,True,True]]
condition2 = [[True,False,False],
             [False,True,False]]
with tf.Session() as sess:
    print(sess.run(tf.where(condition1)))
    print(sess.run(tf.where(condition2)))

(2)tf.where(condition, x=None, y=None, name=None)

  1. condition, x, y 相同維度,condition是bool型值,True/False
  2. 返回值是對應元素,condition中元素爲True的元素替換爲x中的元素,爲False的元素替換爲y中對應元素
  3. x只負責對應替換True的元素,y只負責對應替換False的元素,x,y各有分工
  4. 因爲是替換,返回值的維度,和condition,x , y都是相等的。
import tensorflow as tf
x = [[1,2,3],[4,5,6]]
y = [[7,8,9],[10,11,12]]
condition3 = [[True,False,False],
             [False,True,True]]
condition4 = [[True,False,False],
             [True,True,False]]
with tf.Session() as sess:
    print(sess.run(tf.where(condition3,x,y)))
    print(sess.run(tf.where(condition4,x,y)))

5.tf.slice()

tf.slice(inputs, begin, size, name)

做用:用來進行切片操做,實如今python中的a[:,2:3,5:6]相似的操做,從列表、數組、張量等對象中抽取一部分數據對象

  1. begin和size是兩個多維列表,他們共同決定了要抽取的數據的開始和結束位置
  2. begin表示從inputs的哪幾個維度上的哪一個元素開始抽取 
  3. size表示在inputs的各個維度上抽取的元素個數
  4. 若begin[]或size[]中出現-1,表示抽取對應維度上的全部元素
import tensorflow as tf
t = tf.constant([[[1, 1, 1], [2, 2, 2]],
                 [[3, 3, 3], [4, 4, 4]],
                 [[5, 5, 5], [6, 6, 6]]])
tf.slice(t, [1, 0, 0], [1, 1, 3])  # [[[3, 3, 3]]]
tf.slice(t, [1, 0, 0], [1, 2, 3])  # [[[3, 3, 3],
#   [4, 4, 4]]]
tf.slice(t, [1, 0, 0], [2, 1, 3])  # [[[3, 3, 3]],
#  [[5, 5, 5]]]

bert源碼:blog

# 這裏position embedding是可學習的參數,[max_position_embeddings, width]
# 可是一般實際輸入序列沒有達到max_position_embeddings
# 因此爲了提升訓練速度,使用tf.slice取出句子長度的embedding
# full_position_embeddings:[max_position_embeddings, width]
position_embeddings = tf.slice(full_position_embeddings, [0, 0],[seq_length, -1])

 

 

 

參考文獻:

【1】tf.gather, tf.gather_nd和tf.slice_機器學習雜貨鋪1號店-CSDN博客

【2】tf.where/tf.gather/tf.gather_nd - 知乎

【3】tenflow 入門 tf.where()用法_ustbbsy的博客-CSDN博客

【4】tf.gather tf.gather_nd 和 tf.batch_gather 使用方法_張冰洋的天空-CSDN博客

相關文章
相關標籤/搜索