最近在刷開源的Pytorch版動手學深度學習,裏面談到幾個高級選擇函數,如index_select,masked_select,gather等。這些函數大多很容易理解,可是對於gather函數,確實有些難理解,官方文檔開始也看得一臉懵,感受不太直觀。下面談談我對這幾個函數的一些理解。數組
對於numpy和pytorch,其數組在作維度運算上剛開始可能會給人一種直觀上的誤解,以numpy求矩陣某個維度的最大值爲例(pytorch的理解也是同樣的)函數
import numpy as np a = np.arange(1, 13).reshape(3, 4) """ result: a = [[1, 2, 3, 4], [5, 6, 7, 8,], [9, 10, 11, 12]] """ # 對a維度0求最大值 a.max(axis = 0) """ result: [9, 10, 11, 12] """ # 對a維度1求最大值 a.max(axis = 1) """ result: [4, 8, 12] """
若是對a矩陣在維度0上找最大值,根據咱們直觀上的經驗應該是[4, 8, 12]。即從[1, 2, 3, 4]找到4,從[5, 6, 7, 8]找到8,從[9, 10, 11, 12]找到12。可是從上面結果來看,numpy運算卻給了咱們直觀上認爲是列最大值的結果[9, 10, 11, 12]。
實際numpy(pytorch)運算應該理解爲往給定的維度進行移動運算。仍是以維度0爲例,維度0上有3個向量,分別爲[1, 2, 3, 4],[5, 6, 7, 8]和[9, 10, 11, 12]。往維度0移動,即[1, 2, 3, 4]和[5, 6, 7, 8]逐元素計算最大值,獲得[5, 6, 7, 8],再和[9, 10, 11, 12]運算獲得結果[9, 10, 11, 12]。學習
pytorch和numpy中許多函數都涉及維度運算,gather
也不例外,可是它相對於其餘函數更難理解。依然先來看一個例子code
import torch a = torch.arange(1, 16).reshape(5, 3) """ result: a = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]] """ # 定義兩個index b = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]]) c = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]]) # axis=0 output1 = a.gather(0, b) """ result: [[1, 5, 9], [7, 11, 15], [1, 8, 15]] """ # axis=1 output2 = a.gather(1, c) """ result: [[2, 3, 1, 3, 2], [5, 6, 5, 4, 4]] """
上面的例子看起來可能有點複雜,咱們來一步步的分析它,先從gather維度爲0開始講起。blog
a.gather(0, b)
分爲3個部分,a
是須要被提取元素的矩陣,0
表明的是提取的維度爲0,b
是提取元素的索引
0
除了表明往維度0的方向提取元素外,還有一個特權---提取結果output能夠在這個維度上的長度與a不一樣。打個比方,a如今的shape爲(5, 3),那麼提取結果output1的shape能夠是(1,3),(2, 3),甚至(n, 3)。具體維度0的長度到底爲多少由b來決定。0
的特權,致使了給定的b張量除了維度0外,其餘的維度大小必須和a同樣。其中張量b
實際上包含如下兩個信息
其餘的高級選擇函數都比較容易理解,這裏簡單的提一下。torch.index_select主要是根據傳入的tensor來往給定的axis方向來選取張量索引
import torch a = torch.arange(9).reshape(3, 3) torch.index_select(a, 0, torch.tensor([0, 2])) """ result: [[0, 1, 2], [6, 7, 8]] """
實際上就是經過掩碼條件來選擇元素,像torch.masked_select(x, x>0.5),其實是和x[x>0.5]等價的,最後返回的是一維張量文檔
import torch a = torch.rand(5, 3) # 結果和a[a > 0.5]等價 torch.masked_select(a, a>0.5)
找到非零元素的index深度學習
import torch a = torch.eye(3) torch.nonzero(a) """ result: 對應着非零元素的index [[0, 0], [1, 1], [2, 2]] """