RuntimeError: size mismatch, m1:[1152 x 1] ,m2:[576 x 192] ,at /opt/conda/conda-bld/pytorch_1524

今天跑代碼的時候遇到了這個錯誤:
RuntimeError: size mismatch, m1:[1152 x 1] ,m2:[576 x 192] ,at /opt/conda/conda-bld/pytorch_1524
在這裏插入圖片描述
調試以後發現是以下有問題:
源代碼在這裏:
這是class裏init中相應的部分
在這裏插入圖片描述
def forward的相應的問題
在這裏插入圖片描述
問題就在於y1 = self.fc_1(y)這裏,fc_1也是同fc同樣的全鏈接層,而全鏈接層輸入的尺寸只有兩維!
從圖上能夠看到,y輸出的尺寸是[2,576,1,1]
在y1 = self.fc_1(y)的時候,由於y是四維的,因此會自動變成兩維的,具體怎麼變呢就是:
假設y是[2,576,1,1],輸入全鏈接層fc_1以後,pytorch會自動將[2,576,1,1]壓成[2X576,1X1],這就是二維的了。
因此應該改爲以下:
在這裏插入圖片描述
在pytorch中,全鏈接層是經過torch.nn.linear()這個函數實現的,輸入的參數只有feature的channels(就是途中的in_features),而size則根據前面的量來自適應的,因此不少人會自動的認爲不須要注意輸入的shape,自適應便可。這時全鏈接層就會幫你把你超過二維的shape給自動調整成二維的,這時就會報錯啦。因此要本身調整fc的輸入shape,只要batchSize和Channel數便可。
在這裏插入圖片描述函數

 

本文同步分享在 博客「於小勇」(CSDN)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。.net

相關文章
相關標籤/搜索