(原)堆疊hourglass網絡

轉載請註明出處:html

http://www.javashuo.com/article/p-zfyewxqz-e.htmlgit

論文:github

https://arxiv.org/abs/1603.06937網絡

官方torch代碼(沒具體看):app

https://github.com/princeton-vl/pose-hg-demoide

第三方pytorch代碼(位於models/StackedHourGlass.py):this

https://github.com/Naman-ntc/Pytorch-Human-Pose-Estimationspa

1. 簡介

該論文利用多尺度特徵來識別姿態,以下圖所示,每一個子網絡稱爲hourglass Network,是一個沙漏型的結構,多個這種結構堆疊起來,稱做stacked hourglass。堆疊的方式,方便每一個模塊在整個圖像上從新估計姿態和特徵。以下圖所示,輸入圖像經過全卷積網絡fcn後,獲得特徵,然後經過多個堆疊的hourglass,獲得最終的熱圖。code

Hourglass以下圖所示。其中每一個方塊均爲下下圖的殘差模塊。orm

Hourglass採用了中間監督(Intermediate Supervision)。每一個hourglass均會有熱圖(藍色)。訓練階段,將這些熱圖和真實熱圖計算損失MSE,並求和,獲得損失;推斷階段,使用的是最後一個hourglass的熱圖。

2. stacked hourglass

堆疊hourglass結構以下圖所示(nChannels=256,nStack=2,nModules=2,numReductions=4, nJoints=17):

代碼以下:

 1 class StackedHourGlass(nn.Module):
 2     """docstring for StackedHourGlass"""
 3     def __init__(self, nChannels, nStack, nModules, numReductions, nJoints):
 4         super(StackedHourGlass, self).__init__()
 5         self.nChannels = nChannels
 6         self.nStack = nStack
 7         self.nModules = nModules
 8         self.numReductions = numReductions
 9         self.nJoints = nJoints
10 
11         self.start = M.BnReluConv(3, 64, kernelSize = 7, stride = 2, padding = 3)  # BN+ReLU+conv
12 
13         self.res1 = M.Residual(64, 128) # 輸入和輸出不等,輸入經過1*1conv結果和3*(BN+ReLU+conv)求和
14         self.mp = nn.MaxPool2d(2, 2)
15         self.res2 = M.Residual(128, 128) # 輸入和輸出相等,爲x+3*(BN+ReLU+conv)
16         self.res3 = M.Residual(128, self.nChannels) # 輸入和輸出相等,爲x+3*(BN+ReLU+conv);不然輸入經過1*1conv結果和3*(BN+ReLU+conv)求和。
17 
18         _hourglass, _Residual, _lin1, _chantojoints, _lin2, _jointstochan = [],[],[],[],[],[]
19 
20         for _ in range(self.nStack):  # 堆疊個數
21             _hourglass.append(Hourglass(self.nChannels, self.numReductions, self.nModules))
22             _ResidualModules = []
23             for _ in range(self.nModules):
24                 _ResidualModules.append(M.Residual(self.nChannels, self.nChannels))   # 輸入和輸出相等,爲x+3*(BN+ReLU+conv)
25             _ResidualModules = nn.Sequential(*_ResidualModules)
26             _Residual.append(_ResidualModules)   # self.nModules 個 3*(BN+ReLU+conv)
27             _lin1.append(M.BnReluConv(self.nChannels, self.nChannels))       # BN+ReLU+conv
28             _chantojoints.append(nn.Conv2d(self.nChannels, self.nJoints,1))  # 1*1 conv,維度變換
29             _lin2.append(nn.Conv2d(self.nChannels, self.nChannels,1))        # 1*1 conv,維度不變
30             _jointstochan.append(nn.Conv2d(self.nJoints,self.nChannels,1))   # 1*1 conv,維度變換
31 
32         self.hourglass = nn.ModuleList(_hourglass)
33         self.Residual = nn.ModuleList(_Residual)
34         self.lin1 = nn.ModuleList(_lin1)
35         self.chantojoints = nn.ModuleList(_chantojoints)
36         self.lin2 = nn.ModuleList(_lin2)
37         self.jointstochan = nn.ModuleList(_jointstochan)
38 
39     def forward(self, x):
40         x = self.start(x)
41         x = self.res1(x)
42         x = self.mp(x)
43         x = self.res2(x)
44         x = self.res3(x)
45         out = []
46 
47         for i in range(self.nStack):
48             x1 = self.hourglass[i](x)
49             x1 = self.Residual[i](x1)
50             x1 = self.lin1[i](x1)
51             out.append(self.chantojoints[i](x1))
52             x1 = self.lin2[i](x1)
53             x = x + x1 + self.jointstochan[i](out[i])   # 特徵求和
54 
55         return (out)
View Code

