通過秋招和畢業論文的折磨,提交完論文終稿的那一刻總算以爲有多餘的時間來搞本身的事情。html
研究論文作的是圖像修復相關,這裏對基於深度學習的圖像修復方面的論文和代碼進行整理,也算是研究生方向有一個比較好的結束。好啦,下面開始進入正題~網絡
全部的image inpainting的介紹在這裏: 基於深度學習的Image Inpainting(論文+代碼)函數
Context encoders for image generation
1. Encoder-decoder pipeline
網絡結構是一個簡單的編碼器-解碼器結構,中間採用Channel-wise fully-connected layer來鏈接編碼器和解碼器,網絡結構如圖。post
1.1 編碼器:採用AlexNet網絡做爲baseline,五個卷積加上池化pool5,若輸入圖像爲227x227,能夠獲得一個6x6x256的特徵圖。
學習
1.2 Channel-wise fully-connected layer:減小網絡參數,若使用全鏈接層,輸入特徵圖爲mxnxn,輸出也爲mxnxn,則須要m2n4的參數,而使用channel-wise僅須要mn4的參數,使用步長爲1的卷積來將信息在通道之間傳遞。編碼
1.3 解碼器:就是一系列的五個上卷積的操做,使其恢復到與原圖同樣的大小。lua
2. Loss function url
包含reconstruction(l2) loss和adversarial loss。spa
2.1 重建L2 loss主要是捕獲缺失區域的總體結構,可是容易在預測輸出中平均多種模式;.net
M做爲二值化的掩碼,沒看懂最外面的M是幹啥用。。
2.2 而adv loss則從多種可能的輸出模式中選擇一種,也能夠說是進行特定模式選擇,使得預測結果看起來更真實。
2.3 兩種loss結合到一塊兒,既具有結構性,也具有真實語義性。
對於任意區域的圖像修復網絡結構圖以下。
我以爲這篇論文的創新點有如下兩點:
1. 使用編碼-解碼器結構來完成圖像修復的任務,並改用channel-wise的方式鏈接,節省了必定的參數。
2. 使用聯合損失函數,結合重建l2 loss和對抗式adv loss,使得修復圖像更加真實。
代碼解讀:train.lua
--------------------------------------------------------------------------- -- Adversarial discriminator net --------------------------------------------------------------------------- local netD = nn.Sequential() if opt.conditionAdv then local netD_ctx = nn.Sequential() -- input Context: (nc) x 128 x 128, going into a convolution netD_ctx:add(SpatialConvolution(nc, ndf, 5, 5, 2, 2, 2, 2)) -- state size: (ndf) x 64 x 64 local netD_pred = nn.Sequential() -- input pred: (nc) x 64 x 64, going into a convolution netD_pred:add(SpatialConvolution(nc, ndf, 5, 5, 2, 2, 2+32, 2+32)) -- 32: to keep scaling of features same as context -- state size: (ndf) x 64 x 64 local netD_pl = nn.ParallelTable(); netD_pl:add(netD_ctx) netD_pl:add(netD_pred) netD:add(netD_pl) netD:add(nn.JoinTable(2)) netD:add(nn.LeakyReLU(0.2, true)) -- state size: (ndf * 2) x 64 x 64 netD:add(SpatialConvolution(ndf*2, ndf, 4, 4, 2, 2, 1, 1)) netD:add(SpatialBatchNormalization(ndf)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf) x 32 x 32 else -- input is (nc) x 64 x 64, going into a convolution netD:add(SpatialConvolution(nc, ndf, 4, 4, 2, 2, 1, 1)) netD:add(nn.LeakyReLU(0.2, true)) -- state size: (ndf) x 32 x 32 end
train.lua中分別獲得生成器和判別器的網絡結構,而後準備數據,進行訓練。這裏選擇判別器的網絡結構代碼分析。
網絡結構中用到了nn.ParallelTable(),向介紹下torch中nn.Sequential,nn.Concat/ConcatTable,nn.Parallel/PararelTable之間的區別。
那麼爲何生成器和判別器都須要用到nn.ParallelTable呢?即對每一個成員模塊應用與之對應的輸入(第i個模塊應用第i個輸入)
個人理解:生成器須要將輸入圖像和noise輸入到生成器中獲得預測的圖像;而判別器須要將真實的圖像和預測的圖像輸入到判別器中。