Spatially Adaptive Residual Networks for Efficient Image and Video Deblurring

paper python

本文是印度馬德拉斯理工學院的研究員提出的一種基於空間自適應殘差網絡的圖像/視頻去模糊方法。網絡

嚴重模糊圖像復原要求網絡具備極大感覺野,現有網絡每每採用加深網絡層數、加大卷積核尺寸或者多尺度方式提高感覺野,然而這些方法會早知模型大小的提高以及推理耗時提高。做者提出一種組合形變卷積與自注意力機制的去模糊網絡,進一步,集成時序遞歸模塊能夠將其擴展到視頻去模糊。該網絡能夠模擬空間可變模糊移除而無需多尺度與大卷積核。最後做者經過實驗定性與定量進行分析:在速度、精度以及模型大小方面均取得了SOTA性能。架構

Abstract

​ 針對已有去模糊方法存在的兩個侷限性:(1) 空間不變卷積核,對於動態場景去模糊而言並不是最優方案,嚴重限制了去模糊精度;(2) 經過網絡深度與卷積核尺寸提高擴大感覺野,這會致使模型變大、推理耗時增長。框架

​ 爲此,基於形變卷積與自注意力機制,做者提出一種高效果的端到端的去模糊框架。它與其餘SOTA方法的性能對比見下圖。ide

​ 該方法的優勢包含如下幾點:函數

  • 全卷積且參數高效,僅需一次性前向過程;
  • 可輕易集成其餘架構與損失函數;
  • 網絡估計的變換是動態的,於是能夠自適應處理測試圖像。

Method

​ 上圖給出了做者所提出的SARN網絡架構示意圖,其中編碼子網絡將輸入圖像逐漸變換爲分辨率更小、通道更多的特徵圖,在此基礎上繼續執行空間注意力模塊與形變殘差模塊,最後送入到解碼模塊中,經過一系列的殘差模塊與反捲積對其進行重建。注:上圖中n=32。性能

Deformable Residual Module

​ 傳統的CNN在固定的網格上進行採樣,這限制了其模擬未知幾何變換的能力。STN將空間學習引入到CNN中,然而這種變換比較耗時且爲全局圖像變換,並不適合於局部圖像幾何變換。做者採用形變卷積,它以一種有效的方法學習局部幾何變換。形變卷積首先學習稠密偏移圖進行特徵重採樣,而後再進行卷積操做,該過程見上圖中的從Input FeatureOutput Feature的過程。做者在形變卷積基礎上引入參考模塊,稱之爲形變殘差模塊。更多關於形變卷積的介紹與分析建議參考原文Deformable Convolutional Networks學習

Self-Attention Module

​ 近期的去模糊方法着重於多尺度處理,這種處理方式能夠獲取不一樣尺度的運動模糊,提高網絡的感覺野。儘管這種「自粗而精」的處理策略能夠處理不一樣程度的模糊,可是它沒法從全局角度利用模糊區域之間的相關性,而這對於復原任務也很重要。爲此,做者提出採用:在不一樣空間分辨率利用注意力機制學習非局部關聯性。測試

​ 用於模擬長範圍依賴關係的注意力機制已在多個領域(跨語言與視覺應用)取得了成功。做者採用非局部注意力進行不一樣場景區域之間的關聯性學習並用於提高圖像復原質量。優化

​ 上圖給出了做者所提出的SAM模塊示意圖。它有以下兩點優點:

  • 它克服了感覺野有限的侷限性;
  • 它隱含的提供了一種能夠傳播相對信息的通路。

上述優點使得它適合於處理去模糊,這是由於:因模糊致使的場景-邊緣之間每每是相關的。

​ 以上圖爲例,給定輸入特徵A\in R^{C \times H \times W},首先,將其送入兩個1\times1卷積獲得兩個新的特徵B和C,其中\{B, C\} \in R^{\hat{C} \times H \times W};而後,將其進行reshape爲R^{\hat{C} \times N};其次,對B和C進行矩陣乘操做並執行softmax獲得空間注意力特徵S \in R^{N \times N}(s_{ji}能夠度量i位置與j位置的影響關係),計算方式以下公式所示。最後,將特徵A經由另外一個1\times1卷積獲得特徵D\in R^{C\times H \times W},並reshape爲R^{C \times N},並將其S進行矩陣乘操做獲得加強版特徵,將其與特徵A相加獲得最終的特徵E\in R^{C\times H \times W}

s_{ji} = \frac{\mathcal{exp}(B_i \cdot C_j)}{\sum_{i=1}^N \mathcal{exp}(B_i \cdot C_j)} \notag

​ 經由上述操做獲得的特徵E包含全部位置特徵的加權組合以及原始特徵。所以它具備全局上下文信息,並按照空間注意力進行上下文信息選擇性集成,促使類似特徵加強,不相關特徵削弱。

​ 做者還發現:將SAM至於DRM以前能夠取得更好的性能。猜想緣由爲:早期的特徵加強有助於提高網絡的非局部性。

Video Deblurring

​ 圖像去模糊一種很天然的擴展是視頻去模糊,做者採用LSTM進行先後幀特徵集成,該過程能夠描述爲:

\begin{split}
f^i &= Net_E(B^i, I^{i-1})  \\
h^i, g^i &= ConvLSTM(h^{i-1}, f^i; \theta_{LSTM})  \\
I^i &= Net_D(g^i; \theta_D)
\end{split}

在視頻去模糊中,它以5幀做爲輸入,輸出中間幀的去模糊效果圖。

Experiments

