理解pytorch幾個高級選擇函數(如gather)

1. 引言

  最近在刷開源的Pytorch版動手學深度學習,裏面談到幾個高級選擇函數,如index_select,masked_select,gather等。這些函數大多很容易理解,可是對於gather函數,確實有些難理解,官方文檔開始也看得一臉懵,感受不太直觀。下面談談我對這幾個函數的一些理解。數組

2. 維度的理解

  對於numpy和pytorch,其數組在作維度運算上剛開始可能會給人一種直觀上的誤解,以numpy求矩陣某個維度的最大值爲例(pytorch的理解也是同樣的)函數

import numpy as np
a = np.arange(1, 13).reshape(3, 4)
"""
result:
a = [[1, 2, 3, 4],
      [5, 6, 7, 8,],
      [9, 10, 11, 12]]
"""

# 對a維度0求最大值
a.max(axis = 0)
"""
result:
[9, 10, 11, 12]
"""

# 對a維度1求最大值
a.max(axis = 1)
"""
result:
[4, 8, 12]
"""

  若是對a矩陣在維度0上找最大值,根據咱們直觀上的經驗應該是[4, 8, 12]。即從[1, 2, 3, 4]找到4,從[5, 6, 7, 8]找到8,從[9, 10, 11, 12]找到12。可是從上面結果來看,numpy運算卻給了咱們直觀上認爲是列最大值的結果[9, 10, 11, 12]。
  實際numpy(pytorch)運算應該理解爲往給定的維度進行移動運算。仍是以維度0爲例,維度0上有3個向量,分別爲[1, 2, 3, 4],[5, 6, 7, 8]和[9, 10, 11, 12]。往維度0移動,即[1, 2, 3, 4]和[5, 6, 7, 8]逐元素計算最大值,獲得[5, 6, 7, 8],再和[9, 10, 11, 12]運算獲得結果[9, 10, 11, 12]。學習

維度運算圖1
  另外,對於維度爲3的數組,在numpy和pytorch中,應該把維度0理解爲通道數,維度1和維度2纔是對應高和寬。若是是3維數組對應着用於多輸入通道和單輸出通道的卷積核(維度爲U x V x D),那麼4維數組就對應着用於多輸入通道和多輸出通道的卷積核(維度爲U x V x D x P),此時,維度0則爲多通道卷積核數量的方向,維度1爲通道數,維度2和3纔是分別對應高和寬。
維度運算圖2

3. gather函數

pytorch和numpy中許多函數都涉及維度運算,gather也不例外,可是它相對於其餘函數更難理解。依然先來看一個例子code

import torch
a = torch.arange(1, 16).reshape(5, 3)
"""
result:
a = [[1, 2, 3],
      [4, 5, 6],
      [7, 8, 9],
      [10, 11, 12],
      [13, 14, 15]]
"""

# 定義兩個index
b = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]])
c = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]])

# axis=0
output1 = a.gather(0, b)
"""
result:
[[1, 5, 9],
[7, 11, 15],
[1, 8, 15]]
"""

# axis=1
output2 = a.gather(1, c)
"""
result:
[[2, 3, 1, 3, 2],
[5, 6, 5, 4, 4]]
"""

上面的例子看起來可能有點複雜,咱們來一步步的分析它,先從gather維度爲0開始講起。blog

  1. a.gather(0, b)分爲3個部分,a是須要被提取元素的矩陣,0表明的是提取的維度爲0,b是提取元素的索引
    • 其中規定b和a是同維張量,即a是2維張量,b也必須是2維張量
  2. 0除了表明往維度0的方向提取元素外,還有一個特權---提取結果output能夠在這個維度上的長度與a不一樣。打個比方,a如今的shape爲(5, 3),那麼提取結果output1的shape能夠是(1,3),(2, 3),甚至(n, 3)。具體維度0的長度到底爲多少由b來決定。
  3. 根據0的特權,致使了給定的b張量除了維度0外,其餘的維度大小必須和a同樣。其中張量b實際上包含如下兩個信息
    • b能夠利用除用於gather的維度(此處爲維度0)外的維度來定位出惟一一個向量,也就是a[:, ?](三維度也是同理的,有a[:, ?1, ?2]),?的取值範圍爲a同維度的index。
    • 對於上述定位出的向量,經過b中的元素來定位提取向量中的哪個元素。
    • 上面說得可能有點抽象,實際上b中的每一個元素都能在a中提取出一個元素。舉個具體點的例子,按照上面所說的,b[0, 0]能夠提取a中的一個元素。對於b[0,0],除了維度0外,能夠經過維度1來定位出惟一一個向量a[:, 0]。由於b[0, 0]的元素爲0,即提取的是a[:, 0]的第0個元素---1,並將其做爲output1[0, 0]的提取結果。
      下圖給出了維度0和維度1,gather運算的圖示
gather 2維度
對於3維或者更高維度的張量gather的原理也是同樣的
gather 2維度

4. index_select函數

其餘的高級選擇函數都比較容易理解,這裏簡單的提一下。torch.index_select主要是根據傳入的tensor來往給定的axis方向來選取張量索引

import torch
a = torch.arange(9).reshape(3, 3)
torch.index_select(a, 0, torch.tensor([0, 2]))
"""
result:
[[0, 1, 2],
[6, 7, 8]]
"""

5. masked_select函數

實際上就是經過掩碼條件來選擇元素,像torch.masked_select(x, x>0.5),其實是和x[x>0.5]等價的,最後返回的是一維張量文檔

import torch
a = torch.rand(5, 3)

# 結果和a[a > 0.5]等價
torch.masked_select(a, a>0.5)

6. nonzero函數

找到非零元素的index深度學習

import torch
a = torch.eye(3)
torch.nonzero(a)

"""
result: 對應着非零元素的index
[[0, 0],
[1, 1],
[2, 2]]
"""
相關文章
相關標籤/搜索