打通多個視覺任務的全能Backbone:HRNet

HRNet是微軟亞洲研究院的王井東老師領導的團隊完成的,打通圖像分類、圖像分割、目標檢測、人臉對齊、姿態識別、風格遷移、Image Inpainting、超分、optical flow、Depth estimation、邊緣檢測等網絡結構。python

王老師在ValseWebinar《物體和關鍵點檢測》中親自講解了HRNet,講解地很是透徹。如下文章主要參考了王老師在演講中的解讀,配合論文+代碼部分,來爲各位讀者介紹這個全能的Backbone-HRNet。git

1. 引入

網絡結構設計思路

在人體姿態識別這類的任務中,須要生成一個高分辨率的heatmap來進行關鍵點檢測。這就與通常的網絡結構好比VGGNet的要求不一樣,由於VGGNet最終獲得的feature map分辨率很低,損失了空間結構。github

傳統的解決思路

獲取高分辨率的方式大部分都是如上圖所示,採用的是先降分辨率,而後再升分辨率的方法。U-Net、SegNet、DeconvNet、Hourglass本質上都是這種結構。網絡

雖然看上去不一樣,可是本質是一致的

2. 核心

普通網絡都是這種結構,不一樣分辨率之間是進行了串聯app

不斷降分辨率

王井東老師則是將不一樣分辨率的feature map進行並聯:ide

並聯不一樣分辨率feature map

在並聯的基礎上,添加不一樣分辨率feature map之間的交互(fusion)。函數

具體fusion的方法以下圖所示:性能

  • 同分辨率的層直接複製。
  • 須要升分辨率的使用bilinear upsample + 1x1卷積將channel數統一。
  • 須要降分辨率的使用strided 3x3 卷積。
  • 三個feature map融合的方式是相加。

至於爲什麼要用strided 3x3卷積,這是由於卷積在降維的時候會出現信息損失,使用strided 3x3卷積是爲了經過學習的方式,下降信息的損耗。因此這裏沒有用maxpool或者組合池化。學習

HR示意圖

另外在讀HRNet的時候會有一個問題,有四個分支的到底如何使用這幾個分支呢?論文中也給出了幾種方式做爲最終的特徵選擇。測試

三種特徵融合方法

(a)圖展現的是HRNetV1的特徵選擇,只使用分辨率最高的特徵圖。

(b)圖展現的是HRNetV2的特徵選擇,將全部分辨率的特徵圖(小的特徵圖進行upsample)進行concate,主要用於語義分割和麪部關鍵點檢測。

(c)圖展現的是HRNetV2p的特徵選擇,在HRNetV2的基礎上,使用了一個特徵金字塔,主要用於目標檢測網絡。

再補充一個(d)圖

HRNetV2分類網絡後的特徵選擇

(d)圖展現的也是HRNetV2,採用上圖的融合方式,主要用於訓練分類網絡。

總結一下HRNet創新點

  • 將高低分辨率之間的連接由串聯改成並聯。
  • 在整個網絡結構中都保持了高分辨率的表徵(最上邊那個通路)。
  • 在高低分辨率中引入了交互來提升模型性能。

3. 效果

3.1 消融實驗

  1. 對交互方法進行消融實驗,證實了當前跨分辨率的融合的有效性。

交互方法的消融實現

  1. 證實高分辨率feature map的表徵能力

1x表明不進行降維,2x表明分辨率變爲原來一半,4x表明分辨率變爲原來四分之一。W3二、W48中的3二、48表明卷積的寬度或者通道數。

3.2 姿態識別任務上的表現

以上的姿態識別採用的是top-down的方法。

COCO驗證集的結果

在參數和計算量不增長的狀況下,要比其餘同類網絡效果好不少。

COCO測試集上的結果

在19年2月28日時的PoseTrack Leaderboard,HRNet佔領兩個項目的第一名。

PoseTrack Leaderboard

3.3 語義分割任務中的表現

CityScape驗證集上的結果對比

