pytorch經常使用函數總結(持續更新)

pytorch經常使用函數總結

torch.max(input,dim)

求取指定維度上的最大值,,返回輸入張量給定維度上每行的最大值,並同時返回每一個最大值的位置索引。好比:python

demo.shape
Out[7]: torch.Size([10, 3, 10, 10])
torch.max(demo,1)[0].shape
Out[8]: torch.Size([10, 10, 10])

torch.max(demo,1)[0]這其中的[0]取得就是返回的最大值,torch.max(demo,1)[1]就是返回的最大值對應的位置索引。例子以下:函數

a
Out[8]: 
tensor([[1., 2., 3.],
        [4., 5., 6.]])
a.max(1)
Out[9]: 
torch.return_types.max(
values=tensor([3., 6.]),
indices=tensor([2, 2]))

class torch.nn.ParameterList(parameters=None)

submodules保存在一個list中。spa

ParameterList能夠像通常的Python list同樣被索引。並且ParameterList中包含的parameters已經被正確的註冊,對全部的module method可見。.net

參數說明:code

  • modules (list, optional) – a list of nn.Parameter

例子:blog

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i // 2].mm(x) + p.mm(x)
        return x

torch.cat()函數

cat是concatnate的意思:拼接,聯繫在一塊兒。排序

先說cat( )的普通用法索引

若是咱們有兩個tensor是A和B,想把他們拼接在一塊兒,須要以下操做:內存

C = torch.cat( (A,B),0 )  #按維數0拼接(豎着拼)

C = torch.cat( (A,B),1 )  #按維數1拼接(橫着拼)

至關於將tensor按照指定維度進行拼接,好比A的shape爲128*64*32*32,B的shape爲 128*32*64*64,那麼按照 torch.cat( (A,B),1)拼接的以後的形狀爲 128*96*64*64get

注意:

兩個tensor要想進行拼接,必須保證除了指定拼接的維度之外其餘的維度形狀必須相同,好比上面的例子,拼接A和B時,A的形狀爲128*64*32*32,B的形狀爲128*32*64*64,只有第二個維度的維數數值不一樣,其餘的維度的維數都是相同的,因此拼接時可按維度1進行拼接(注意,維度的下標是從0開始的,好比 A 的形狀對應的維度下標爲:\(128_0*64_1*32_2*32_3\)

contiguous()函數的使用

contiguous通常與transpose,permute,view搭配使用:使用transpose或permute進行維度變換後,調用contiguous,而後方可以使用view對維度進行變形(如:tensor_var.contiguous().view() ),示例以下:

x = torch.Tensor(2,3)
y = x.permute(1,0)         # permute:二維tensor的維度變換,此處功能至關於轉置transpose
y.view(-1)                 # 報錯,view使用前需調用contiguous()函數
y = x.permute(1,0).contiguous()
y.view(-1)                 # OK

具體緣由有兩種說法:

1 transpose、permute等維度變換操做後,tensor在內存中再也不是連續存儲的,而view操做要求tensor的內存連續存儲,因此須要contiguous來返回一個contiguous copy;

2 維度變換後的變量是以前變量的淺拷貝,指向同一區域,即view操做會連帶原來的變量一同變形,這是不合法的,因此也會報錯;---- 這個解釋有部分道理,也即contiguous返回了tensor的深拷貝contiguous copy數據;

原文連接:https://zhuanlan.zhihu.com/p/64376950

tensor.repeat()函數

該函數傳入的參數個數很多於tensor的維數,其中每一個參數表明的是對該維度重複多少次,也就至關於複製的倍數,結合例子更好理解,以下:

>>> import torch
>>> 
>>> a = torch.randn(33, 55)
>>> a.size()
torch.Size([33, 55])
>>> 
>>> a.repeat(1, 1).size()
torch.Size([33, 55])
>>> 
>>> a.repeat(2,1).size()
torch.Size([66, 55])
>>> 
>>> a.repeat(1,2).size()
torch.Size([33, 110])
>>>
>>> a.repeat(1,1,1).size()
torch.Size([1, 33, 55])
>>>
>>> a.repeat(2,1,1).size()
torch.Size([2, 33, 55])
>>>
>>> a.repeat(1,2,1).size()
torch.Size([1, 66, 55])
>>>
>>> a.repeat(1,1,2).size()
torch.Size([1, 33, 110])
>>>
>>> a.repeat(1,1,1,1).size()
torch.Size([1, 1, 33, 55])
>>> 
>>> # repeat()的參數的個數,不能少於被操做的張量的維度的個數,
>>> # 下面是一些錯誤示例
>>> a.repeat(2).size()  # 1D < 2D, error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b = torch.randn(5,6,7)
>>> b.size() # 3D
torch.Size([5, 6, 7])
>>> 
>>> b.repeat(2).size() # 1D < 3D, error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1).size() # 2D < 3D, error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1,1).size() # 3D = 3D, okay
torch.Size([10, 6, 7])
>>>

參考博客:http://www.javashuo.com/article/p-mhfzbqgs-nc.html

torch.masked_select()函數

a = torch.Tensor([[4,5,7], [3,9,8],[2,3,4]])
b = torch.Tensor([[1,1,0], [0,0,1],[1,0,1]]).type(torch.ByteTensor)
c = torch.masked_select(a,b)
print(c)

img

用法:torch.masked_select(x, mask),mask必須轉化成torch.ByteTensor類型。

torch.sort

torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)

對輸入張量input沿着指定維按升序排序。若是不給定dim,則默認爲輸入的最後一維。若是指定參數descendingTrue,則按降序排序

返回元組 (sorted_tensor, sorted_indices) , sorted_indices 爲原始輸入中的下標。

參數:

  • input (Tensor) – 要對比的張量
  • dim (int, optional) – 沿着此維排序
  • descending (bool, optional) – 布爾值,控制升降排序
  • out (tuple, optional) – 輸出張量。必須爲ByteTensor或者與第一個參數tensor相同類型。

例子:

>>> x = torch.randn(3, 4)
>>> sorted, indices = torch.sort(x)
>>> sorted

-1.6747  0.0610  0.1190  1.4137
-1.4782  0.7159  1.0341  1.3678
-0.3324 -0.0782  0.3518  0.4763
[torch.FloatTensor of size 3x4]

>>> indices

 0  1  3  2
 2  1  0  3
 3  1  0  2
[torch.LongTensor of size 3x4]

>>> sorted, indices = torch.sort(x, 0)
>>> sorted

-1.6747 -0.0782 -1.4782 -0.3324
 0.3518  0.0610  0.4763  0.1190
 1.0341  0.7159  1.4137  1.3678
[torch.FloatTensor of size 3x4]

>>> indices

 0  2  1  2
 2  0  2  0
 1  1  0  1
[torch.LongTensor of size 3x4]
相關文章
相關標籤/搜索