tf.gather(params, indices, validate_indices=None, name=None, axis=0)
indices
在
axis
這個軸上對
params
進行索引,拼接成一個新的張量。
CPU
,則會報錯,若是在GPU
上進行操做的,那麼相應的輸出值將會被置爲0,而不會報錯,所以認真檢查是否越界。import tensorflow as tf temp4=tf.reshape(tf.range(0,20)+tf.constant(1,shape=[20]),[2,2,5])
temp4: [[[ 1 2 3 4 5] [ 6 7 8 9 10]] [[11 12 13 14 15] [16 17 18 19 20]]]
temp5=tf.gather(temp4,[0,1],axis=0) #indices是向量
temp5:
[[[ 1 2 3 4 5] [ 6 7 8 9 10]] [[11 12 13 14 15] [16 17 18 19 20]]]
temp7=tf.gather(temp4,[1,4],axis=2)
# (2,2,5)[:2]+(2,)+(2,2,5)[3:]=(2,2,2)
temp7: [[[ 2 5] [ 7 10]] [[12 15] [17 20]]]
temp6=tf.gather(temp4,1,axis=1) #indices是數值
# (2,2,5)[:1]+()+(2,2,5)[2:]=(2,5)
temp:
[[ 6 7 8 9 10] [16 17 18 19 20]]
temp8=tf.gather(temp4,[[0,1],[3,4]],axis=2) #indices是多維的
# (2,2,5)[:2]+(2,2)+(2,2,5)[3:]=(2,2,2,2)
temp8:
[[[[ 1 2] [ 4 5]] [[ 6 7] [ 9 10]]] [[[11 12] [14 15]] [[16 17] [19 20]]]]
bert源碼:python
flat_input_ids = tf.reshape(input_ids, [-1]) #【batch_size*seq_length*input_num】 if use_one_hot_embeddings: one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) output = tf.matmul(one_hot_input_ids, embedding_table) else: output = tf.gather(embedding_table, flat_input_ids)
tf.gather_nd( params, indices, name=None, batch_dims=0)
功能:相似於tf.gather
,不事後者只能在一個維度上進行索引,而前者能夠在多個維度上進行索引,數組
參數:機器學習
返回維度: indices.shape[:-1] + params.shape[indices.shape[-1]:],前面的indices.shape[:-1]表明索引後的指定形狀函數
舉例:學習
indices = [[0, 0], [1, 1]] params = [['a', 'b'], ['c', 'd']] # (2,2)[:-1]+(2,2)[(2,2)[-1]:]=(2,) output = ['a', 'd'] 表示將params對應第一行第一列的'a'和第二行第二列的'd'取出來 indices = [[1], [0]] params = [['a', 'b'], ['c', 'd']] # (2,1)[:-1]+(2,2)[(2,1)[-1]:]=(2,)+(2,)=(2,2) output = [['c', 'd'], ['a', 'b']] 表示將params對應第二行和第一行取出來 ''' 功能:T是一個二維tensor,咱們想要根據另一個二維tensor value的最後一維最大元素的下標選出tensor T中 最後一維最大的元素,組成一個新的一維的tensor,那麼就能夠首先選出最後一維度的下標[1,2,3], 而後將其擴展成[[0,1],[1,2],[2,3]],而後使用這個函數選擇便可。 ''' import tensorflow as tf sess = tf.InteractiveSession() values = tf.constant([[0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0]]) T = tf.constant([[0,1,2,3], [4,5,6,7], [8,9,10,11]]) max_indices = tf.argmax(values, axis=1) # 行 print('max_indices',max_indices.eval()) # [3 1 2] # If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0]. print(tf.stack((tf.range(T.get_shape()[0],dtype=max_indices.dtype),max_indices),axis=1).eval()) print(tf.range(T.get_shape()[0]).eval()) result = tf.gather_nd(T, tf.stack((tf.range(T.get_shape()[0], dtype=max_indices.dtype), max_indices), axis=1)) print(result.eval())
做用:支持對張量的批量索引.注意由於是批處理,因此indices要有和params相同的第0個維度。spa
import tensorflow as tf tensor_a = tf.Variable([[1,2,3],[4,5,6],[7,8,9]]) tensor_b = tf.Variable([[0],[1],[2]],dtype=tf.int32) tensor_c = tf.Variable([[0],[0],[0]],dtype=tf.int32) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print('gather') print(sess.run(tf.gather(tensor_a,tensor_b))) print(sess.run(tf.gather(tensor_a,tensor_c))) print('gather_nd') print(sess.run(tf.gather_nd(tensor_a, tensor_b))) print(sess.run(tf.gather_nd(tensor_a, tensor_c))) print('batch_gather') print(sess.run(tf.batch_gather(tensor_a, tensor_b))) print(sess.run(tf.batch_gather(tensor_a, tensor_c)))
tf.where(condition, x=None, y=None, name=None)
做用: 返回condition爲True的元素座標(x=y=None).net
返回維度: (num_true, dim_size(condition)),其中dim_size爲condition的維度。code
import tensorflow as tf a = [[1,2,3],[4,5,6]] b = [[1,0,3],[1,5,1]] condition1 = [[True,False,False], [False,True,True]] condition2 = [[True,False,False], [False,True,False]] with tf.Session() as sess: print(sess.run(tf.where(condition1))) print(sess.run(tf.where(condition2)))
import tensorflow as tf x = [[1,2,3],[4,5,6]] y = [[7,8,9],[10,11,12]] condition3 = [[True,False,False], [False,True,True]] condition4 = [[True,False,False], [True,True,False]] with tf.Session() as sess: print(sess.run(tf.where(condition3,x,y))) print(sess.run(tf.where(condition4,x,y)))
tf.slice(inputs, begin, size, name)
做用:用來進行切片操做,實如今python
中的a[:,2:3,5:6]
相似的操做,從列表、數組、張量等對象中抽取一部分數據對象
import tensorflow as tf t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]]) tf.slice(t, [1, 0, 0], [1, 1, 3]) # [[[3, 3, 3]]] tf.slice(t, [1, 0, 0], [1, 2, 3]) # [[[3, 3, 3], # [4, 4, 4]]] tf.slice(t, [1, 0, 0], [2, 1, 3]) # [[[3, 3, 3]], # [[5, 5, 5]]]
bert源碼:blog
# 這裏position embedding是可學習的參數,[max_position_embeddings, width] # 可是一般實際輸入序列沒有達到max_position_embeddings # 因此爲了提升訓練速度,使用tf.slice取出句子長度的embedding # full_position_embeddings:[max_position_embeddings, width] position_embeddings = tf.slice(full_position_embeddings, [0, 0],[seq_length, -1])
參考文獻:
【1】tf.gather, tf.gather_nd和tf.slice_機器學習雜貨鋪1號店-CSDN博客
【2】tf.where/tf.gather/tf.gather_nd - 知乎
【3】tenflow 入門 tf.where()用法_ustbbsy的博客-CSDN博客
【4】tf.gather tf.gather_nd 和 tf.batch_gather 使用方法_張冰洋的天空-CSDN博客