這幾天一直在用Pytorch來複現文本檢測領域的CTPN論文,本文章將從數據處理、訓練標籤生成、神經網絡搭建、損失函數設計、訓練主過程編寫等這幾個方面來一步一步復現CTPN。CTPN算法理論能夠參考這裏。html
咱們的訓練選擇天池ICPR2018和MSRA_TD500兩個數據集,天池ICPR的數據集爲網絡圖像,都是一些淘寶商家上傳到淘寶的一些商品介紹圖像,其標籤方式參考了ICDAR2015的數據標籤格式,即一個文本框用4個座標來表示,即左上、右上、右下、左下四個座標,共八個值,記做[x1 y1 x2 y2 x3 y3 x4 y4]git
天池ICPR2018數據集的風格以下,字體形態格式顏色多變,多嵌套於物體之中,識別難度大:github
MSRA_TD500使微軟收集的一個文本檢測和識別的一個數據集,裏面的圖像可能是街景圖,背景比較複雜,但文本位置比較明顯,一目瞭然。由於MSRA_TD500的標籤格式不同,最後一個參數表示矩形框的旋轉角度。算法
因此咱們第一步就是將這兩個數據集的標籤格式統一,個人作法是將MSRA數據集格式改成ICDAR格式,方便後面的模型訓練。由於MSRA_TD500採起的標籤格式是[index difficulty_label x y w h angle],因此咱們須要根據這個文本框的旋轉角度來求得水平文本框旋轉後的4個座標位置。實現以下:網絡
""" This file is to change MSRA_TD500 dataset format to ICDAR2015 dataset format. MSRA_TD500 format: [index difficulty_label x y w h angle] ICDAR2015 format: [left_top_x left_top_y right_top_X right_top_y right_bottom_x right_bottom_y left_bottom_x left_bottom_y] """ import math import cv2 import os # 求旋轉後矩形的4個座標 def get_box_img(x, y, w, h, angle): # 矩形框中點(x0,y0) x0 = x + w/2 y0 = y + h/2 l = math.sqrt(pow(w/2, 2) + pow(h/2, 2)) # 即對角線的一半 # angle小於0,逆時針轉 if angle < 0: a1 = -angle + math.atan(h / float(w)) # 旋轉角度-對角線與底線所成的角度 a2 = -angle - math.atan(h / float(w)) # 旋轉角度+對角線與底線所成的角度 pt1 = (x0 - l * math.cos(a2), y0 + l * math.sin(a2)) pt2 = (x0 + l * math.cos(a1), y0 - l * math.sin(a1)) pt3 = (x0 + l * math.cos(a2), y0 - l * math.sin(a2)) # x0+左下點旋轉後在水平線上的投影, y0-左下點在垂直線上的投影,顯然逆時針轉時,左下點上一和左移了。 pt4 = (x0 - l * math.cos(a1), y0 + l * math.sin(a1)) else: a1 = angle + math.atan(h / float(w)) a2 = angle - math.atan(h / float(w)) pt1 = (x0 - l * math.cos(a1), y0 - l * math.sin(a1)) pt2 = (x0 + l * math.cos(a2), y0 + l * math.sin(a2)) pt3 = (x0 + l * math.cos(a1), y0 + l * math.sin(a1)) pt4 = (x0 - l * math.cos(a2), y0 - l * math.sin(a2)) return [pt1[0], pt1[1], pt2[0], pt2[1], pt3[0], pt3[1], pt4[0], pt4[1]] def read_file(path): result = [] for line in open(path): info = [] data = line.split(' ') info.append(int(data[2])) info.append(int(data[3])) info.append(int(data[4])) info.append(int(data[5])) info.append(float(data[6])) info.append(data[0]) result.append(info) return result if __name__ == '__main__': file_path = '/home/ljs/OCR_dataset/MSRA-TD500/test/' save_img_path = '../dataset/OCR_dataset/ctpn/test_im/' save_gt_path = '../dataset/OCR_dataset/ctpn/test_gt/' file_list = os.listdir(file_path) for f in file_list: if '.gt' in f: continue name = f[0:8] txt_path = file_path + name + '.gt' im_path = file_path + f im = cv2.imread(im_path) coordinate = read_file(txt_path) # 仿照ICDAR格式,圖片名字寫作img_xx.jpg,對應的標籤文件寫作gt_img_xx.txt cv2.imwrite(save_img_path + name.lower() + '.jpg', im) save_gt = open(save_gt_path + 'gt_' + name.lower() + '.txt', 'w') for i in coordinate: box = get_box_img(i[0], i[1], i[2], i[3], i[4]) box = [int(box[i]) for i in range(len(box))] box = [str(box[i]) for i in range(len(box))] save_gt.write(','.join(box)) save_gt.write('\n')
通過格式處理後,咱們兩份數據集算是整理好了。固然咱們還須要對整個數據集劃分爲訓練集和測試集,個人文件組織習慣以下:train_im, test_im文件夾裝的是訓練和測試圖像,train_gt和test_gt裝的是訓練和測試標籤。架構
由於CTPN的核心思想也是基於Faster RCNN中的region proposal機制的,因此原始數據標籤須要轉化爲
anchor標籤。訓練數據的標籤的生成的代碼是最難寫,由於從一個完整的文本框標籤轉化爲一個個小尺度文本框標籤確實有點難度,並且這個anchor標籤的生成方式也與Faster RCNN生成方式略有不一樣。下面講一講個人實現思路:app
第一步咱們須要將原先每張圖的bbox標籤轉化爲每一個anchor標籤。爲了實現該功能,咱們先將一張圖劃分爲寬度爲16的各個anchor。less
def generate_gt_anchor(img, box, anchor_width=16): """ calsulate ground truth fine-scale box :param img: input image :param box: ground truth box (4 point) :param anchor_width: :return: tuple (position, h, cy) """ if not isinstance(box[0], float): box = [float(box[i]) for i in range(len(box))] result = [] # 求解一個bbox下,能分解爲多少個16寬度的小anchor,並求出最左和最右的小achor的id left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width)) # the left side anchor of the text box, downwards right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width)) # the right side anchor of the text box, upwards # handle extreme case, the right side anchor may exceed the image width if right_anchor_num * 16 + 15 > img.shape[1]: right_anchor_num -= 1 # combine the left-side and the right-side x_coordinate of a text anchor into one pair position_pair = [(i * anchor_width, (i + 1) * anchor_width - 1) for i in range(left_anchor_num, right_anchor_num)] # 計算每一個gt anchor的真實位置,其實就是求解gt anchor的上邊界和下邊界 y_top, y_bottom = cal_y_top_and_bottom(img, position_pair, box) # 最後將每一個anchor的位置(水平ID)、anchor中心y座標、anchor高度存儲並返回 for i in range(len(position_pair)): position = int(position_pair[i][0] / anchor_width) # the index of anchor box h = y_bottom[i] - y_top[i] + 1 # the height of anchor box cy = (float(y_bottom[i]) + float(y_top[i])) / 2.0 # the center point of anchor box result.append((position, cy, h)) return result
計算anchor上下邊界的方法:dom
# cal the gt anchor box's bottom and top coordinate def cal_y_top_and_bottom(raw_img, position_pair, box): """ :param raw_img: :param position_pair: for example:[(0, 15), (16, 31), ...] :param box: gt box (4 point) :return: top and bottom coordinates for y-axis """ img = copy.deepcopy(raw_img) y_top = [] y_bottom = [] height = img.shape[0] # 設置圖像mask,channel 0爲全黑圖 for i in range(img.shape[0]): for j in range(img.shape[1]): img[i, j, 0] = 0 top_flag = False bottom_flag = False # 根據bbox四點畫出文本框,channel 0下文本框爲白色 img = other.draw_box_4pt(img, box, color=(255, 0, 0)) for k in range(len(position_pair)): # 從左到右遍歷anchor gt,對每一個anchor從上往下掃描像素,遇到白色像素點(255)就停下來,此時像素點座標y就是該anchor gt的上邊界 # calc top y coordinate for y in range(0, height-1): # loop each anchor, from left to right for x in range(position_pair[k][0], position_pair[k][1] + 1): if img[y, x, 0] == 255: y_top.append(y) top_flag = True break if top_flag is True: break # 從左到右遍歷anchor gt,對每一個anchor從下往上掃描像素,遇到白色像素點(255)就停下來,此時像素點座標y就是該anchor gt的下邊界 # calc bottom y coordinate, pixel from down to top loop for y in range(height - 1, -1, -1): # loop each anchor, from left to right for x in range(position_pair[k][0], position_pair[k][1] + 1): if img[y, x, 0] == 255: y_bottom.append(y) bottom_flag = True break if bottom_flag is True: break top_flag = False bottom_flag = False return y_top, y_bottom
通過上面的標籤處理,咱們已經將原先的標準的文本框標籤轉化爲一個一個小尺度anchor標籤,如下是標籤轉化後的效果:ide
以上標籤可視化後看來anchor標籤作得不錯,可是這裏須要提出的是,我發現這種anchor生成方法是不太精準的,好比一個文本框邊緣像素恰好落在一個新的anchor上,那麼咱們就要爲這個像素分配一個16像素的anchor,顯然致使了文本框標籤的不許確,引入了15像素的偏差,這個是須要思考的。這個問題咱們先不作處理,繼續下面的工做。
固然轉化期間咱們也遇到不少奇怪的問題,好比下圖這種標籤都已經超出圖像範圍的,咱們必須作相應的特殊處理,好比限定標籤橫座標的最大尺寸爲圖像寬度。
left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width)) # the left side anchor of the text box, downwards right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width)) # the right side anchor of the text box, upwards
由於CTPN用到了CNN+雙向LSTM的網絡結構,因此咱們分步實現CTPN架構。
CNN部分CTPN採起了VGG16進行底層特徵提取。
class VGG_16(nn.Module): """ VGG-16 without pooling layer before fc layer """ def __init__(self): super(VGG_16, self).__init__() self.convolution1_1 = nn.Conv2d(3, 64, 3, padding=1) self.convolution1_2 = nn.Conv2d(64, 64, 3, padding=1) self.pooling1 = nn.MaxPool2d(2, stride=2) self.convolution2_1 = nn.Conv2d(64, 128, 3, padding=1) self.convolution2_2 = nn.Conv2d(128, 128, 3, padding=1) self.pooling2 = nn.MaxPool2d(2, stride=2) self.convolution3_1 = nn.Conv2d(128, 256, 3, padding=1) self.convolution3_2 = nn.Conv2d(256, 256, 3, padding=1) self.convolution3_3 = nn.Conv2d(256, 256, 3, padding=1) self.pooling3 = nn.MaxPool2d(2, stride=2) self.convolution4_1 = nn.Conv2d(256, 512, 3, padding=1) self.convolution4_2 = nn.Conv2d(512, 512, 3, padding=1) self.convolution4_3 = nn.Conv2d(512, 512, 3, padding=1) self.pooling4 = nn.MaxPool2d(2, stride=2) self.convolution5_1 = nn.Conv2d(512, 512, 3, padding=1) self.convolution5_2 = nn.Conv2d(512, 512, 3, padding=1) self.convolution5_3 = nn.Conv2d(512, 512, 3, padding=1) def forward(self, x): x = F.relu(self.convolution1_1(x), inplace=True) x = F.relu(self.convolution1_2(x), inplace=True) x = self.pooling1(x) x = F.relu(self.convolution2_1(x), inplace=True) x = F.relu(self.convolution2_2(x), inplace=True) x = self.pooling2(x) x = F.relu(self.convolution3_1(x), inplace=True) x = F.relu(self.convolution3_2(x), inplace=True) x = F.relu(self.convolution3_3(x), inplace=True) x = self.pooling3(x) x = F.relu(self.convolution4_1(x), inplace=True) x = F.relu(self.convolution4_2(x), inplace=True) x = F.relu(self.convolution4_3(x), inplace=True) x = self.pooling4(x) x = F.relu(self.convolution5_1(x), inplace=True) x = F.relu(self.convolution5_2(x), inplace=True) x = F.relu(self.convolution5_3(x), inplace=True) return x
再實現雙向LSTM,加強關聯序列的信息學習。
class BLSTM(nn.Module): def __init__(self, channel, hidden_unit, bidirectional=True): """ :param channel: lstm input channel num :param hidden_unit: lstm hidden unit :param bidirectional: """ super(BLSTM, self).__init__() self.lstm = nn.LSTM(channel, hidden_unit, bidirectional=bidirectional) def forward(self, x): """ WARNING: The batch size of x must be 1. """ x = x.transpose(1, 3) recurrent, _ = self.lstm(x[0]) recurrent = recurrent[np.newaxis, :, :, :] recurrent = recurrent.transpose(1, 3) return recurrent
這裏實現多一層中間層,用於鏈接CNN和LSTM。將VGG最後一層卷積層輸出的feature map轉化爲向量形式,用於接下來的LSTM訓練。
class Im2col(nn.Module): def __init__(self, kernel_size, stride, padding): super(Im2col, self).__init__() self.kernel_size = kernel_size self.stride = stride self.padding = padding def forward(self, x): height = x.shape[2] x = F.unfold(x, self.kernel_size, padding=self.padding, stride=self.stride) x = x.reshape((x.shape[0], x.shape[1], height, -1)) return x
最後將以上三部分拼接成一個完整的CTPN網絡:底層使用VGG16作特徵提取->lstm序列信息學習->output每一個anchor分數,h, y, side_refinement
class CTPN(nn.Module): def __init__(self): super(CTPN, self).__init__() self.cnn = nn.Sequential() self.cnn.add_module('VGG_16', VGG_16()) self.rnn = nn.Sequential() self.rnn.add_module('im2col', Net.Im2col((3, 3), (1, 1), (1, 1))) self.rnn.add_module('blstm', BLSTM(3 * 3 * 512, 128)) self.FC = nn.Conv2d(256, 512, 1) self.vertical_coordinate = nn.Conv2d(512, 2 * 10, 1) # 最終輸出2K個參數(k=10),10表示anchor的尺寸個數,2個參數分別表示anchor的h和dy self.score = nn.Conv2d(512, 2 * 10, 1) # 最終輸出是2K個分數(k=10),2表示有無字符,10表示anchor的尺寸個數 self.side_refinement = nn.Conv2d(512, 10, 1) # 最終輸出1K個參數(k=10),該參數表示該anchor的水平偏移,用於精修文本框水平邊緣精度,,10表示anchor的尺寸個數 def forward(self, x, val=False): x = self.cnn(x) x = self.rnn(x) x = self.FC(x) x = F.relu(x, inplace=True) vertical_pred = self.vertical_coordinate(x) score = self.score(x) if val: score = score.reshape((score.shape[0], 10, 2, score.shape[2], score.shape[3])) score = score.squeeze(0) score = score.transpose(1, 2) score = score.transpose(2, 3) score = score.reshape((-1, 2)) #score = F.softmax(score, dim=1) score = score.reshape((10, vertical_pred.shape[2], -1, 2)) vertical_pred = vertical_pred.reshape((vertical_pred.shape[0], 10, 2, vertical_pred.shape[2], vertical_pred.shape[3])) side_refinement = self.side_refinement(x) return vertical_pred, score, side_refinement
CTPN的LOSS分爲三部分:
先定義好一些固定參數
class CTPN_Loss(nn.Module): def __init__(self, using_cuda=False): super(CTPN_Loss, self).__init__() self.Ns = 128 self.ratio = 0.5 self.lambda1 = 1.0 self.lambda2 = 1.0 self.Ls_cls = nn.CrossEntropyLoss() self.Lv_reg = nn.SmoothL1Loss() self.Lo_reg = nn.SmoothL1Loss() self.using_cuda = using_cuda
首先設計classification loss
cls_loss = 0.0 if self.using_cuda: for p in positive_batch: cls_loss += self.Ls_cls(score[0, p[2] * 2: ((p[2] + 1) * 2), p[1], p[0]].unsqueeze(0), torch.LongTensor([1]).cuda()) for n in negative_batch: cls_loss += self.Ls_cls(score[0, n[2] * 2: ((n[2] + 1) * 2), n[1], n[0]].unsqueeze(0), torch.LongTensor([0]).cuda()) else: for p in positive_batch: cls_loss += self.Ls_cls(score[0, p[2] * 2: ((p[2] + 1) * 2), p[1], p[0]].unsqueeze(0), torch.LongTensor([1])) for n in negative_batch: cls_loss += self.Ls_cls(score[0, n[2] * 2: ((n[2] + 1) * 2), n[1], n[0]].unsqueeze(0), torch.LongTensor([0])) cls_loss = cls_loss / self.Ns
而後是vertical coordinate regression loss,反映的是y和h的誤差
# calculate vertical coordinate regression loss v_reg_loss = 0.0 Nv = len(vertical_reg) if self.using_cuda: for v in vertical_reg: v_reg_loss += self.Lv_reg(vertical_pred[0, v[2] * 2: ((v[2] + 1) * 2), v[1], v[0]].unsqueeze(0), torch.FloatTensor([v[3], v[4]]).unsqueeze(0).cuda()) else: for v in vertical_reg: v_reg_loss += self.Lv_reg(vertical_pred[0, v[2] * 2: ((v[2] + 1) * 2), v[1], v[0]].unsqueeze(0), torch.FloatTensor([v[3], v[4]]).unsqueeze(0)) v_reg_loss = v_reg_loss / float(Nv)
最後計算side refinement regression loss,用於修正邊緣精度
# calculate side refinement regression loss o_reg_loss = 0.0 No = len(side_refinement_reg) if self.using_cuda: for s in side_refinement_reg: o_reg_loss += self.Lo_reg(side_refinement[0, s[2]: s[2] + 1, s[1], s[0]].unsqueeze(0), torch.FloatTensor([s[3]]).unsqueeze(0).cuda()) else: for s in side_refinement_reg: o_reg_loss += self.Lo_reg(side_refinement[0, s[2]: s[2] + 1, s[1], s[0]].unsqueeze(0), torch.FloatTensor([s[3]]).unsqueeze(0)) o_reg_loss = o_reg_loss / float(No)
固然最後還有個total loss,彙總整個訓練過程當中的loss
loss = cls_loss + v_reg_loss * self.lambda1 + o_reg_loss * self.lambda2
訓練:優化器咱們選擇SGD,learning rate咱們設置了兩個,前N個epoch使用較大的lr,後面的epoch使用較小的lr以更好地收斂。訓練過程咱們定義了4個loss,分別是total_cls_loss,total_v_reg_loss, total_o_reg_loss, total_loss(前面三個loss相加)。
net = Net.CTPN() # 獲取網絡結構 for name, value in net.named_parameters(): if name in no_grad: value.requires_grad = False else: value.requires_grad = True # for name, value in net.named_parameters(): # print('name: {0}, grad: {1}'.format(name, value.requires_grad)) net.load_state_dict(torch.load('./lib/vgg16.model')) # net.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) lib.utils.init_weight(net) if using_cuda: net.cuda() net.train() print(net) criterion = Loss.CTPN_Loss(using_cuda=using_cuda) # 獲取loss train_im_list, train_gt_list, val_im_list, val_gt_list = create_train_val() # 獲取訓練、測試數據 total_iter = len(train_im_list) print("total training image num is %s" % len(train_im_list)) print("total val image num is %s" % len(val_im_list)) train_loss_list = [] test_loss_list = [] # 開始迭代訓練 for i in range(epoch): if i >= change_epoch: lr = lr_behind else: lr = lr_front optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) #optimizer = optim.Adam(net.parameters(), lr=lr) iteration = 1 total_loss = 0 total_cls_loss = 0 total_v_reg_loss = 0 total_o_reg_loss = 0 start_time = time.time() random.shuffle(train_im_list) # 打亂訓練集 # print(random_im_list) for im in train_im_list: root, file_name = os.path.split(im) root, _ = os.path.split(root) name, _ = os.path.splitext(file_name) gt_name = 'gt_' + name + '.txt' gt_path = os.path.join(root, "train_gt", gt_name) if not os.path.exists(gt_path): print('Ground truth file of image {0} not exists.'.format(im)) continue gt_txt = lib.dataset_handler.read_gt_file(gt_path) # 讀取對應的標籤 #print("processing image %s" % os.path.join(img_root1, im)) img = cv2.imread(im) if img is None: iteration += 1 continue img, gt_txt = lib.dataset_handler.scale_img(img, gt_txt) # 圖像和標籤作歸一化 tensor_img = img[np.newaxis, :, :, :] tensor_img = tensor_img.transpose((0, 3, 1, 2)) if using_cuda: tensor_img = torch.FloatTensor(tensor_img).cuda() else: tensor_img = torch.FloatTensor(tensor_img) vertical_pred, score, side_refinement = net(tensor_img) # 正向計算,獲取預測結果 del tensor_img # transform bbox gt to anchor gt for training positive = [] negative = [] vertical_reg = [] side_refinement_reg = [] visual_img = copy.deepcopy(img) # 該圖用於可視化標籤 try: # loop all bbox in one image for box in gt_txt: # generate anchors from one bbox gt_anchor, visual_img = lib.generate_gt_anchor.generate_gt_anchor(img, box, draw_img_gt=visual_img) # 獲取圖像的anchor標籤 positive1, negative1, vertical_reg1, side_refinement_reg1 = lib.tag_anchor.tag_anchor(gt_anchor, score, box) # 計算預測值反映在anchor層面的數據 positive += positive1 negative += negative1 vertical_reg += vertical_reg1 side_refinement_reg += side_refinement_reg1 except: print("warning: img %s raise error!" % im) iteration += 1 continue if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0: iteration += 1 continue cv2.imwrite(os.path.join(DRAW_PREFIX, file_name), visual_img) optimizer.zero_grad() # 計算偏差 loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive, negative, vertical_reg, side_refinement_reg) # 反向傳播 loss.backward() optimizer.step() iteration += 1 # save gpu memory by transferring loss to float total_loss += float(loss) total_cls_loss += float(cls_loss) total_v_reg_loss += float(v_reg_loss) total_o_reg_loss += float(o_reg_loss) if iteration % display_iter == 0: end_time = time.time() total_time = end_time - start_time print('Epoch: {2}/{3}, Iteration: {0}/{1}, loss: {4}, cls_loss: {5}, v_reg_loss: {6}, o_reg_loss: {7}, {8}'. format(iteration, total_iter, i, epoch, total_loss / display_iter, total_cls_loss / display_iter, total_v_reg_loss / display_iter, total_o_reg_loss / display_iter, im)) logger.info('Epoch: {2}/{3}, Iteration: {0}/{1}'.format(iteration, total_iter, i, epoch)) logger.info('loss: {0}'.format(total_loss / display_iter)) logger.info('classification loss: {0}'.format(total_cls_loss / display_iter)) logger.info('vertical regression loss: {0}'.format(total_v_reg_loss / display_iter)) logger.info('side-refinement regression loss: {0}'.format(total_o_reg_loss / display_iter)) train_loss_list.append(total_loss) total_loss = 0 total_cls_loss = 0 total_v_reg_loss = 0 total_o_reg_loss = 0 start_time = time.time() # 按期驗證模型性能 if iteration % val_iter == 0: net.eval() logger.info('Start evaluate at {0} epoch {1} iteration.'.format(i, iteration)) val_loss = evaluate.val(net, criterion, val_batch_size, using_cuda, logger, val_im_list) logger.info('End evaluate.') net.train() start_time = time.time() test_loss_list.append(val_loss) # 按期存儲模型 if iteration % save_iter == 0: print('Model saved at ./model/ctpn-{0}-{1}.model'.format(i, iteration)) torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-{1}.model'.format(i, iteration)) print('Model saved at ./model/ctpn-{0}-end.model'.format(i)) torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-end.model'.format(i)) # 畫出loss的變化圖 draw_loss_plot(train_loss_list, test_loss_list)
縮放圖像具備必定規則:首先要保證文本框label的最短邊也要等於600。咱們經過scale = float(shortest_side)/float(min(height, width))
來求得圖像的縮放係數,對原始圖像進行縮放。同時咱們也要對咱們的label也要根據該縮放係數進行縮放。
def scale_img(img, gt, shortest_side=600): height = img.shape[0] width = img.shape[1] scale = float(shortest_side)/float(min(height, width)) img = cv2.resize(img, (0, 0), fx=scale, fy=scale) if img.shape[0] < img.shape[1] and img.shape[0] != 600: img = cv2.resize(img, (600, img.shape[1])) elif img.shape[0] > img.shape[1] and img.shape[1] != 600: img = cv2.resize(img, (img.shape[0], 600)) elif img.shape[0] != 600: img = cv2.resize(img, (600, 600)) h_scale = float(img.shape[0])/float(height) w_scale = float(img.shape[1])/float(width) scale_gt = [] for box in gt: scale_box = [] for i in range(len(box)): # x座標 if i % 2 == 0: scale_box.append(int(int(box[i]) * w_scale)) # y座標 else: scale_box.append(int(int(box[i]) * h_scale)) scale_gt.append(scale_box) return img, scale_gt
驗證集評估:
def val(net, criterion, batch_num, using_cuda, logger): img_root = '../dataset/OCR_dataset/ctpn/test_im' gt_root = '../dataset/OCR_dataset/ctpn/test_gt' img_list = os.listdir(img_root) total_loss = 0 total_cls_loss = 0 total_v_reg_loss = 0 total_o_reg_loss = 0 start_time = time.time() for im in random.sample(img_list, batch_num): name, _ = os.path.splitext(im) gt_name = 'gt_' + name + '.txt' gt_path = os.path.join(gt_root, gt_name) if not os.path.exists(gt_path): print('Ground truth file of image {0} not exists.'.format(im)) continue gt_txt = Dataset.port.read_gt_file(gt_path, have_BOM=True) img = cv2.imread(os.path.join(img_root, im)) img, gt_txt = Dataset.scale_img(img, gt_txt) tensor_img = img[np.newaxis, :, :, :] tensor_img = tensor_img.transpose((0, 3, 1, 2)) if using_cuda: tensor_img = torch.FloatTensor(tensor_img).cuda() else: tensor_img = torch.FloatTensor(tensor_img) vertical_pred, score, side_refinement = net(tensor_img) del tensor_img positive = [] negative = [] vertical_reg = [] side_refinement_reg = [] for box in gt_txt: gt_anchor = Dataset.generate_gt_anchor(img, box) positive1, negative1, vertical_reg1, side_refinement_reg1 = Net.tag_anchor(gt_anchor, score, box) positive += positive1 negative += negative1 vertical_reg += vertical_reg1 side_refinement_reg += side_refinement_reg1 if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0: batch_num -= 1 continue loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive, negative, vertical_reg, side_refinement_reg) total_loss += loss total_cls_loss += cls_loss total_v_reg_loss += v_reg_loss total_o_reg_loss += o_reg_loss end_time = time.time() total_time = end_time - start_time print('#################### Start evaluate ####################') print('loss: {0}'.format(total_loss / float(batch_num))) logger.info('Evaluate loss: {0}'.format(total_loss / float(batch_num))) print('classification loss: {0}'.format(total_cls_loss / float(batch_num))) logger.info('Evaluate vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num))) print('vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num))) logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) print('side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) print('{1} iterations for {0} seconds.'.format(total_time, batch_num)) print('##################### Evaluate end #####################') print('\n')
訓練過程:
測試效果:輸入一張圖片,給出最後的檢測結果
def infer_one(im_name, net): im = cv2.imread(im_name) im = lib.dataset_handler.scale_img_only(im) # 歸一化圖像 img = copy.deepcopy(im) img = img.transpose(2, 0, 1) img = img[np.newaxis, :, :, :] img = torch.Tensor(img) v, score, side = net(img, val=True) # 送入網絡預測 result = [] # 根據分數獲取有文字的anchor for i in range(score.shape[0]): for j in range(score.shape[1]): for k in range(score.shape[2]): if score[i, j, k, 1] > THRESH_HOLD: result.append((j, k, i, float(score[i, j, k, 1].detach().numpy()))) # nms過濾 for_nms = [] for box in result: pt = lib.utils.trans_to_2pt(box[1], box[0] * 16 + 7.5, anchor_height[box[2]]) for_nms.append([pt[0], pt[1], pt[2], pt[3], box[3], box[0], box[1], box[2]]) for_nms = np.array(for_nms, dtype=np.float32) nms_result = lib.nms.cpu_nms(for_nms, NMS_THRESH) out_nms = [] for i in nms_result: out_nms.append(for_nms[i, 0:8]) # 肯定哪幾個anchors是屬於一組的 connect = get_successions(v, out_nms) # 將一組anchors合併成一條文本線 texts = get_text_lines(connect, im.shape) for box in texts: box = np.array(box) print(box) lib.draw_image.draw_ploy_4pt(im, box[0:8]) _, basename = os.path.split(im_name) cv2.imwrite('./infer_'+basename, im)
推斷時提到了get_successions
用於獲取一個預測文本行裏的全部anchors,換句話說,咱們獲得的不少預測有字符的anchor,可是咱們怎麼知道哪些acnhors能夠組成一個文本線呢?因此咱們須要實現一個anchor合併算法,這也是CTPN代碼實現中最爲困難的一步。
CTPN論文提到,文本線構造法以下:文本行構建很簡單,經過將那些text/no-text score > 0.7的連續的text proposals相鏈接便可。文本行的構建以下。
一看理論很簡單,可是一到本身實現就困難重重了。真是應了那句「紙上得來終覺淺,絕知此事要躬行」啊!get_successions
傳入的參數是v表明每一個預測anchor的h和y信息,anchors表明每一個anchors的四個頂點座標信息。
def get_successions(v, anchors=[]): texts = [] for i, anchor in enumerate(anchors): neighbours = [] # 記錄每組的anchors neighbours.append(i) center_x1 = (anchor[2] + anchor[0]) / 2 h1 = get_anchor_h(anchor, v) # 獲取該anchor的高度 # find i's neighbour # 遍歷餘下的anchors,找出鄰居 for j in range(i + 1, len(anchors)): center_x2 = (anchors[j][2] + anchors[j][0]) / 2 # 中心點X座標 h2 = get_anchor_h(anchors[j], v) # 若是這兩個Anchor間的距離小於50,並且他們的它們的垂直重疊(vertical overlap)大於必定閾值,那就是鄰居 if abs(center_x1 - center_x2) < NEIGHBOURS_MIN_DIST and \ meet_v_iou(max(anchor[1], anchors[j][1]), min(anchor[3], anchors[j][3]), h1, h2): # less than 50 pixel between each anchor neighbours.append(j) if len(neighbours) != 0: texts.append(neighbours) # 經過上面的步驟,咱們已經把每個anchor的鄰居都找到並加入了對應的集合中了,如今咱們 # 經過一個循環來不斷將每一個小組合並 need_merge = True while need_merge: need_merge = False # ok, we combine again. for i, line in enumerate(texts): if len(line) == 0: continue for index in line: for j in range(i+1, len(texts)): if index in texts[j]: texts[i] += texts[j] texts[i] = list(set(texts[i])) texts[j] = [] need_merge = True result = [] #print(texts) for text in texts: if len(text) < 2: continue local = [] for j in text: local.append(anchors[j]) result.append(local) return result
當咱們獲得一個文本框的anchors組合後,接下來要作的就是將組內的anchors串聯成一個文本框。get_text_lines
函數作的就是這個功能。
def get_text_lines(text_proposals, im_size, scores=0): """ text_proposals:boxes """ text_lines = np.zeros((len(text_proposals), 8), np.float32) for index, tp_indices in enumerate(text_proposals): text_line_boxes = np.array(tp_indices) # 每一個文本行的所有小框 #print(text_line_boxes) #print(type(text_line_boxes)) #print(text_line_boxes.shape) X = (text_line_boxes[:, 0] + text_line_boxes[:, 2]) / 2 # 求每個小框的中心x,y座標 Y = (text_line_boxes[:, 1] + text_line_boxes[:, 3]) / 2 #print(X) #print(Y) z1 = np.polyfit(X, Y, 1) # 多項式擬合,根據以前求的中心店擬合一條直線(最小二乘) x0 = np.min(text_line_boxes[:, 0]) # 文本行x座標最小值 x1 = np.max(text_line_boxes[:, 2]) # 文本行x座標最大值 offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 # 小框寬度的一半 # 以所有小框的左上角這個點去擬合一條直線,而後計算一下文本行x座標的極左極右對應的y座標 lt_y, rt_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset) # 以所有小框的左下角這個點去擬合一條直線,而後計算一下文本行x座標的極左極右對應的y座標 lb_y, rb_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset) #score = scores[list(tp_indices)].sum() / float(len(tp_indices)) # 求所有小框得分的均值做爲文本行的均值 text_lines[index, 0] = x0 text_lines[index, 1] = min(lt_y, rt_y) # 文本行上端 線段 的y座標的小值 text_lines[index, 2] = x1 text_lines[index, 3] = max(lb_y, rb_y) # 文本行下端 線段 的y座標的大值 text_lines[index, 4] = scores # 文本行得分 text_lines[index, 5] = z1[0] # 根據中心點擬合的直線的k,b text_lines[index, 6] = z1[1] height = np.mean((text_line_boxes[:, 3] - text_line_boxes[:, 1])) # 小框平均高度 text_lines[index, 7] = height + 2.5 text_recs = np.zeros((len(text_lines), 9), np.float32) index = 0 for line in text_lines: b1 = line[6] - line[7] / 2 # 根據高度和文本行中心線,求取文本行上下兩條線的b值 b2 = line[6] + line[7] / 2 x1 = line[0] y1 = line[5] * line[0] + b1 # 左上 x2 = line[2] y2 = line[5] * line[2] + b1 # 右上 x3 = line[0] y3 = line[5] * line[0] + b2 # 左下 x4 = line[2] y4 = line[5] * line[2] + b2 # 右下 disX = x2 - x1 disY = y2 - y1 width = np.sqrt(disX * disX + disY * disY) # 文本行寬度 fTmp0 = y3 - y1 # 文本行高度 fTmp1 = fTmp0 * disY / width x = np.fabs(fTmp1 * disX / width) # 作補償 y = np.fabs(fTmp1 * disY / width) if line[5] < 0: x1 -= x y1 += y x4 += x y4 -= y else: x2 += x y2 += y x3 -= x y3 -= y # clock-wise order text_recs[index, 0] = x1 text_recs[index, 1] = y1 text_recs[index, 2] = x2 text_recs[index, 3] = y2 text_recs[index, 4] = x4 text_recs[index, 5] = y4 text_recs[index, 6] = x3 text_recs[index, 7] = y3 text_recs[index, 8] = line[4] index = index + 1 text_recs = clip_boxes(text_recs, im_size) return text_recs
首先看一下訓練出來的模型的文字檢測效果,爲了便於觀察,我把anchor和最終合併好的文本框一併畫出:
下面再看看一些比較好的文字檢測效果吧:
在實現過程當中的一些總結和想法:
CTPN的完整實現能夠參考個人Github