pytorch維度變換

import torchui

import numpy as np圖片

 

#維度變換1:view容易形成數據存儲方式丟失.import

a = torch.rand(4,1,28,28)擴展

print(a.shape,a.view(4,28,28))#4,28,28 4張圖片,把每張圖片都合併在一塊兒,即784,經常使用於全鏈接層;channel

print(a.view(4,28*28).shape)#torch.Size([4, 784])numpy

print(a.view(4*28,28))#把全部通道全部行都放在第一個維度,即channel和行通道合併在一塊兒im

print(a.view(4*1,28,28))#數據

 

#維度展開:unsqueeze,注意能插入範圍是[-5,4)這裏4表明整個維度,5表明維度加1,好比0表明第一個位置前插入,1表明第二個位置前插入,3表明第三個位置前插入view

b = a.unsqueeze(0)#torch.Size([1, 4, 1, 28, 28])vi

print(b.shape)

c = a.unsqueeze(-1)#torch.Size([4, 1, 28, 28, 1])

print(c.shape)

"""

-5 -4 -3 -2 -1

0 1 2 3 4

4 1 28 28

 

torch.Size([1, 4, 1, 28, 28])

torch.Size([4, 1, 28, 28, 1])

torch.Size([1, 4, 1, 28, 28])

torch.Size([4, 1, 1, 28, 28])

torch.Size([4, 1, 1, 28, 28])

torch.Size([4, 1, 28, 1, 28])

torch.Size([4, 1, 28, 28, 1])

torch.Size([1, 4, 1, 28, 28])

torch.Size([4, 1, 1, 28, 28])

torch.Size([4, 1, 1, 28, 28])

torch.Size([4, 1, 28, 1, 28])

torch.Size([4, 1, 28, 28, 1])

儘可能不使用負數

"""

for i in range(-5,5):

d = a.unsqueeze(i)

print(d.shape)

 

b = torch.rand(32)

f = torch.rand(4,32,14,14)

b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)#torch.Size([1, 32, 1, 1])

print(b.shape)


 

#維度刪減squeeze

"""

torch.Size([32, 1, 1])

torch.Size([1, 32, 1, 1])

torch.Size([1, 32, 1])

torch.Size([1, 32, 1])

torch.Size([32, 1, 1])

torch.Size([1, 32, 1, 1])

torch.Size([1, 32, 1])

torch.Size([1, 32, 1])

"""

c = b.squeeze()#torch.Size([32])

print(c.shape)

for i in range(-4,4):

print(b.squeeze(i).shape)

 

#維度擴展,即把shape改變 expand改變理解方式,不增長數據,repeat增長數據;注意repeat須要拷貝數據,因此速度慢.

b = torch.rand(1,32,1,1)

a = torch.rand(4,32,14,14)

c = b.expand(4,32,14,14)

print(b.shape,a.shape,c.shape)#torch.Size([1, 32, 1, 1]) torch.Size([4, 32, 14, 14]) torch.Size([4, 32, 14, 14])

 

d = b.repeat(4,32,1,1)#這裏4,32,1,1表明數據被拷貝次數;

print(d.shape)#torch.Size([4, 1024, 1, 1])這不是咱們想要結果,正確以下:

d = b.repeat(4,1,1,1)

print(d.shape)#torch.Size([4, 32, 1, 1])

 

#矩陣轉置

a = torch.randn(4,3)

print(a,a.t())#t只用於二維度

 

a = torch.rand(4,3,32,32)

#b = a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)#數據不連續,錯誤

#print(a.shape,c.shape)

b=a.transpose(1,3).contiguous().view(4,3*32*32 ).view(4,3,32,32)

c=a.transpose(1,3).contiguous().view(4,3*32*32 ).view(4,32,32,3).transpose(1,3)

print(b.shape,c.shape)#torch.Size([4, 3, 32, 32]) torch.Size([4, 32, 32, 3])

print(torch.all(torch.eq(a,b)),torch.all(torch.eq(a,c)))#tensor(0, dtype=torch.uint8) tensor(1, dtype=torch.uint8) 判斷數據內容是否一致

 

d = a.permute(0,2,3,1)

print(d.shape)#torch.Size([4, 32, 32, 3]) 0,2,3,1表明存放維度數

相關文章
相關標籤/搜索