​ 在訓練過程當中,相關參數配置以下:

  • 對於圖像去模糊任務,訓練數據爲GoPro,優化器爲Adam,學習率爲0.0001,BatchSize=4,訓練迭代次數爲1百萬.
  • 對於視頻去模糊任務,優化器Adam,學習率0.0001,BatchSize=4,迭代次數3百萬。

​ 下圖給出在GoPro數據集上相關去模糊方法的性能與視覺效果對比。更多實驗結果與分析建議參考原文,這裏再也不贅述。

​ 下面給出了在視頻去模糊任務上的性能與視覺效果對比。更多實驗結果與分析建議參考原文,這裏再也不贅述。

Concolusion

​ 做者結合形變卷積、自注意力機制提出一種有效的圖像/視頻去模糊方法。其中形變卷積殘差模塊能夠解決局部模糊的局部信息偏移問題;而自注意力機制則能夠對不一樣模糊區域創建關聯性,從而提高特徵性能。自注意力機制與形變卷積都可提高網絡的感覺野,同時具備高效性。最後做者經過實驗驗證了所提方法的SOTA性能。

參考代碼

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.ops import DeformConvPack

# GPU 
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# DeformConv copy from mmdetection.
class DeformResModule(nn.Module):
    def __init__(self, inc, ksize):
        super(DeformResModule, self).__init__()
        pad = (ksize-1)//2
        self.dconv = DeformConvPack(inc,inc,ksize,1,padding=pad)
    def forward(self, x):
        res = self.dconv(x)
        return res + x

class ResBlock(nn.Module):
    def __init__(self, inc, ksize):
        super(ResBlock, self).__init__()
        padding = (ksize-1)//2
        self.conv1 = nn.Conv2d(inc, inc, ksize, 1, padding)
        self.conv2 = nn.Conv2d(inc, inc, ksize, 1, padding)
    def forward(self, x):
        res = self.conv2(F.relu(self.conv1(x)))
        return res + x
        
class SAM(nn.Module):
    def __init__(self, inc):
        super(SAM, self).__init__()
        self.convb = nn.Conv2d(inc, inc, 1)
        self.convc = nn.Conv2d(inc, inc, 1)
        self.convd = nn.Conv2d(inc, inc, 1)
        
    def forward(self, x):
        N, C, H, W = x.size()
        featB = self.convb(x)                         #N,C,H,W
        featC = self.convc(x)                         #N,C,H,W
        featD = self.convd(x)                         #N,C,H,W
        
        featB = featB.reshape(N, C, -1)               #N,C, HW
        featC = featC.reshape(N, C, -1)               #N,C, HW
        featC = featC.permute(0, 2, 1)                #N,HW,C
        
        featD = featD.reshape(N, C, -1)               #N,C, HW
        featD = featD.permute(0, 2, 1)                #N,HW,C
        
        featBC = torch.matmul(featC, featB)           #N,HW,HW
        featBC = featBC.softmax(-1)                   #N,HW,HW
        
        fusion = torch.matmul(featBC, featD)          #N,HW,C
        fusion = fusion.permute(0, 2, 1).contiguous() #N,C, HW
        fusion = fusion.reshape(N, C, H, W)           #N,C,H,W
        
        return x + fusion

class Net(nn.Module):
    def __init__(self, inc, outc, midc):
        super(Net, self).__init__()
        mid2 = midc*2
        mid4 = midc*4
        self.ecode1 = nn.Sequential(nn.Conv2d(inc,midc,3,1,1),
                                    nn.ReLU(),
                                    ResBlock(midc, 3),
                                    ResBlock(midc, 3),
                                    ResBlock(midc, 3))
        self.ecode2 = nn.Sequential(nn.Conv2d(midc,mid2,3,2,1),
                                    nn.ReLU(),
                                    ResBlock(mid2, 3),
                                    ResBlock(mid2, 3),
                                    ResBlock(mid2, 3))
        self.ecode3 = nn.Sequential(nn.Conv2d(mid2,mid4,3,2,1),
                                    nn.ReLU(),
                                    SAM(mid4),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3))
        
        self.dcode2 = nn.Sequential(ResBlock(mid2, 3),
                                    ResBlock(mid2, 3),
                                    ResBlock(mid2, 3))
        self.dcode1 = nn.Sequential(ResBlock(midc, 3),
                                    ResBlock(midc, 3),
                                    ResBlock(midc, 3),
                                    nn.Conv2d(midc, 3, 3, 1, 1))
                
        self.upsample1 = nn.ConvTranspose2d(mid4, mid2, 4, 2, 1)
        self.upsample2 = nn.ConvTranspose2d(mid2, midc, 4, 2, 1)
        
        self.feat1 = nn.Conv2d(midc, midc, 3, 1, 1)
        self.feat2 = nn.Conv2d(midc*2, midc*2, 3, 1, 1)
        
    def forward(self, x):
        encoder1 = self.ecode1(x)
        encoder2 = self.ecode2(encoder1)
        encoder3 = self.ecode3(encoder2)
        decoder3 = self.upsample1(encoder3)
        decoder2 = self.dcode2(decoder3 + self.feat2(encoder2))
        decoder1 = self.upsample2(decoder2)
        output   = self.dcode1(decoder1 + self.feat1(encoder1))
        
        return output
             
        
def main():
    model = Net(3, 3, 32).cuda().eval()
    
    inputs = torch.randn(4, 3, 128, 128).cuda()
    with torch.no_grad():
        output = model(inputs)
    print(output.size())
    
    
if __name__ == "__main__":
    main()
複製代碼
相關文章
相關標籤/搜索