【CV中的Attention機制】融合Non-Local和SENet的GCNet

前言: 以前已經介紹過SENet和Non Local Neural Network(NLNet),二者都是有效的注意力模塊。做者發現NLNet中attention maps在不一樣位置的響應幾乎一致,並結合SENet後,提出了Global Context block,用於全局上下文建模,在主流的benchmarks中的結果優於SENet和NLNet。python

GCNet論文名稱爲:《GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond》,是由清華大學提出的一個注意力模型,與SE block、Non Local block相似,提出了GC block。爲了克服NL block計算量過大的缺點,提出了一個Simplified NL block,因爲其與SE block結構的類似性,因而在其基礎上結合SE改進獲得GC block。git

SENet中提出的SE block是使用全局上下文對不一樣通道進行權值重標定,對通道依賴進行調整。可是採用這種方法,並無充分利用全局上下文信息。github

捕獲長距離依賴關係的目標是對視覺場景進行全局理解,對不少計算機視覺任務都有效,好比圖片分類、視頻分類、目標檢測、語義分割等。而NLNet就是經過自注意力機制來對長距離依賴關係進行建模。學習

做者對NLNet進行試驗,選擇COCO數據集中的6幅圖,對於不一樣的查詢點(query point)分別對Attention maps進行可視化,獲得如下結果:測試

能夠看出,對於不一樣的查詢點,其attention map是幾乎一致的,這說明NLNet學習到的是獨立於查詢的依賴(query-independent dependency),這說明雖然NLNet想要對每個位置進行特定的全局上下文計算,可是可視化結果以及實驗數據證實,全局上下文不受位置依賴。優化

基於以上發現,做者但願可以減小沒必要要的計算量,下降計算,並結合SENet設計,提出了GCNet融合了二者的優勢,既可以有用NLNet的全局上下文建模能力,又可以像SENet同樣輕量。ui

做者首先針對NLNet的問題,提出了一個Simplified NLNet, 極大地減小了計算量。spa

NLNet 中的Non-Local block能夠表示爲: $$ z_i=x_i+W_z\sum^{N_p}{j=1}\frac{f(x_i,x_j)}{C(x)}(W_v×x_j) $$ 輸入的feature map定義爲$x={x_i}^{N_p}{i=1}$, $N_p$是位置數量。$x和z$是NL block輸入和輸出。$i$是位置索引,$j$枚舉全部可能位置。$f(x_i,x_j)$表示位置$i和j$的關係,$C(x)$是歸一化因子。$W_z和W_v是線性轉換矩陣。設計

NLNet中提出了四個類似度計算模型,其效果是大概類似的。做者以Embedded Gaussian爲基礎進行改進,能夠表達爲: $$ W_{ij}=\frac{exp(W_qx_i,W_kx_j)}{\sum_{m}exp(W_qx_i,W_kx_m)} $$ 簡化後版本的Simplified NLNet想要經過計算一個全局注意力便可,能夠表達爲: $$ z_i=x_i+W_v\sum^{N_p}{j=1}\frac{exp(W_kx_j)}{\sum^{N_p}{m=1}exp(W_kx_m)}x_j $$ 這裏的$W_v、W_q、W_k$都是$1\times1$卷積,具體實現能夠參考上圖。code

簡化後的NLNet雖然計算量下去了,可是準確率並無提高,因此做者觀察到SENet與當前的模塊有必定的類似性,因此結合了SENet模塊,提出了GCNet。

能夠看出,GCNet在上下文信息建模這個地方使用了Simplified NL block中的機制,能夠充分利用全局上下文信息,同時在Transform階段借鑑了SE block。

GC block在ResNet中的使用位置是每兩個Stage之間的鏈接部分,下邊是GC block的官方實現(基於mmdetection進行修改):

代碼實現:

import torch
from torch import nn

class ContextBlock(nn.Module):
    def __init__(self,inplanes,ratio,pooling_type='att',
                 fusion_types=('channel_add', )):
        super(ContextBlock, self).__init__()
        valid_fusion_types = ['channel_add', 'channel_mul']

        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'

        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types

        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None


    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)
        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)
        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term
        return out

if __name__ == "__main__":
    in_tensor = torch.ones((12, 64, 128, 128))

    cb = ContextBlock(inplanes=64, ratio=1./16.,pooling_type='att')
    
    out_tensor = cb(in_tensor)

    print(in_tensor.shape)
    print(out_tensor.shape)

對這個模塊進行了測試,須要說明的是,若是ratio × inplanes < 1, 將會出問題,這與通道個數有關,通道的個數是沒法小於1的。

實驗部分

做者基於mmdetection進行修改,添加了GC block,如下是消融實驗。

  • 從block設計來說,能夠看出Simplified NL與NL幾乎一直,可是參數量要小一些。而每一個階段都使用GC block的狀況下能比baseline提升2-3%。
  • 從添加位置進行試驗,在residual block中添加在add操做以後效果最好。
  • 從添加的不一樣階段來看,施加在三個階段效果最優,能比baseline高1-3%。

  • Bottleneck設計方面,測試使用縮放、ReLU、LayerNorm等組合,發現使用簡化版的NLNet並使用1×1卷積做爲transform的時候效果最好,可是其計算量太大。

  • 縮放因子設計:發現ratio=1/4的時候效果最好。

  • 池化和特徵融合設計:分別使用average pooling和attention pooling與add、scale方法進行組合實驗,發現attention pooling+add的方法效果最好。

此處對ImageNet數據集進行了實驗,提高大概在1%之內。

在動做識別數據集Kinetics中,也取得了1%左右的提高。

總結:GCNet結合了SENet和Non Local的優勢,在知足計算量相對較小的同時,優化了全局上下文建模能力,以後進行了詳盡的消融實驗證實了其在目標檢測、圖像分類、動做識別等視覺任務中的有效性。這篇論文值得多讀幾遍。


參考:

論文地址:https://arxiv.org/abs/1904.11492

官方實現代碼:https://github.com/xvjiarui/GCNet

文章中核心代碼:https://github.com/pprp/SimpleCVReproduction/tree/master/attention/GCBlock

相關文章
相關標籤/搜索