首先先放下github地址:https://github.com/acm5656/ssd_pytorchhtml
而後放上參考的代碼的github地址:https://github.com/amdegroot/ssd.pytorchios
爲何要使用pytorch復現呢,由於好多大佬的代碼對於萌新真的不友好,看半天看不懂,因此筆者本着學習和練手的目的,嘗試復現下,並分享出來幫助其餘萌新學習,大佬有興趣看了後能夠提些建議~git
而後對ssd原理感興趣的同窗能夠參考個人這篇博客https://www.cnblogs.com/cmai/p/10076050.html,主要對SSD模型進行了講解。在這就主要講解代碼實現上的內容了,就再也不講原理了。github
首先看下項目目錄:編程
VOCdevkit:存放訓練數據網絡
weights :存放權重文件app
Config.py :默認的一些配置函數
Test.py :測試單張照片的識別工具
Train.py :訓練的py文件學習
augmentation.py:data augmentation的py文件,主要功能是擴大訓練數據
detection.py:對識別的結果的數據進行部分篩選,傳送給Test.py文件,供其調用使用
l2norm.py:進行l2正則化
loss_function.py:計算損失函數
ssd_net_vgg.py:ssd模型的實現
utils.py:工具類
voc0712.py:重寫dataset類,提取voc的數據並規則化
模型搭建
模型搭建在ssd_net_vgg.py中,這個類只須要將一點,即vgg的網絡須要注意,必須採用筆者的方式搭建,不然pre-train的model加載出錯,具體的緣由不在這裏闡述。
模型的實現過程,將loc和conf的提取分開進行了,這個不影響正常的使用,只是在計算損失函數時,可以方便編程而已。
default box計算
代碼在utils.py文件下,代碼以下:
def default_prior_box(): mean_layer = [] for k,f in enumerate(Config.feature_map): mean = [] for i,j in product(range(f),repeat=2): f_k = Config.image_size/Config.steps[k] cx = (j+0.5)/f_k cy = (i+0.5)/f_k s_k = Config.sk[k]/Config.image_size mean += [cx,cy,s_k,s_k] s_k_prime = sqrt(s_k * Config.sk[k+1]/Config.image_size) mean += [cx,cy,s_k_prime,s_k_prime] for ar in Config.aspect_ratios[k]: mean += [cx, cy, s_k * sqrt(ar), s_k/sqrt(ar)] mean += [cx, cy, s_k / sqrt(ar), s_k * sqrt(ar)] if Config.use_cuda: mean = torch.Tensor(mean).cuda().view(Config.feature_map[k], Config.feature_map[k], -1).contiguous() else: mean = torch.Tensor(mean).view( Config.feature_map[k],Config.feature_map[k],-1).contiguous() mean.clamp_(max=1, min=0) mean_layer.append(mean) return mean_layer
該函數則是生成box,與論文中的數量對應,最後的輸出是6個list,每一個list對應一個特徵層輸出的default box數,具體數量參考上一篇ssd論文解讀的博客。計算公式同參考上篇博客。
Loss函數計算
loss函數的功能實如今loss_function.py中,具體核心代碼以下:
class LossFun(nn.Module): def __init__(self): super(LossFun,self).__init__() def forward(self, prediction,targets,priors_boxes): loc_data , conf_data = prediction loc_data = torch.cat([o.view(o.size(0),-1,4) for o in loc_data] ,1) conf_data = torch.cat([o.view(o.size(0),-1,21) for o in conf_data],1) priors_boxes = torch.cat([o.view(-1,4) for o in priors_boxes],0) if Config.use_cuda: loc_data = loc_data.cuda() conf_data = conf_data.cuda() priors_boxes = priors_boxes.cuda() # batch_size batch_num = loc_data.size(0) # default_box數量 box_num = loc_data.size(1) # 存儲targets根據每個prior_box變換後的數據 target_loc = torch.Tensor(batch_num,box_num,4) target_loc.requires_grad_(requires_grad=False) # 存儲每個default_box預測的種類 target_conf = torch.LongTensor(batch_num,box_num) target_conf.requires_grad_(requires_grad=False) if Config.use_cuda: target_loc = target_loc.cuda() target_conf = target_conf.cuda() # 由於一次batch可能有多個圖,每次循環計算出一個圖中的box,即8732個box的loc和conf,存放在target_loc和target_conf中 for batch_id in range(batch_num): target_truths = targets[batch_id][:,:-1].data target_labels = targets[batch_id][:,-1].data if Config.use_cuda: target_truths = target_truths.cuda() target_labels = target_labels.cuda() # 計算box函數,即公式中loc損失函數的計算公式 utils.match(0.5,target_truths,priors_boxes,target_labels,target_loc,target_conf,batch_id) pos = target_conf > 0 pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) # 至關於論文中L1損失函數乘xij的操做 pre_loc_xij = loc_data[pos_idx].view(-1,4) tar_loc_xij = target_loc[pos_idx].view(-1,4) # 將計算好的loc和預測進行smooth_li損失函數 loss_loc = F.smooth_l1_loss(pre_loc_xij,tar_loc_xij,size_average=False) batch_conf = conf_data.view(-1,21) # 參照論文中conf計算方式,求出ci loss_c = utils.log_sum_exp(batch_conf) - batch_conf.gather(1, target_conf.view(-1, 1)) loss_c = loss_c.view(batch_num, -1) # 將正樣本設定爲0 loss_c[pos] = 0 # 將剩下的負樣本排序,選出目標數量的負樣本 _, loss_idx = loss_c.sort(1, descending=True) _, idx_rank = loss_idx.sort(1) num_pos = pos.long().sum(1, keepdim=True) num_neg = torch.clamp(3*num_pos, max=pos.size(1)-1) # 提取出正負樣本 neg = idx_rank < num_neg.expand_as(idx_rank) pos_idx = pos.unsqueeze(2).expand_as(conf_data) neg_idx = neg.unsqueeze(2).expand_as(conf_data) conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, 21) targets_weighted = target_conf[(pos+neg).gt(0)] loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False) N = num_pos.data.sum().double() loss_l = loss_loc.double() loss_c = loss_c.double() loss_l /= N loss_c /= N return loss_l, loss_c
其中較爲複雜的是match函數,其具體的代碼以下:
def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx): """計算default box和實際位置的jaccard比,計算出每一個box的最大jaccard比的種類和每一個種類的最大jaccard比的box Args: threshold: (float) jaccard比的閾值. truths: (tensor) 實際位置. priors: (tensor) default box variances: (tensor) 這個數據含義暫時不清楚,筆者測試過,若是不使用一樣能夠訓練. labels: (tensor) 一個圖片實際包含的類別數. loc_t: (tensor) 須要存儲每一個box不一樣類別中的最大jaccard比. conf_t: (tensor) 存儲每一個box的最大jaccard比的類別. idx: (int) 當前的批次 """ # 計算jaccard比 overlaps = jaccard( truths, # 轉換priors,轉換爲x_min,y_min,x_max和y_max point_form(priors) ) # [1,num_objects] best prior for each ground truth # 實際包含的類別對應box中jaccarb最大的box和對應的索引值,即每一個類別最優box best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) # [1,num_priors] best ground truth for each prior # 每個box,在實際類別中最大的jaccard比的類別,即每一個box最優類別 best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) best_truth_idx.squeeze_(0) best_truth_overlap.squeeze_(0) best_prior_idx.squeeze_(1) best_prior_overlap.squeeze_(1) # 將每一個類別中的最大box設置爲2,確保不影響後邊操做 best_truth_overlap.index_fill_(0, best_prior_idx, 2) # 計算每個box的最優類別,和每一個類別的最優loc for j in range(best_prior_idx.size(0)): best_truth_idx[best_prior_idx[j]] = j matches = truths[best_truth_idx] # Shape: [num_priors,4] conf = labels[best_truth_idx] + 1 # Shape: [num_priors] conf[best_truth_overlap < threshold] = 0 # label as background # 實現loc的轉換,具體的轉換公式參照論文中的loc的loss函數的計算公式 loc = encode(matches, priors, variances) loc_t[idx] = loc # [num_priors,4] encoded offsets to learn conf_t[idx] = conf # [num_priors] top class label for each prior
代碼已經添加了比較詳細的註釋了,所以再也不作過多的解釋了。
我的認爲比較難的部分代碼就是上述的幾塊,但願讀者有時間能夠debug調試測試一下,再配合註釋,應該可以理解具體的內容,代碼中data augumentation 部分沒有作詳細的解釋,這部分筆者也沒搞得太明白,只是知道其功能是對數據集進行了擴大,即擴大圖像尺寸或者裁剪其中一部份內容等功能。
注:
這個代碼有一個bug,訓練的時候loss值有必定的機率會變爲nan,我的在訓練時候的經驗是在Config.py文件中,要修改batch_size大小,越大出現的機率越小,緣由應該是部分訓練集特徵比較分散,致使預測結果得分相差較大,在計算損失函數有一個計算e的次方,致使溢出,這是我的見解,不清楚是否正確。
以上是我的的理解,若是幫到你了,但願可以在github上star一下,謝謝啦。