pytorch合併與拆分

import torch內存

import numpy as npit

 

#自動擴張,或者叫作廣播,不須要拷貝數據.ast

#若是前面沒有維度,則在前面插入一個新的維度;import

#對齊時默認從後面開始對齊.numpy

#broading cast 能夠簡化運算而且減小內存拷貝.im

a = torch.randn(4,3)數據

b = torch.rand(4,3)di

c = a + bcas

c = torch.randn(1,3)

d = a+ c

 

#拼接1:

a = torch.randn(5,32,48)

b = torch.randn(4,32,48)

c = torch.cat([a,b],dim=0)#在第0維度合併

print(c.shape)#torch.Size([9, 32, 48])

print(torch.cat([a,b]).shape)#默認拼接按照0維

 

#拼接2 stack會建立新的維度,注意其形狀必須匹配

a = torch.randn(32,8)

b = torch.randn(32,8)

d = torch.stack([a,b],dim=2)#torch.Size([32, 8, 2])

print(d.shape)

 

#拆分split 長度拆分,如[1,2,3,4,5,6]指定拆分長度2,則拆分爲三個單元

a = torch.randn(2,32,8)

b = a.split(1,dim=0)#第0個維度拆分

print(type(b))

 

#按數量區分

print(a.chunk(2,dim=0))

相關文章
相關標籤/搜索