天池ICPR2018和MSRA_TD500兩個數據集:html
1)天池ICPR的數據集爲網絡圖像,都是一些淘寶商家上傳到淘寶的一些商品介紹圖像,其標籤方式參考了ICDAR2015的數據標籤格式,即一個文本框用4個座標來表示,即左上、右上、右下、左下四個座標,共八個值,記做[x1 y1 x2 y2 x3 y3 x4 y4]python
2)MSRA_TD500使微軟收集的一個文本檢測和識別的一個數據集,裏面的圖像可能是街景圖,背景比較複雜,但文本位置比較明顯,一目瞭然。git
由於MSRA_TD500的標籤格式不同,最後一個參數表示矩形框的旋轉角度。github
因此咱們第一步就是將這兩個數據集的標籤格式統一,個人作法是將MSRA數據集格式改成ICDAR格式,方便後面的模型訓練。算法
由於MSRA_TD500採起的標籤格式是[index difficulty_label x y w h angle],因此咱們須要根據這個文本框的旋轉角度來求得水平文本框旋轉後的4個座標位置。實現以下:網絡
""" This file is to change MSRA_TD500 dataset format to ICDAR2015 dataset format. MSRA_TD500 format: [index difficulty_label x y w h angle] ICDAR2015 format: [left_top_x left_top_y right_top_X right_top_y right_bottom_x right_bottom_y left_bottom_x left_bottom_y] """ import math import cv2 import os # 求旋轉後矩形的4個座標 def get_box_img(x, y, w, h, angle): # 矩形框中點(x0,y0) x0 = x + w/2 y0 = y + h/2 l = math.sqrt(pow(w/2, 2) + pow(h/2, 2)) # 即對角線的一半 # angle小於0,逆時針轉 if angle < 0: a1 = -angle + math.atan(h / float(w)) # 旋轉角度-對角線與底線所成的角度 a2 = -angle - math.atan(h / float(w)) # 旋轉角度+對角線與底線所成的角度 pt1 = (x0 - l * math.cos(a2), y0 + l * math.sin(a2)) pt2 = (x0 + l * math.cos(a1), y0 - l * math.sin(a1)) pt3 = (x0 + l * math.cos(a2), y0 - l * math.sin(a2)) # x0+左下點旋轉後在水平線上的投影, y0-左下點在垂直線上的投影,顯然逆時針轉時,左下點上一和左移了。 pt4 = (x0 - l * math.cos(a1), y0 + l * math.sin(a1)) else: a1 = angle + math.atan(h / float(w)) a2 = angle - math.atan(h / float(w)) pt1 = (x0 - l * math.cos(a1), y0 - l * math.sin(a1)) pt2 = (x0 + l * math.cos(a2), y0 + l * math.sin(a2)) pt3 = (x0 + l * math.cos(a1), y0 + l * math.sin(a1)) pt4 = (x0 - l * math.cos(a2), y0 - l * math.sin(a2)) return [pt1[0], pt1[1], pt2[0], pt2[1], pt3[0], pt3[1], pt4[0], pt4[1]] def read_file(path): result = [] for line in open(path): info = [] data = line.split(' ') info.append(int(data[2])) info.append(int(data[3])) info.append(int(data[4])) info.append(int(data[5])) info.append(float(data[6])) info.append(data[0]) result.append(info) return result if __name__ == '__main__': file_path = '/home/ljs/OCR_dataset/MSRA-TD500/test/' save_img_path = '../dataset/OCR_dataset/ctpn/test_im/' save_gt_path = '../dataset/OCR_dataset/ctpn/test_gt/' file_list = os.listdir(file_path) for f in file_list: if '.gt' in f: continue name = f[0:8] txt_path = file_path + name + '.gt' im_path = file_path + f im = cv2.imread(im_path) coordinate = read_file(txt_path) # 仿照ICDAR格式,圖片名字寫作img_xx.jpg,對應的標籤文件寫作gt_img_xx.txt cv2.imwrite(save_img_path + name.lower() + '.jpg', im) save_gt = open(save_gt_path + 'gt_' + name.lower() + '.txt', 'w') for i in coordinate: box = get_box_img(i[0], i[1], i[2], i[3], i[4]) box = [int(box[i]) for i in range(len(box))] box = [str(box[i]) for i in range(len(box))] save_gt.write(','.join(box)) save_gt.write('\n')
通過格式處理後,咱們兩份數據集算是整理好了。固然咱們還須要對整個數據集劃分爲訓練集和測試集,個人文件組織習慣以下:app
train_im, test_im文件夾裝的是訓練和測試圖像,train_gt和test_gt裝的是訓練和測試標籤。dom
由於CTPN的核心思想也是基於Faster RCNN中的region proposal機制的,因此原始數據標籤須要轉化爲 anchor標籤。訓練數據的標籤的生成的代碼是最難寫,ide
由於從一個完整的文本框標籤轉化爲一個個小尺度文本框標籤確實有點難度,並且這個anchor標籤的生成方式也與Faster RCNN生成方式略有不一樣。下面講一講個人實現思路:oop
第一步咱們須要將原先每張圖的bbox標籤轉化爲每一個anchor標籤。爲了實現該功能,咱們先將一張圖劃分爲寬度爲16的各個anchor。
def generate_gt_anchor(img, box, anchor_width=16): """ calsulate ground truth fine-scale box :param img: input image :param box: ground truth box (4 point) :param anchor_width: :return: tuple (position, h, cy) """ if not isinstance(box[0], float): box = [float(box[i]) for i in range(len(box))] result = [] # 求解一個bbox下,能分解爲多少個16寬度的小anchor,並求出最左和最右的小achor的id left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width)) # the left side anchor of the text box, downwards right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width)) # the right side anchor of the text box, upwards # handle extreme case, the right side anchor may exceed the image width if right_anchor_num * 16 + 15 > img.shape[1]: right_anchor_num -= 1 # combine the left-side and the right-side x_coordinate of a text anchor into one pair position_pair = [(i * anchor_width, (i + 1) * anchor_width - 1) for i in range(left_anchor_num, right_anchor_num)] # 計算每一個gt anchor的真實位置,其實就是求解gt anchor的上邊界和下邊界 y_top, y_bottom = cal_y_top_and_bottom(img, position_pair, box) # 最後將每一個anchor的位置(水平ID)、anchor中心y座標、anchor高度存儲並返回 for i in range(len(position_pair)): position = int(position_pair[i][0] / anchor_width) # the index of anchor box h = y_bottom[i] - y_top[i] + 1 # the height of anchor box cy = (float(y_bottom[i]) + float(y_top[i])) / 2.0 # the center point of anchor box result.append((position, cy, h)) return result
計算anchor上下邊界的方法:
# cal the gt anchor box's bottom and top coordinate def cal_y_top_and_bottom(raw_img, position_pair, box): """ :param raw_img: :param position_pair: for example:[(0, 15), (16, 31), ...] :param box: gt box (4 point) :return: top and bottom coordinates for y-axis """ img = copy.deepcopy(raw_img) y_top = [] y_bottom = [] height = img.shape[0] # 設置圖像mask,channel 0爲全黑圖 for i in range(img.shape[0]): for j in range(img.shape[1]): img[i, j, 0] = 0 top_flag = False bottom_flag = False # 根據bbox四點畫出文本框,channel 0下文本框爲白色 img = other.draw_box_4pt(img, box, color=(255, 0, 0)) for k in range(len(position_pair)): # 從左到右遍歷anchor gt,對每一個anchor從上往下掃描像素,遇到白色像素點(255)就停下來,此時像素點座標y就是該anchor gt的上邊界 # calc top y coordinate for y in range(0, height-1): # loop each anchor, from left to right for x in range(position_pair[k][0], position_pair[k][1] + 1): if img[y, x, 0] == 255: y_top.append(y) top_flag = True break if top_flag is True: break # 從左到右遍歷anchor gt,對每一個anchor從下往上掃描像素,遇到白色像素點(255)就停下來,此時像素點座標y就是該anchor gt的下邊界 # calc bottom y coordinate, pixel from down to top loop for y in range(height - 1, -1, -1): # loop each anchor, from left to right for x in range(position_pair[k][0], position_pair[k][1] + 1): if img[y, x, 0] == 255: y_bottom.append(y) bottom_flag = True break if bottom_flag is True: break top_flag = False bottom_flag = False return y_top, y_bottom
通過上面的標籤處理,咱們已經將原先的標準的文本框標籤轉化爲一個一個小尺度anchor標籤,如下是標籤轉化後的效果:
以上標籤可視化後看來anchor標籤作得不錯,可是這裏須要提出的是,我發現這種anchor生成方法是不太精準的,好比一個文本框邊緣像素恰好落在一個新的anchor上,那麼咱們就要爲這個像素分配一個16像素的anchor,顯然致使了文本框標籤的不許確,引入了15像素的偏差,這個是須要思考的。這個問題咱們先不作處理,繼續下面的工做。
固然轉化期間咱們也遇到不少奇怪的問題,好比下圖這種標籤都已經超出圖像範圍的,咱們必須作相應的特殊處理,好比限定標籤橫座標的最大尺寸爲圖像寬度。
left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width)) # the left side anchor of the text box, downwards right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width)) # the right side anchor of the text box, upwards
練:優化器咱們選擇SGD,learning rate咱們設置了兩個,前N個epoch使用較大的lr,後面的epoch使用較小的lr以更好地收斂。
訓練過程咱們定義了4個loss,分別是total_cls_loss,total_v_reg_loss, total_o_reg_loss, total_loss(前面三個loss相加)。
net = Net.CTPN() # 獲取網絡結構 for name, value in net.named_parameters(): if name in no_grad: value.requires_grad = False else: value.requires_grad = True # for name, value in net.named_parameters(): # print('name: {0}, grad: {1}'.format(name, value.requires_grad)) net.load_state_dict(torch.load('./lib/vgg16.model')) # net.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) lib.utils.init_weight(net) if using_cuda: net.cuda() net.train() print(net) criterion = Loss.CTPN_Loss(using_cuda=using_cuda) # 獲取loss train_im_list, train_gt_list, val_im_list, val_gt_list = create_train_val() # 獲取訓練、測試數據 total_iter = len(train_im_list) print("total training image num is %s" % len(train_im_list)) print("total val image num is %s" % len(val_im_list)) train_loss_list = [] test_loss_list = [] # 開始迭代訓練 for i in range(epoch): if i >= change_epoch: lr = lr_behind else: lr = lr_front optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) #optimizer = optim.Adam(net.parameters(), lr=lr) iteration = 1 total_loss = 0 total_cls_loss = 0 total_v_reg_loss = 0 total_o_reg_loss = 0 start_time = time.time() random.shuffle(train_im_list) # 打亂訓練集 # print(random_im_list) for im in train_im_list: root, file_name = os.path.split(im) root, _ = os.path.split(root) name, _ = os.path.splitext(file_name) gt_name = 'gt_' + name + '.txt' gt_path = os.path.join(root, "train_gt", gt_name) if not os.path.exists(gt_path): print('Ground truth file of image {0} not exists.'.format(im)) continue gt_txt = lib.dataset_handler.read_gt_file(gt_path) # 讀取對應的標籤 #print("processing image %s" % os.path.join(img_root1, im)) img = cv2.imread(im) if img is None: iteration += 1 continue img, gt_txt = lib.dataset_handler.scale_img(img, gt_txt) # 圖像和標籤作歸一化 tensor_img = img[np.newaxis, :, :, :] tensor_img = tensor_img.transpose((0, 3, 1, 2)) if using_cuda: tensor_img = torch.FloatTensor(tensor_img).cuda() else: tensor_img = torch.FloatTensor(tensor_img) vertical_pred, score, side_refinement = net(tensor_img) # 正向計算,獲取預測結果 del tensor_img # transform bbox gt to anchor gt for training positive = [] negative = [] vertical_reg = [] side_refinement_reg = [] visual_img = copy.deepcopy(img) # 該圖用於可視化標籤 try: # loop all bbox in one image for box in gt_txt: # generate anchors from one bbox gt_anchor, visual_img = lib.generate_gt_anchor.generate_gt_anchor(img, box, draw_img_gt=visual_img) # 獲取圖像的anchor標籤 positive1, negative1, vertical_reg1, side_refinement_reg1 = lib.tag_anchor.tag_anchor(gt_anchor, score, box) # 計算預測值反映在anchor層面的數據 positive += positive1 negative += negative1 vertical_reg += vertical_reg1 side_refinement_reg += side_refinement_reg1 except: print("warning: img %s raise error!" % im) iteration += 1 continue if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0: iteration += 1 continue cv2.imwrite(os.path.join(DRAW_PREFIX, file_name), visual_img) optimizer.zero_grad() # 計算偏差 loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive, negative, vertical_reg, side_refinement_reg) # 反向傳播 loss.backward() optimizer.step() iteration += 1 # save gpu memory by transferring loss to float total_loss += float(loss) total_cls_loss += float(cls_loss) total_v_reg_loss += float(v_reg_loss) total_o_reg_loss += float(o_reg_loss) if iteration % display_iter == 0: end_time = time.time() total_time = end_time - start_time print('Epoch: {2}/{3}, Iteration: {0}/{1}, loss: {4}, cls_loss: {5}, v_reg_loss: {6}, o_reg_loss: {7}, {8}'. format(iteration, total_iter, i, epoch, total_loss / display_iter, total_cls_loss / display_iter, total_v_reg_loss / display_iter, total_o_reg_loss / display_iter, im)) logger.info('Epoch: {2}/{3}, Iteration: {0}/{1}'.format(iteration, total_iter, i, epoch)) logger.info('loss: {0}'.format(total_loss / display_iter)) logger.info('classification loss: {0}'.format(total_cls_loss / display_iter)) logger.info('vertical regression loss: {0}'.format(total_v_reg_loss / display_iter)) logger.info('side-refinement regression loss: {0}'.format(total_o_reg_loss / display_iter)) train_loss_list.append(total_loss) total_loss = 0 total_cls_loss = 0 total_v_reg_loss = 0 total_o_reg_loss = 0 start_time = time.time() # 按期驗證模型性能 if iteration % val_iter == 0: net.eval() logger.info('Start evaluate at {0} epoch {1} iteration.'.format(i, iteration)) val_loss = evaluate.val(net, criterion, val_batch_size, using_cuda, logger, val_im_list) logger.info('End evaluate.') net.train() start_time = time.time() test_loss_list.append(val_loss) # 按期存儲模型 if iteration % save_iter == 0: print('Model saved at ./model/ctpn-{0}-{1}.model'.format(i, iteration)) torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-{1}.model'.format(i, iteration)) print('Model saved at ./model/ctpn-{0}-end.model'.format(i)) torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-end.model'.format(i)) # 畫出loss的變化圖 draw_loss_plot(train_loss_list, test_loss_list)
縮放圖像具備必定規則:首先要保證文本框label的最短邊也要等於600。咱們經過
scale = float(shortest_side)/float(min(height, width))
來求得圖像的縮放係數,對原始圖像進行縮放。 同時咱們也要對咱們的label也要根據該縮放係數進行縮放。
def scale_img(img, gt, shortest_side=600): height = img.shape[0] width = img.shape[1] scale = float(shortest_side)/float(min(height, width)) img = cv2.resize(img, (0, 0), fx=scale, fy=scale) if img.shape[0] < img.shape[1] and img.shape[0] != 600: img = cv2.resize(img, (600, img.shape[1])) elif img.shape[0] > img.shape[1] and img.shape[1] != 600: img = cv2.resize(img, (img.shape[0], 600)) elif img.shape[0] != 600: img = cv2.resize(img, (600, 600)) h_scale = float(img.shape[0])/float(height) w_scale = float(img.shape[1])/float(width) scale_gt = [] for box in gt: scale_box = [] for i in range(len(box)): # x座標 if i % 2 == 0: scale_box.append(int(int(box[i]) * w_scale)) # y座標 else: scale_box.append(int(int(box[i]) * h_scale)) scale_gt.append(scale_box) return img, scale_gt
驗證集評估:
def val(net, criterion, batch_num, using_cuda, logger): img_root = '../dataset/OCR_dataset/ctpn/test_im' gt_root = '../dataset/OCR_dataset/ctpn/test_gt' img_list = os.listdir(img_root) total_loss = 0 total_cls_loss = 0 total_v_reg_loss = 0 total_o_reg_loss = 0 start_time = time.time() for im in random.sample(img_list, batch_num): name, _ = os.path.splitext(im) gt_name = 'gt_' + name + '.txt' gt_path = os.path.join(gt_root, gt_name) if not os.path.exists(gt_path): print('Ground truth file of image {0} not exists.'.format(im)) continue gt_txt = Dataset.port.read_gt_file(gt_path, have_BOM=True) img = cv2.imread(os.path.join(img_root, im)) img, gt_txt = Dataset.scale_img(img, gt_txt) tensor_img = img[np.newaxis, :, :, :] tensor_img = tensor_img.transpose((0, 3, 1, 2)) if using_cuda: tensor_img = torch.FloatTensor(tensor_img).cuda() else: tensor_img = torch.FloatTensor(tensor_img) vertical_pred, score, side_refinement = net(tensor_img) del tensor_img positive = [] negative = [] vertical_reg = [] side_refinement_reg = [] for box in gt_txt: gt_anchor = Dataset.generate_gt_anchor(img, box) positive1, negative1, vertical_reg1, side_refinement_reg1 = Net.tag_anchor(gt_anchor, score, box) positive += positive1 negative += negative1 vertical_reg += vertical_reg1 side_refinement_reg += side_refinement_reg1 if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0: batch_num -= 1 continue loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive, negative, vertical_reg, side_refinement_reg) total_loss += loss total_cls_loss += cls_loss total_v_reg_loss += v_reg_loss total_o_reg_loss += o_reg_loss end_time = time.time() total_time = end_time - start_time print('#################### Start evaluate ####################') print('loss: {0}'.format(total_loss / float(batch_num))) logger.info('Evaluate loss: {0}'.format(total_loss / float(batch_num))) print('classification loss: {0}'.format(total_cls_loss / float(batch_num))) logger.info('Evaluate vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num))) print('vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num))) logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) print('side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num))) print('{1} iterations for {0} seconds.'.format(total_time, batch_num)) print('##################### Evaluate end #####################') print('\n')
測試效果:輸入一張圖片,給出最後的檢測結果
def infer_one(im_name, net): im = cv2.imread(im_name) im = lib.dataset_handler.scale_img_only(im) # 歸一化圖像 img = copy.deepcopy(im) img = img.transpose(2, 0, 1) img = img[np.newaxis, :, :, :] img = torch.Tensor(img) v, score, side = net(img, val=True) # 送入網絡預測 result = [] # 根據分數獲取有文字的anchor for i in range(score.shape[0]): for j in range(score.shape[1]): for k in range(score.shape[2]): if score[i, j, k, 1] > THRESH_HOLD: result.append((j, k, i, float(score[i, j, k, 1].detach().numpy()))) # nms過濾 for_nms = [] for box in result: pt = lib.utils.trans_to_2pt(box[1], box[0] * 16 + 7.5, anchor_height[box[2]]) for_nms.append([pt[0], pt[1], pt[2], pt[3], box[3], box[0], box[1], box[2]]) for_nms = np.array(for_nms, dtype=np.float32) nms_result = lib.nms.cpu_nms(for_nms, NMS_THRESH) out_nms = [] for i in nms_result: out_nms.append(for_nms[i, 0:8]) # 肯定哪幾個anchors是屬於一組的 connect = get_successions(v, out_nms) # 將一組anchors合併成一條文本線 texts = get_text_lines(connect, im.shape) for box in texts: box = np.array(box) print(box) lib.draw_image.draw_ploy_4pt(im, box[0:8]) _, basename = os.path.split(im_name) cv2.imwrite('./infer_'+basename, im)
推斷時提到了get_successions
用於獲取一個預測文本行裏的全部anchors,換句話說,咱們獲得的不少預測有字符的anchor,可是咱們怎麼知道哪些acnhors能夠組成一個文本線呢?因此咱們須要實現一個anchor合併算法,這也是CTPN代碼實現中最爲困難的一步。
CTPN論文提到,文本線構造法以下:文本行構建很簡單,經過將那些text/no-text score > 0.7的連續的text proposals相鏈接便可。文本行的構建以下。
一看理論很簡單,可是一到本身實現就困難重重了。真是應了那句「紙上得來終覺淺,絕知此事要躬行」啊!get_successions
傳入的參數是v表明每一個預測anchor的h和y信息,anchors表明每一個anchors的四個頂點座標信息。
首先看一下訓練出來的模型的文字檢測效果,爲了便於觀察,我把anchor和最終合併好的文本框一併畫出:
在實現過程當中的一些總結和想法:
CTPN的完整實現能夠參考博主: Github
https://www.cnblogs.com/skyfsm/p/10054386.html