Cityscapes測試集上的對比

3.4 目標檢測任務中的表現

單模型單尺度模型對比

Mask R-CNN上結果

3.5 分類任務上的表現

ps: 王井東老師在這部分提到,分割的網絡也須要使用分類的預訓練模型,不然結果會差幾個點。

圖像分類任務中和ResNet進行對比

以上是HRNet和ResNet結果對比,同一個顏色的都是參數量大致一致的模型進行的對比,在參數兩差很少甚至更少的狀況下,HRNet可以比ResNet達到更好的效果。

4. 代碼

HRNet( https://github.com/HRNet )工做量很是大,構建了六個庫涉及語義分割、人體姿態檢測、目標檢測、圖片分類、面部關鍵點檢測、Mask R-CNN等庫。所有內容以下圖所示:

筆者對HRNet代碼構建很是感興趣,因此以HRNet-Image-Classification庫爲例,來解析一下這部分代碼。

先從簡單的入手,BasicBlock

BasicBlock結構

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

Bottleneck:

Bottleneck結構圖

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion,
                                  momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

HighResolutionModule,這是核心模塊, 主要分爲兩個組件:branches和fuse layer。

class HighResolutionModule(nn.Module):
    def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
                 num_channels, fuse_method, multi_scale_output=True):
        '''
        調用:
        # 調用高低分辨率交互模塊, stage2 爲例
        HighResolutionModule(num_branches, # 2
                             block, # 'BASIC'
                             num_blocks, # [4, 4]
                             num_inchannels, # 上個stage的out channel
                             num_channels, # [32, 64]
                             fuse_method, # SUM
                             reset_multi_scale_output)
        '''
        super(HighResolutionModule, self).__init__()
        self._check_branches(
            # 檢查分支數目是否合理
            num_branches, blocks, num_blocks, num_inchannels, num_channels)

        self.num_inchannels = num_inchannels
        # 融合選用相加的方式
        self.fuse_method = fuse_method
        self.num_branches = num_branches

        self.multi_scale_output = multi_scale_output

        # 兩個核心部分,一個是branches構建,一個是融合layers構建
        self.branches = self._make_branches(
            num_branches, blocks, num_blocks, num_channels)
        self.fuse_layers = self._make_fuse_layers()

        self.relu = nn.ReLU(False)

    def _check_branches(self, num_branches, blocks, num_blocks,
                        num_inchannels, num_channels):
        # 分別檢查參數是否符合要求,看models.py中的參數,blocks參數冗餘了
        if num_branches != len(num_blocks):
            error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
                num_branches, len(num_blocks))
            logger.error(error_msg)
            raise ValueError(error_msg)

        if num_branches != len(num_channels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
                num_branches, len(num_channels))
            logger.error(error_msg)
            raise ValueError(error_msg)

        if num_branches != len(num_inchannels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
                num_branches, len(num_inchannels))
            logger.error(error_msg)
            raise ValueError(error_msg)

    def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
                         stride=1):
        # 構建一個分支,一個分支重複num_blocks個block
        downsample = None

        # 這裏判斷,若是通道變大(分辨率變小),則使用下采樣
        if stride != 1 or \
           self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.num_inchannels[branch_index],
                          num_channels[branch_index] * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
                               momentum=BN_MOMENTUM),
            )

        layers = []
        layers.append(block(self.num_inchannels[branch_index],
                            num_channels[branch_index], stride, downsample))

        self.num_inchannels[branch_index] = \
            num_channels[branch_index] * block.expansion

        for i in range(1, num_blocks[branch_index]):
            layers.append(block(self.num_inchannels[branch_index],
                                num_channels[branch_index]))

        return nn.Sequential(*layers)

    def _make_branches(self, num_branches, block, num_blocks, num_channels):
        branches = []
        
        # 經過循環構建多分支,每一個分支屬於不一樣的分辨率
        for i in range(num_branches):
            branches.append(
                self._make_one_branch(i, block, num_blocks, num_channels))

        return nn.ModuleList(branches)

    def _make_fuse_layers(self):
        if self.num_branches == 1:
            return None

        num_branches = self.num_branches # 2
        num_inchannels = self.num_inchannels
        fuse_layers = []
        for i in range(num_branches if self.multi_scale_output else 1):
            # i表明枚舉全部分支
            fuse_layer = []
            for j in range(num_branches):
                # j表明處理的當前分支
                if j > i: # 進行上採樣,使用最近鄰插值
                    fuse_layer.append(nn.Sequential(
                        nn.Conv2d(num_inchannels[j],
                                  num_inchannels[i],
                                  1,
                                  1,
                                  0,
                                  bias=False),
                        nn.BatchNorm2d(num_inchannels[i],
                                       momentum=BN_MOMENTUM),
                        nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
                elif j == i:
                    # 本層不作處理
                    fuse_layer.append(None)
                else:
                    conv3x3s = []
                    # 進行strided 3x3 conv下采樣,若是跨兩層,就使用兩次strided 3x3 conv
                    for k in range(i-j):
                        if k == i - j - 1:
                            num_outchannels_conv3x3 = num_inchannels[i]
                            conv3x3s.append(nn.Sequential(
                                nn.Conv2d(num_inchannels[j],
                                          num_outchannels_conv3x3,
                                          3, 2, 1, bias=False),
                                nn.BatchNorm2d(num_outchannels_conv3x3,
                                               momentum=BN_MOMENTUM)))
                        else:
                            num_outchannels_conv3x3 = num_inchannels[j]
                            conv3x3s.append(nn.Sequential(
                                nn.Conv2d(num_inchannels[j],
                                          num_outchannels_conv3x3,
                                          3, 2, 1, bias=False),
                                nn.BatchNorm2d(num_outchannels_conv3x3,
                                nn.ReLU(False)))
                    fuse_layer.append(nn.Sequential(*conv3x3s))
            fuse_layers.append(nn.ModuleList(fuse_layer))

        return nn.ModuleList(fuse_layers)

    def get_num_inchannels(self):
        return self.num_inchannels

    def forward(self, x):
        if self.num_branches == 1:
            return [self.branches[0](x[0])]

        for i in range(self.num_branches):
            x[i]=self.branches[i](x[i])

        x_fuse=[]
        for i in range(len(self.fuse_layers)):
            y=x[0] if i == 0 else self.fuse_layers[i][0](x[0])
            for j in range(1, self.num_branches):
                if i == j:
                    y=y + x[j]
                else:
                    y=y + self.fuse_layers[i][j](x[j])
            x_fuse.append(self.relu(y))

        # 將fuse之後的多個分支結果保存到list中
        return x_fuse

models.py中保存的參數, 能夠經過這些配置來改變模型的容量、分支個數、特徵融合方法:

# high_resoluton_net related params for classification
POSE_HIGH_RESOLUTION_NET = CN()
POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*']
POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64
POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1
POSE_HIGH_RESOLUTION_NET.WITH_HEAD = True

POSE_HIGH_RESOLUTION_NET.STAGE2 = CN()
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4]
POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64]
POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM'

POSE_HIGH_RESOLUTION_NET.STAGE3 = CN()
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4]
POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128]
POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM'

