【即插即用】Triplet Attention機制讓Channel和Spatial交互更加豐富(附開源代碼)

點擊上方【AI人工智能初學者】,選擇【星標】公衆號,期待您個人相遇與進步

  • 論文下載地址和代碼開源地址:https://github.com/LandskapeAI/triplet-attention
    https://arxiv.org/abs/2010.03045

在本文中研究了輕量且有效的注意力機制,並提出了Triplet Attention,該注意力機制是一種經過使用Triplet Branch結構捕獲跨維度交互來計算注意力權重的新方法。對於輸入張量,Triplet Attention經過旋轉操做和殘差變換創建維度間的依存關係,並以可忽略的計算開銷對通道和空間信息進行編碼。該方法既簡單又有效,而且能夠輕鬆地插入經典Backbone中。node

一、簡介和相關方法

最近許多工做提出使用Channel Attention或Spatial Attention,或二者結合起來提升神經網絡的性能。這些Attention機制經過創建Channel之間的依賴關係或加權空間注意Mask有能力改善由標準CNN生成的特徵表示。學習注意力權重背後是讓網絡有能力學習關注哪裏,並進一步關注目標對象。這裏列舉一些具備表明的工做:
一、SENet(Squeeze and Excite module)
二、CBAM(Convolutional Block Attention Module)
三、BAM(Bottleneck Attention Module)
四、Grad-CAM
五、Grad-CAM++
六、 -Nets(Double Attention Networks)
七、NL(Non-Local blocks)
八、GSoP-Net(Global Second order Pooling Networks)
九、GC-Net(Global Context Networks)
十、CC-Net(Criss-Cross Networks)
十一、SPNet
等等方法(這些方法都值得你們去學習和調研,說不定會給你的項目帶來意想不到的效果)。
以上大多數方法都有明顯的缺點(Cross-dimension),Triplet Attention解決了這些缺點。Triplet Attention模塊旨在捕捉Cross-dimension交互,從而可以在一個合理的計算開銷內(與上述方法相比能夠忽略不計)提供顯著的性能收益。git

二、本文方法

2.一、分析

本文的目標是研究如何在不涉及任何維數下降的狀況下創建廉價但有效的通道注意力模型。Triplet Attention不像CBAM和SENet須要必定數量的可學習參數來創建通道間的依賴關係,本文提出了一個幾乎無參數的注意機制來建模通道注意和空間注意,即Triplet Attention。github

2.二、Triplet Attention

所提出的Triplet Attention見下圖所示。顧名思義,Triplet Attention由3個平行的Branch組成,其中兩個負責捕獲通道C和空間H或W之間的跨維交互。最後一個Branch相似於CBAM,用於構建Spatial Attention。最終3個Branch的輸出使用平均進行聚合。web

一、Cross-Dimension Interaction

傳統的計算通道注意力的方法涉及計算一個權值,而後使用權值統一縮放這些特徵圖。可是在考慮這種方法時,有一個重要的缺失。一般,爲了計算這些通道的權值,輸入張量在空間上經過全局平均池化分解爲一個像素。這致使了空間信息的大量丟失,所以在單像素通道上計算注意力時,通道維數和空間維數之間的相互依賴性也不存在。微信

雖而後期提出基於Spatial和Channel的CBAM模型緩解了空間相互依賴的問題,可是依然存在一個問題,即,通道注意和空間注意是分離的,計算是相互獨立的。基於創建空間注意力的方法,本文提出了跨維度交互做用(cross dimension interaction)的概念,經過捕捉空間維度和輸入張量通道維度之間的交互做用,解決了這一問題。網絡

這裏是經過三個分支分別捕捉輸入張量的(C, H),(C, W)和(H, W)維間的依賴關係來引入Triplet Attention中的跨維交互做用。架構

二、Z-pool

Z-pool層負責將C維度的Tensor縮減到2維,將該維上的平均聚集特徵和最大聚集特徵鏈接起來。這使得該層可以保留實際張量的豐富表示,同時縮小其深度以使進一步的計算量更輕。能夠用下式表示:app

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=11)

三、Triplet Attention

