pytorch-tensor處理速查表(cat stack squeeze unsqueeze permute等)

1 torch.catcode

torch.cat((A, B), dim)

將兩個tensor在指定維度進行拼接class

A = torch.zeros(2,3)
    B = torch.zeros(2,3)
    C = torch.cat((A,B), 0) ## shape [4,3]
    D = torch.cat((A,B), 1) ## shape [2,6]

2 torch.stack擴展

torch.stack((A, B), dim)

增長新的維度進行堆疊im

A = torch.zeros(1,3)
B = torch.zeros(1,3)
C = torch.stack((A,B), 0)  ## [2, 1, 3]
D = torch.stack((A,B), 1)  ## [1, 2, 3]
E = torch.stack((A,B), 2)  ## [1, 3, 2]

3 torch.permute數據

A = A.permute(0, 2, 3, 1)

調整tensor的維度順序,至關於更靈活的transpose移動

A = torch.zeros(32, 3, 18, 18)  ## [32, 3, 18, 18]
B = A.permute(0, 2, 3, 1)          ##[32, 18, 18, 3]

4 tensor.contiguous view只能用在contiguous的tensor上。若是在view以前用了transpose, permute等,須要用contiguous()來返回一個contiguous copy。 eg:di

v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv

5 tensor.squeezeview

A = A.squeeze(dim)

去掉tensor的維度爲1的維度,該維度能夠經過參數dim指定,也能夠不加參數,默認找到維度爲1的維度而後去掉vi

A = torch.zeros(1, 18, 18)  ## [1, 18, 18]
B = A.squeeze(0)               ## [18, 18]

6 tensor.unsqueezecopy

A = A.unsqueee(dim)

在tensor中增長一個新的指定維度,新維度放在指定位置 原來維度序列向兩邊移動

A = torch.zeros(2, 3, 4)   ## [2, 3, 4]
B = A.unsqueeze(0)    ## [1, 2, 3, 4]
C = A.unsqueeze(1)    ## [2, 1, 3, 4]      
D = A.unsqueeze(2)    ## [2, 3, 1, 4]
E = A.unsqueeze(3)    ## [2, 3, 4, 1]

7 tensor.expand

A = A.expand()

在指定維度上擴展數據, 該指定維度長度爲1,不然報錯。(此時擴展僅是建立新的視圖,並不進行數據複製)

A = torch.zeros(2, 3, 1) ## [2, 3, 1]
B = A.expand(2, 3, 3)   ## [2, 3,  3]

8 tensor.clone() clone() 獲得的tensor不只拷貝了原始的value,並且會計算梯度傳播信息

b = a.clone()

9 tensor.copy_(src_tensor) 只拷貝src_tensor的數據到dst_tensor上,並返回self

a = torch.ones([3,4])
b = torch.zeros([3,4])
b.copy_(a)

10 生成特定尺度、特定數值的tensor

a = torch.Tensor(3,5).fill_(0)
a = torch.full((3,5), 0, dtype=torch.IntTensor)
相關文章
相關標籤/搜索