POSE_HIGH_RESOLUTION_NET.STAGE4 = CN()
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC'
POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM'

而後來看整個HRNet模型的構建, 因爲總體代碼量太大,這裏僅僅來看forward函數。

def forward(self, x):

    # 使用兩個strided 3x3conv進行快速降維
    x=self.relu(self.bn1(self.conv1(x)))
    x=self.relu(self.bn2(self.conv2(x)))

    # 構建了一串BasicBlock構成的模塊
    x=self.layer1(x)

    # 而後是多個stage,每一個stage核心是調用HighResolutionModule模塊
    x_list=[]
    for i in range(self.stage2_cfg['NUM_BRANCHES']):
        if self.transition1[i] is not None:
            x_list.append(self.transition1[i](x))
        else:
            x_list.append(x)
    y_list=self.stage2(x_list)

    x_list=[]
    for i in range(self.stage3_cfg['NUM_BRANCHES']):
        if self.transition2[i] is not None:
            x_list.append(self.transition2[i](y_list[-1]))
        else:
            x_list.append(y_list[i])
    y_list=self.stage3(x_list)

    x_list=[]
    for i in range(self.stage4_cfg['NUM_BRANCHES']):
        if self.transition3[i] is not None:
            x_list.append(self.transition3[i](y_list[-1]))
        else:
            x_list.append(y_list[i])
    y_list=self.stage4(x_list)

    # 添加分類頭,上文中有顯示,在分類問題中添加這種頭
    # 在其餘問題中換用不一樣的頭
    y=self.incre_modules[0](y_list[0])
    for i in range(len(self.downsamp_modules)):
        y=self.incre_modules[i+1](y_list[i+1]) + \
            self.downsamp_modules[i](y)
    y=self.final_layer(y)

    if torch._C._get_tracing_state():
        # 在不寫C代碼的狀況下執行forward,直接用python版本
        y=y.flatten(start_dim=2).mean(dim=2)
    else:
        y=F.avg_pool2d(y, kernel_size=y.size()
                            [2:]).view(y.size(0), -1)
    y=self.classifier(y)

    return y