3. hourglass

hourglass在numReductions>1時,遞歸調用本身,結構以下:

代碼以下:

 1 class Hourglass(nn.Module):
 2     """docstring for Hourglass"""
 3     def __init__(self, nChannels = 256, numReductions = 4, nModules = 2, poolKernel = (2,2), poolStride = (2,2), upSampleKernel = 2):
 4         super(Hourglass, self).__init__()
 5         self.numReductions = numReductions
 6         self.nModules = nModules
 7         self.nChannels = nChannels
 8         self.poolKernel = poolKernel
 9         self.poolStride = poolStride
10         self.upSampleKernel = upSampleKernel
11 
12         """For the skip connection, a residual module (or sequence of residuaql modules)  """
13         _skip = []
14         for _ in range(self.nModules):
15             _skip.append(M.Residual(self.nChannels, self.nChannels))  # 輸入和輸出相等,爲x+3*(BN+ReLU+conv)
16         self.skip = nn.Sequential(*_skip)
17 
18         """First pooling to go to smaller dimension then pass input through
19         Residual Module or sequence of Modules then  and subsequent cases:
20             either pass through Hourglass of numReductions-1 or pass through M.Residual Module or sequence of Modules """
21         self.mp = nn.MaxPool2d(self.poolKernel, self.poolStride)
22 
23         _afterpool = []
24         for _ in range(self.nModules):
25             _afterpool.append(M.Residual(self.nChannels, self.nChannels))  # 輸入和輸出相等,爲x+3*(BN+ReLU+conv)
26         self.afterpool = nn.Sequential(*_afterpool)
27 
28         if (numReductions > 1):
29             self.hg = Hourglass(self.nChannels, self.numReductions-1, self.nModules, self.poolKernel, self.poolStride)  # 嵌套調用自己
30         else:
31             _num1res = []
32             for _ in range(self.nModules):
33                 _num1res.append(M.Residual(self.nChannels,self.nChannels))  # 輸入和輸出相等,爲x+3*(BN+ReLU+conv)
34             self.num1res = nn.Sequential(*_num1res)  # doesnt seem that important ?
35 
36         """ Now another M.Residual Module or sequence of M.Residual Modules  """
37         _lowres = []
38         for _ in range(self.nModules):
39             _lowres.append(M.Residual(self.nChannels,self.nChannels))   # 輸入和輸出相等,爲x+3*(BN+ReLU+conv)
40         self.lowres = nn.Sequential(*_lowres)
41 
42         """ Upsampling Layer (Can we change this??????) As per Newell's paper upsamping recommended  """
43         self.up = myUpsample()#nn.Upsample(scale_factor = self.upSampleKernel)   # 將高和寬擴充爲原來2倍,實現上採樣
44 
45 
46     def forward(self, x):
47         out1 = x
48         out1 = self.skip(out1)          # 輸入和輸出相等,爲x+3*(BN+ReLU+conv)
49         out2 = x
50         out2 = self.mp(out2)            # 降維
51         out2 = self.afterpool(out2)     # 輸入和輸出相等,爲x+3*(BN+ReLU+conv)
52         if self.numReductions>1:
53             out2 = self.hg(out2)        # 嵌套調用自己
54         else:
55             out2 = self.num1res(out2)   # 輸入和輸出相等,爲x+3*(BN+ReLU+conv)
56         out2 = self.lowres(out2)        # 輸入和輸出相等,爲x+3*(BN+ReLU+conv)
57         out2 = self.up(out2)            # 升維
58 
59         return out2 + out1              # 求和
View Code

4. 上採樣myUpsample

上採樣代碼以下:

1 class myUpsample(nn.Module):
2     def __init__(self):
3         super(myUpsample, self).__init__()
4         pass
5     def forward(self, x):   # 將高和寬擴充爲原來2倍,實現上採樣
6         return x[:, :, :, None, :, None].expand(-1, -1, -1, 2, -1, 2).reshape(x.size(0), x.size(1), x.size(2)*2, x.size(3)*2)
View Code

其中x爲(N)(C)(H)(W)的矩陣,x[:, :, :, None, :, None]爲(N)(C)(H)(1)(W)(1)的矩陣,expand以後變成(N)(C)(H)(2)(W)(2)的矩陣,最終reshape以後變成(N)(C)(2H) (2W)的矩陣,實現了將1個像素水平和垂直方向各擴充2倍,變成4個像素(4個像素值相同),完成了上採樣。

5. 殘差模塊

殘差模塊結構以下:

代碼以下:

 1 class Residual(nn.Module):
 2         """docstring for Residual"""  # 輸入和輸出相等,爲x+3*(BN+ReLU+conv);不然輸入經過1*1conv結果和3*(BN+ReLU+conv)求和
 3         def __init__(self, inChannels, outChannels):
 4                 super(Residual, self).__init__()
 5                 self.inChannels = inChannels
 6                 self.outChannels = outChannels
 7                 self.cb = ConvBlock(inChannels, outChannels)      # 3 * (BN+ReLU+conv) 其中第一組降維,第二組不變,第三組升維
 8                 self.skip = SkipLayer(inChannels, outChannels)    # 輸入和輸出通道相等,則輸出=輸入,不然爲1*1 conv
 9 
10         def forward(self, x):
11                 out = 0
12                 out = out + self.cb(x)
13                 out = out + self.skip(x)
14                 return out
View Code

其中skiplayer代碼以下:

 1 class SkipLayer(nn.Module):
 2         """docstring for SkipLayer"""  # 輸入和輸出通道相等,則輸出=輸入,不然爲1*1 conv
 3         def __init__(self, inChannels, outChannels):
 4                 super(SkipLayer, self).__init__()
 5                 self.inChannels = inChannels
 6                 self.outChannels = outChannels
 7                 if (self.inChannels == self.outChannels):
 8                         self.conv = None
 9                 else:
10                         self.conv = nn.Conv2d(self.inChannels, self.outChannels, 1)
11 
12         def forward(self, x):
13                 if self.conv is not None:
14                         x = self.conv(x)
15                 return x
View Code

6. conv

 1 class BnReluConv(nn.Module):
 2         """docstring for BnReluConv"""    # BN+ReLU+conv
 3         def __init__(self, inChannels, outChannels, kernelSize = 1, stride = 1, padding = 0):
 4                 super(BnReluConv, self).__init__()
 5                 self.inChannels = inChannels
 6                 self.outChannels = outChannels
 7                 self.kernelSize = kernelSize
 8                 self.stride = stride
 9                 self.padding = padding
10 
11                 self.bn = nn.BatchNorm2d(self.inChannels)
12                 self.conv = nn.Conv2d(self.inChannels, self.outChannels, self.kernelSize, self.stride, self.padding)
13                 self.relu = nn.ReLU()
14 
15         def forward(self, x):
16                 x = self.bn(x)
17                 x = self.relu(x)
18                 x = self.conv(x)
19                 return x
View Code

7. ConvBlock

 1 class ConvBlock(nn.Module):
 2         """docstring for ConvBlock"""  # 3 * (BN+ReLU+conv) 其中第一組降維,第二組不變,第三組升維
 3         def __init__(self, inChannels, outChannels):
 4                 super(ConvBlock, self).__init__()
 5                 self.inChannels = inChannels
 6                 self.outChannels = outChannels
 7                 self.outChannelsby2 = outChannels//2
 8 
 9                 self.cbr1 = BnReluConv(self.inChannels, self.outChannelsby2, 1, 1, 0)        # BN+ReLU+conv
10                 self.cbr2 = BnReluConv(self.outChannelsby2, self.outChannelsby2, 3, 1, 1)    # BN+ReLU+conv
11                 self.cbr3 = BnReluConv(self.outChannelsby2, self.outChannels, 1, 1, 0)       # BN+ReLU+conv
12 
13         def forward(self, x):
14                 x = self.cbr1(x)
15                 x = self.cbr2(x)
16                 x = self.cbr3(x)
17                 return x
View Code
相關文章
相關標籤/搜索