給定一個輸入張量 ,首先將其傳遞到Triplet Attention模塊中的三個分支中。編輯器

在第1個分支中,在H維度和C維度之間創建了交互:ide

爲了實現這一點,輸入張量 沿H軸逆時針旋轉90°。這個旋轉張量表示爲 的形狀爲(W×H×C),再而後通過Z-Pool後的張量 的shape爲(2×H×C),而後, 經過內核大小爲k×k的標準卷積層,再經過批處理歸一化層,提供維數(1×H×C)的中間輸出。而後,經過將張量經過sigmoid來生成的注意力權值。在最後輸出是沿着H軸進行順時針旋轉90°保持和輸入的shape一致。

在第2個分支中,在C維度和W維度之間創建了交互:

爲了實現這一點,輸入張量 沿W軸逆時針旋轉90°。這個旋轉張量表示爲 的形狀爲(H×C×W),再而後通過Z-Pool後的張量 的shape爲(2×C×W ),而後, 經過內核大小爲k×k的標準卷積層,再經過批處理歸一化層,提供維數(1×C×W)的中間輸出。而後,經過將張量經過sigmoid來生成的注意力權值。在最後輸出是沿着W軸進行順時針旋轉90°保持和輸入的shape一致。

在第3個分支中,在H維度和W維度之間創建了交互:

輸入張量 的通道經過Z-pool將變量簡化爲2。將這個形狀的簡化張量(2×H×W)簡化後經過核大小k定義的標準卷積層,而後經過批處理歸一化層。輸出經過sigmoid激活層生成形狀爲(1×H×W)的注意權值,並將其應用於輸入 ,獲得結果 。而後經過簡單的平均將3個分支產生的精細張量(C×H×W)聚合在一塊兒。

**最終輸出的Tensor:

class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=Trueif bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(21, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid_(x_out) 
        return x * scale

class TripletAttention(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg''max'], no_spatial=False):
        super(TripletAttention, self).__init__()
        self.ChannelGateH = SpatialGate()
        self.ChannelGateW = SpatialGate()
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_perm1 = x.permute(0,2,1,3).contiguous()
        x_out1 = self.ChannelGateH(x_perm1)
        x_out11 = x_out1.permute(0,2,1,3).contiguous()
        x_perm2 = x.permute(0,3,2,1).contiguous()
        x_out2 = self.ChannelGateW(x_perm2)
        x_out21 = x_out2.permute(0,3,2,1).contiguous()
        if not self.no_spatial:
            x_out = self.SpatialGate(x)
            x_out = (1/3)*(x_out + x_out11 + x_out21)
        else:
            x_out = (1/2)*(x_out11 + x_out21)
        return x_out

四、Complexity Analysis

經過與其餘標準注意力機制的比較,驗證了Triplet Attention的效率,C爲該層的輸入通道數,r爲MLP在計算通道注意力時瓶頸處使用的縮減比,用於2D卷積的核大小用k表示,k<<<C。

三、實驗結果

3.一、圖像分類實驗

3.二、目標檢測實驗

3.三、消融實驗

3.四、HeatMap輸出對比

四、總結

在這項工做中提出了一個新的注意力機制Triplet Attention,它抓住了張量中各個維度特徵的重要性。Triplet Attention使用了一種有效的注意計算方法,不存在任何信息瓶頸。實驗證實,Triplet Attention提升了ResNet和MobileNet等標準神經網絡架構在ImageNet上的圖像分類和MS COCO上的目標檢測等任務上的Baseline性能,而只引入了最小的計算開銷。是一個很是不錯的即插即用的注意力模塊。

更爲詳細內容能夠參見論文中的描述。

References

[1] Rotate to Attend: Convolutional Triplet Attention Module

聲明:轉載請說明出處

掃描下方二維碼關注【AI人工智能初學者】公衆號,獲取更多實踐項目源碼和論文解讀,很是期待你個人相遇,讓咱們以夢爲馬,砥礪前行!!!

點「在看」給我一朵小黃花唄

本文分享自微信公衆號 - AI人工智能初學者(ChaucerG)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。

相關文章
相關標籤/搜索