5. 總結

HRNet核心方法是:在模型的整個過程當中,保存高分辨率表徵的同時使用讓不一樣分辨率的feature map進行特徵交互。

HRNet在很是多的CV領域有普遍的應用,好比ICCV2019的東北虎關鍵點識別比賽中,HRNet就起到了必定的做用。而且在分類部分的實驗證實了在同等參數量的狀況下,能夠取代ResNet進行分類。

以前看鄭安坤大佬的一篇文章CNN結構設計技巧-兼顧速度精度與工程實現中提到了一點:

senet是hrnet的一個特例,hrnet不只有通道注意力,同時也有空間注意力

-- akkaze-鄭安坤

SELayer核心實現

SELayer首先經過一個全局平均池化獲得一個一維向量,而後經過兩個全鏈接層,將信息進行壓縮和擴展,經過sigmoid之後獲得每一個通道的權值,而後用這個權值與原來的feature map相乘,進行信息上的優化。

HRNet一個結構

能夠看到上圖用紅色箭頭串起來的是否是和SELayer很類似。爲何說SENet是HRNet的一個特例,但從這個結構來說,能夠這麼看:

  • SENet沒有像HRNet這樣分辨率變爲原來的一半,分辨率直接變爲1x1,比較極端。變爲1x1向量之後,SENet中使用了兩個全鏈接網絡來學習通道的特徵分佈;可是在HRNet中,使用了幾個卷積(Residual block)來學習特徵。
  • SENet在主幹部分(高分辨率分支)沒有安排卷積進行特徵的學習;HRNet中主幹部分(高分辨率分支)安排了幾個卷積(Residual block)來學習特徵。
  • 特徵融合部分SENet和HRNet區分比較大,SENet使用的對應通道相乘的方法,HRNet則使用的是相加。之因此說SENet是通道注意力機制是由於經過全局平均池化後沒有了空間特徵,只剩通道的特徵;HRNet則能夠看做同時保留了空間特徵和通道特徵,因此說HRNet不只有通道注意力,同時也有空間注意力。

HRNet團隊有10人之多,構建了分類、分割、檢測、關鍵點檢測等庫,工做量很是大,並且作了不少紮實的實驗證實了這種思路的有效性。因此是否能夠認爲HRNet屬於SENet以後又一個更優的backbone呢?還須要本身實踐中使用這種想法和思路來驗證。

6. 參考

https://arxiv.org/pdf/1908.07919

https://www.bilibili.com/video/BV1WJ41197dh?t=508

https://github.com/HRNet

相關文章
相關標籤/搜索