轉載請註明出處:html
http://www.javashuo.com/article/p-zfyewxqz-e.htmlgit
論文:github
https://arxiv.org/abs/1603.06937網絡
官方torch代碼(沒具體看):app
https://github.com/princeton-vl/pose-hg-demoide
第三方pytorch代碼(位於models/StackedHourGlass.py):this
https://github.com/Naman-ntc/Pytorch-Human-Pose-Estimationspa
該論文利用多尺度特徵來識別姿態,以下圖所示,每一個子網絡稱爲hourglass Network,是一個沙漏型的結構,多個這種結構堆疊起來,稱做stacked hourglass。堆疊的方式,方便每一個模塊在整個圖像上從新估計姿態和特徵。以下圖所示,輸入圖像經過全卷積網絡fcn後,獲得特徵,然後經過多個堆疊的hourglass,獲得最終的熱圖。code
Hourglass以下圖所示。其中每一個方塊均爲下下圖的殘差模塊。orm
Hourglass採用了中間監督(Intermediate Supervision)。每一個hourglass均會有熱圖(藍色)。訓練階段,將這些熱圖和真實熱圖計算損失MSE,並求和,獲得損失;推斷階段,使用的是最後一個hourglass的熱圖。
堆疊hourglass結構以下圖所示(nChannels=256,nStack=2,nModules=2,numReductions=4, nJoints=17):
代碼以下:
1 class StackedHourGlass(nn.Module): 2 """docstring for StackedHourGlass""" 3 def __init__(self, nChannels, nStack, nModules, numReductions, nJoints): 4 super(StackedHourGlass, self).__init__() 5 self.nChannels = nChannels 6 self.nStack = nStack 7 self.nModules = nModules 8 self.numReductions = numReductions 9 self.nJoints = nJoints 10 11 self.start = M.BnReluConv(3, 64, kernelSize = 7, stride = 2, padding = 3) # BN+ReLU+conv 12 13 self.res1 = M.Residual(64, 128) # 輸入和輸出不等,輸入經過1*1conv結果和3*(BN+ReLU+conv)求和 14 self.mp = nn.MaxPool2d(2, 2) 15 self.res2 = M.Residual(128, 128) # 輸入和輸出相等,爲x+3*(BN+ReLU+conv) 16 self.res3 = M.Residual(128, self.nChannels) # 輸入和輸出相等,爲x+3*(BN+ReLU+conv);不然輸入經過1*1conv結果和3*(BN+ReLU+conv)求和。 17 18 _hourglass, _Residual, _lin1, _chantojoints, _lin2, _jointstochan = [],[],[],[],[],[] 19 20 for _ in range(self.nStack): # 堆疊個數 21 _hourglass.append(Hourglass(self.nChannels, self.numReductions, self.nModules)) 22 _ResidualModules = [] 23 for _ in range(self.nModules): 24 _ResidualModules.append(M.Residual(self.nChannels, self.nChannels)) # 輸入和輸出相等,爲x+3*(BN+ReLU+conv) 25 _ResidualModules = nn.Sequential(*_ResidualModules) 26 _Residual.append(_ResidualModules) # self.nModules 個 3*(BN+ReLU+conv) 27 _lin1.append(M.BnReluConv(self.nChannels, self.nChannels)) # BN+ReLU+conv 28 _chantojoints.append(nn.Conv2d(self.nChannels, self.nJoints,1)) # 1*1 conv,維度變換 29 _lin2.append(nn.Conv2d(self.nChannels, self.nChannels,1)) # 1*1 conv,維度不變 30 _jointstochan.append(nn.Conv2d(self.nJoints,self.nChannels,1)) # 1*1 conv,維度變換 31 32 self.hourglass = nn.ModuleList(_hourglass) 33 self.Residual = nn.ModuleList(_Residual) 34 self.lin1 = nn.ModuleList(_lin1) 35 self.chantojoints = nn.ModuleList(_chantojoints) 36 self.lin2 = nn.ModuleList(_lin2) 37 self.jointstochan = nn.ModuleList(_jointstochan) 38 39 def forward(self, x): 40 x = self.start(x) 41 x = self.res1(x) 42 x = self.mp(x) 43 x = self.res2(x) 44 x = self.res3(x) 45 out = [] 46 47 for i in range(self.nStack): 48 x1 = self.hourglass[i](x) 49 x1 = self.Residual[i](x1) 50 x1 = self.lin1[i](x1) 51 out.append(self.chantojoints[i](x1)) 52 x1 = self.lin2[i](x1) 53 x = x + x1 + self.jointstochan[i](out[i]) # 特徵求和 54 55 return (out)
hourglass在numReductions>1時,遞歸調用本身,結構以下:
代碼以下:
1 class Hourglass(nn.Module): 2 """docstring for Hourglass""" 3 def __init__(self, nChannels = 256, numReductions = 4, nModules = 2, poolKernel = (2,2), poolStride = (2,2), upSampleKernel = 2): 4 super(Hourglass, self).__init__() 5 self.numReductions = numReductions 6 self.nModules = nModules 7 self.nChannels = nChannels 8 self.poolKernel = poolKernel 9 self.poolStride = poolStride 10 self.upSampleKernel = upSampleKernel 11 12 """For the skip connection, a residual module (or sequence of residuaql modules) """ 13 _skip = [] 14 for _ in range(self.nModules): 15 _skip.append(M.Residual(self.nChannels, self.nChannels)) # 輸入和輸出相等,爲x+3*(BN+ReLU+conv) 16 self.skip = nn.Sequential(*_skip) 17 18 """First pooling to go to smaller dimension then pass input through 19 Residual Module or sequence of Modules then and subsequent cases: 20 either pass through Hourglass of numReductions-1 or pass through M.Residual Module or sequence of Modules """ 21 self.mp = nn.MaxPool2d(self.poolKernel, self.poolStride) 22 23 _afterpool = [] 24 for _ in range(self.nModules): 25 _afterpool.append(M.Residual(self.nChannels, self.nChannels)) # 輸入和輸出相等,爲x+3*(BN+ReLU+conv) 26 self.afterpool = nn.Sequential(*_afterpool) 27 28 if (numReductions > 1): 29 self.hg = Hourglass(self.nChannels, self.numReductions-1, self.nModules, self.poolKernel, self.poolStride) # 嵌套調用自己 30 else: 31 _num1res = [] 32 for _ in range(self.nModules): 33 _num1res.append(M.Residual(self.nChannels,self.nChannels)) # 輸入和輸出相等,爲x+3*(BN+ReLU+conv) 34 self.num1res = nn.Sequential(*_num1res) # doesnt seem that important ? 35 36 """ Now another M.Residual Module or sequence of M.Residual Modules """ 37 _lowres = [] 38 for _ in range(self.nModules): 39 _lowres.append(M.Residual(self.nChannels,self.nChannels)) # 輸入和輸出相等,爲x+3*(BN+ReLU+conv) 40 self.lowres = nn.Sequential(*_lowres) 41 42 """ Upsampling Layer (Can we change this??????) As per Newell's paper upsamping recommended """ 43 self.up = myUpsample()#nn.Upsample(scale_factor = self.upSampleKernel) # 將高和寬擴充爲原來2倍,實現上採樣 44 45 46 def forward(self, x): 47 out1 = x 48 out1 = self.skip(out1) # 輸入和輸出相等,爲x+3*(BN+ReLU+conv) 49 out2 = x 50 out2 = self.mp(out2) # 降維 51 out2 = self.afterpool(out2) # 輸入和輸出相等,爲x+3*(BN+ReLU+conv) 52 if self.numReductions>1: 53 out2 = self.hg(out2) # 嵌套調用自己 54 else: 55 out2 = self.num1res(out2) # 輸入和輸出相等,爲x+3*(BN+ReLU+conv) 56 out2 = self.lowres(out2) # 輸入和輸出相等,爲x+3*(BN+ReLU+conv) 57 out2 = self.up(out2) # 升維 58 59 return out2 + out1 # 求和
上採樣代碼以下:
1 class myUpsample(nn.Module): 2 def __init__(self): 3 super(myUpsample, self).__init__() 4 pass 5 def forward(self, x): # 將高和寬擴充爲原來2倍,實現上採樣 6 return x[:, :, :, None, :, None].expand(-1, -1, -1, 2, -1, 2).reshape(x.size(0), x.size(1), x.size(2)*2, x.size(3)*2)
其中x爲(N)(C)(H)(W)的矩陣,x[:, :, :, None, :, None]爲(N)(C)(H)(1)(W)(1)的矩陣,expand以後變成(N)(C)(H)(2)(W)(2)的矩陣,最終reshape以後變成(N)(C)(2H) (2W)的矩陣,實現了將1個像素水平和垂直方向各擴充2倍,變成4個像素(4個像素值相同),完成了上採樣。
殘差模塊結構以下:
代碼以下:
1 class Residual(nn.Module): 2 """docstring for Residual""" # 輸入和輸出相等,爲x+3*(BN+ReLU+conv);不然輸入經過1*1conv結果和3*(BN+ReLU+conv)求和 3 def __init__(self, inChannels, outChannels): 4 super(Residual, self).__init__() 5 self.inChannels = inChannels 6 self.outChannels = outChannels 7 self.cb = ConvBlock(inChannels, outChannels) # 3 * (BN+ReLU+conv) 其中第一組降維,第二組不變,第三組升維 8 self.skip = SkipLayer(inChannels, outChannels) # 輸入和輸出通道相等,則輸出=輸入,不然爲1*1 conv 9 10 def forward(self, x): 11 out = 0 12 out = out + self.cb(x) 13 out = out + self.skip(x) 14 return out
其中skiplayer代碼以下:
1 class SkipLayer(nn.Module): 2 """docstring for SkipLayer""" # 輸入和輸出通道相等,則輸出=輸入,不然爲1*1 conv 3 def __init__(self, inChannels, outChannels): 4 super(SkipLayer, self).__init__() 5 self.inChannels = inChannels 6 self.outChannels = outChannels 7 if (self.inChannels == self.outChannels): 8 self.conv = None 9 else: 10 self.conv = nn.Conv2d(self.inChannels, self.outChannels, 1) 11 12 def forward(self, x): 13 if self.conv is not None: 14 x = self.conv(x) 15 return x
1 class BnReluConv(nn.Module): 2 """docstring for BnReluConv""" # BN+ReLU+conv 3 def __init__(self, inChannels, outChannels, kernelSize = 1, stride = 1, padding = 0): 4 super(BnReluConv, self).__init__() 5 self.inChannels = inChannels 6 self.outChannels = outChannels 7 self.kernelSize = kernelSize 8 self.stride = stride 9 self.padding = padding 10 11 self.bn = nn.BatchNorm2d(self.inChannels) 12 self.conv = nn.Conv2d(self.inChannels, self.outChannels, self.kernelSize, self.stride, self.padding) 13 self.relu = nn.ReLU() 14 15 def forward(self, x): 16 x = self.bn(x) 17 x = self.relu(x) 18 x = self.conv(x) 19 return x
1 class ConvBlock(nn.Module): 2 """docstring for ConvBlock""" # 3 * (BN+ReLU+conv) 其中第一組降維,第二組不變,第三組升維 3 def __init__(self, inChannels, outChannels): 4 super(ConvBlock, self).__init__() 5 self.inChannels = inChannels 6 self.outChannels = outChannels 7 self.outChannelsby2 = outChannels//2 8 9 self.cbr1 = BnReluConv(self.inChannels, self.outChannelsby2, 1, 1, 0) # BN+ReLU+conv 10 self.cbr2 = BnReluConv(self.outChannelsby2, self.outChannelsby2, 3, 1, 1) # BN+ReLU+conv 11 self.cbr3 = BnReluConv(self.outChannelsby2, self.outChannels, 1, 1, 0) # BN+ReLU+conv 12 13 def forward(self, x): 14 x = self.cbr1(x) 15 x = self.cbr2(x) 16 x = self.cbr3(x) 17 return x