ICDAR2015 數據處理及訓練

訓練數據處理:

天池ICPR2018和MSRA_TD500兩個數據集:html

1)天池ICPR的數據集爲網絡圖像,都是一些淘寶商家上傳到淘寶的一些商品介紹圖像,其標籤方式參考了ICDAR2015的數據標籤格式,即一個文本框用4個座標來表示,即左上、右上、右下、左下四個座標,共八個值,記做[x1 y1 x2 y2 x3 y3 x4 y4]python

 2)MSRA_TD500使微軟收集的一個文本檢測和識別的一個數據集,裏面的圖像可能是街景圖,背景比較複雜,但文本位置比較明顯,一目瞭然。git

由於MSRA_TD500的標籤格式不同,最後一個參數表示矩形框的旋轉角度。github

因此咱們第一步就是將這兩個數據集的標籤格式統一,個人作法是將MSRA數據集格式改成ICDAR格式,方便後面的模型訓練。算法

由於MSRA_TD500採起的標籤格式是[index difficulty_label x y w h angle],因此咱們須要根據這個文本框的旋轉角度來求得水平文本框旋轉後的4個座標位置。實現以下:網絡

"""
This file is to change MSRA_TD500 dataset format to ICDAR2015 dataset format.

MSRA_TD500 format: [index difficulty_label x y w h angle]

ICDAR2015 format: [left_top_x left_top_y right_top_X right_top_y right_bottom_x right_bottom_y left_bottom_x left_bottom_y]

"""


import math
import cv2
import os

# 求旋轉後矩形的4個座標
def get_box_img(x, y, w, h, angle):
    # 矩形框中點(x0,y0)
    x0 = x + w/2
    y0 = y + h/2
    l = math.sqrt(pow(w/2, 2) + pow(h/2, 2))  # 即對角線的一半
    # angle小於0,逆時針轉
    if angle < 0:
        a1 = -angle + math.atan(h / float(w))  # 旋轉角度-對角線與底線所成的角度
        a2 = -angle - math.atan(h / float(w)) # 旋轉角度+對角線與底線所成的角度
        pt1 = (x0 - l * math.cos(a2), y0 + l * math.sin(a2))
        pt2 = (x0 + l * math.cos(a1), y0 - l * math.sin(a1))
        pt3 = (x0 + l * math.cos(a2), y0 - l * math.sin(a2))  # x0+左下點旋轉後在水平線上的投影, y0-左下點在垂直線上的投影,顯然逆時針轉時,左下點上一和左移了。
        pt4 = (x0 - l * math.cos(a1), y0 + l * math.sin(a1))
    else:
        a1 = angle + math.atan(h / float(w))
        a2 = angle - math.atan(h / float(w))
        pt1 = (x0 - l * math.cos(a1), y0 - l * math.sin(a1))
        pt2 = (x0 + l * math.cos(a2), y0 + l * math.sin(a2))
        pt3 = (x0 + l * math.cos(a1), y0 + l * math.sin(a1))
        pt4 = (x0 - l * math.cos(a2), y0 - l * math.sin(a2))
    return [pt1[0], pt1[1], pt2[0], pt2[1], pt3[0], pt3[1], pt4[0], pt4[1]]


def read_file(path):
    result = []
    for line in open(path):
        info = []
        data = line.split(' ')
        info.append(int(data[2]))
        info.append(int(data[3]))
        info.append(int(data[4]))
        info.append(int(data[5]))
        info.append(float(data[6]))
        info.append(data[0])
        result.append(info)
    return result


if __name__ == '__main__':
    file_path = '/home/ljs/OCR_dataset/MSRA-TD500/test/'
    save_img_path = '../dataset/OCR_dataset/ctpn/test_im/'
    save_gt_path = '../dataset/OCR_dataset/ctpn/test_gt/'
    file_list = os.listdir(file_path)
    for f in file_list:
        if '.gt' in f:
            continue
        name = f[0:8]
        txt_path = file_path + name + '.gt'
        im_path = file_path + f
        im = cv2.imread(im_path)
        coordinate = read_file(txt_path)
        # 仿照ICDAR格式,圖片名字寫作img_xx.jpg,對應的標籤文件寫作gt_img_xx.txt
        cv2.imwrite(save_img_path + name.lower() + '.jpg', im)
        save_gt = open(save_gt_path + 'gt_' + name.lower() + '.txt', 'w')
        for i in coordinate:
            box = get_box_img(i[0], i[1], i[2], i[3], i[4])
            box = [int(box[i]) for i in range(len(box))]
            box = [str(box[i]) for i in range(len(box))]
            save_gt.write(','.join(box))
            save_gt.write('\n')

通過格式處理後,咱們兩份數據集算是整理好了。固然咱們還須要對整個數據集劃分爲訓練集和測試集,個人文件組織習慣以下:app

train_im, test_im文件夾裝的是訓練和測試圖像,train_gt和test_gt裝的是訓練和測試標籤。dom

訓練標籤生成

由於CTPN的核心思想也是基於Faster RCNN中的region proposal機制的,因此原始數據標籤須要轉化爲 anchor標籤。訓練數據的標籤的生成的代碼是最難寫,ide

由於從一個完整的文本框標籤轉化爲一個個小尺度文本框標籤確實有點難度,並且這個anchor標籤的生成方式也與Faster RCNN生成方式略有不一樣。下面講一講個人實現思路:oop

第一步咱們須要將原先每張圖的bbox標籤轉化爲每一個anchor標籤。爲了實現該功能,咱們先將一張圖劃分爲寬度爲16的各個anchor。

  • 首先計算一張圖能夠分爲多少個寬度爲16的acnhor(好比一張圖的寬度爲w,那麼水平anchor總數爲w/16),再計算出咱們的文本框標籤中含有幾個acnhor,最左和最右的anchor又是哪幾個;
  • 計算文本框內anchor的高度和中心是多少:此時咱們能夠在一個全黑的mask中把文本框label畫上去(白色),而後從上往下和從下往上找到第一個白色像素點的位置做爲該anchor的上下邊界;
  • 最後將每一個anchor的位置(水平ID)、anchor中心y座標、anchor高度存儲並返回
def generate_gt_anchor(img, box, anchor_width=16):
    """
    calsulate ground truth fine-scale box
    :param img: input image
    :param box: ground truth box (4 point)
    :param anchor_width:
    :return: tuple (position, h, cy)
    """
    if not isinstance(box[0], float):
        box = [float(box[i]) for i in range(len(box))]
    result = []
    # 求解一個bbox下,能分解爲多少個16寬度的小anchor,並求出最左和最右的小achor的id
    left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width))  # the left side anchor of the text box, downwards
    right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width))  # the right side anchor of the text box, upwards
    
    # handle extreme case, the right side anchor may exceed the image width
    if right_anchor_num * 16 + 15 > img.shape[1]:
        right_anchor_num -= 1
        
    # combine the left-side and the right-side x_coordinate of a text anchor into one pair
    position_pair = [(i * anchor_width, (i + 1) * anchor_width - 1) for i in range(left_anchor_num, right_anchor_num)]
    
    # 計算每一個gt anchor的真實位置,其實就是求解gt anchor的上邊界和下邊界
    y_top, y_bottom = cal_y_top_and_bottom(img, position_pair, box)
    # 最後將每一個anchor的位置(水平ID)、anchor中心y座標、anchor高度存儲並返回
    for i in range(len(position_pair)):
        position = int(position_pair[i][0] / anchor_width)  # the index of anchor box
        h = y_bottom[i] - y_top[i] + 1  # the height of anchor box
        cy = (float(y_bottom[i]) + float(y_top[i])) / 2.0  # the center point of anchor box
        result.append((position, cy, h))
    return result

  計算anchor上下邊界的方法:

# cal the gt anchor box's bottom and top coordinate
def cal_y_top_and_bottom(raw_img, position_pair, box):
    """
    :param raw_img:
    :param position_pair: for example:[(0, 15), (16, 31), ...]
    :param box: gt box (4 point)
    :return: top and bottom coordinates for y-axis
    """
    img = copy.deepcopy(raw_img)
    y_top = []
    y_bottom = []
    height = img.shape[0]
    # 設置圖像mask,channel 0爲全黑圖
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            img[i, j, 0] = 0
            
    top_flag = False
    bottom_flag = False
    # 根據bbox四點畫出文本框,channel 0下文本框爲白色
    img = other.draw_box_4pt(img, box, color=(255, 0, 0))
    
    
    for k in range(len(position_pair)):
        # 從左到右遍歷anchor gt,對每一個anchor從上往下掃描像素,遇到白色像素點(255)就停下來,此時像素點座標y就是該anchor gt的上邊界
        # calc top y coordinate
        for y in range(0, height-1):
            # loop each anchor, from left to right
            for x in range(position_pair[k][0], position_pair[k][1] + 1):
                if img[y, x, 0] == 255:
                    y_top.append(y)
                    top_flag = True
                    break
            if top_flag is True:
                break
        
         # 從左到右遍歷anchor gt,對每一個anchor從下往上掃描像素,遇到白色像素點(255)就停下來,此時像素點座標y就是該anchor gt的下邊界
        # calc bottom y coordinate, pixel from down to top loop
        for y in range(height - 1, -1, -1):
            # loop each anchor, from left to right
            for x in range(position_pair[k][0], position_pair[k][1] + 1):
                if img[y, x, 0] == 255:
                    y_bottom.append(y)
                    bottom_flag = True
                    break
            if bottom_flag is True:
                break
        top_flag = False
        bottom_flag = False
    return y_top, y_bottom

  通過上面的標籤處理,咱們已經將原先的標準的文本框標籤轉化爲一個一個小尺度anchor標籤,如下是標籤轉化後的效果:

 

以上標籤可視化後看來anchor標籤作得不錯,可是這裏須要提出的是,我發現這種anchor生成方法是不太精準的,好比一個文本框邊緣像素恰好落在一個新的anchor上,那麼咱們就要爲這個像素分配一個16像素的anchor,顯然致使了文本框標籤的不許確,引入了15像素的偏差,這個是須要思考的。這個問題咱們先不作處理,繼續下面的工做。

固然轉化期間咱們也遇到不少奇怪的問題,好比下圖這種標籤都已經超出圖像範圍的,咱們必須作相應的特殊處理,好比限定標籤橫座標的最大尺寸爲圖像寬度。

left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width))  # the left side anchor of the text box, downwards
right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width))  # the right side anchor of the text box, upwards

  

訓練過程:

練:優化器咱們選擇SGD,learning rate咱們設置了兩個,前N個epoch使用較大的lr,後面的epoch使用較小的lr以更好地收斂。

訓練過程咱們定義了4個loss,分別是total_cls_loss,total_v_reg_loss, total_o_reg_loss, total_loss(前面三個loss相加)。

 net = Net.CTPN() # 獲取網絡結構
    for name, value in net.named_parameters():
        if name in no_grad:
            value.requires_grad = False
        else:
            value.requires_grad = True
    # for name, value in net.named_parameters():
    #     print('name: {0}, grad: {1}'.format(name, value.requires_grad))
    net.load_state_dict(torch.load('./lib/vgg16.model'))
    # net.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
    lib.utils.init_weight(net)
    if using_cuda:
        net.cuda()
    net.train()
    print(net)

    criterion = Loss.CTPN_Loss(using_cuda=using_cuda)  # 獲取loss

    train_im_list, train_gt_list, val_im_list, val_gt_list = create_train_val()  # 獲取訓練、測試數據
    total_iter = len(train_im_list)
    print("total training image num is %s" % len(train_im_list))
    print("total val image num is %s" % len(val_im_list))

    train_loss_list = []
    test_loss_list = []

    # 開始迭代訓練
    for i in range(epoch):
        if i >= change_epoch:
            lr = lr_behind
        else:
            lr = lr_front
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
        #optimizer = optim.Adam(net.parameters(), lr=lr)
        iteration = 1
        total_loss = 0
        total_cls_loss = 0
        total_v_reg_loss = 0
        total_o_reg_loss = 0
        start_time = time.time()

        random.shuffle(train_im_list)  # 打亂訓練集
        # print(random_im_list)
        for im in train_im_list:
            root, file_name = os.path.split(im)
            root, _ = os.path.split(root)
            name, _ = os.path.splitext(file_name)
            gt_name = 'gt_' + name + '.txt'

            gt_path = os.path.join(root, "train_gt", gt_name)

            if not os.path.exists(gt_path):
                print('Ground truth file of image {0} not exists.'.format(im))
                continue

            gt_txt = lib.dataset_handler.read_gt_file(gt_path)  # 讀取對應的標籤
            #print("processing image %s" % os.path.join(img_root1, im))
            img = cv2.imread(im)
            if img is None:
                iteration += 1
                continue

            img, gt_txt = lib.dataset_handler.scale_img(img, gt_txt)  # 圖像和標籤作歸一化
            tensor_img = img[np.newaxis, :, :, :]
            tensor_img = tensor_img.transpose((0, 3, 1, 2))
            if using_cuda:
                tensor_img = torch.FloatTensor(tensor_img).cuda()
            else:
                tensor_img = torch.FloatTensor(tensor_img)

            vertical_pred, score, side_refinement = net(tensor_img)  # 正向計算,獲取預測結果
            del tensor_img

            # transform bbox gt to anchor gt for training
            positive = []
            negative = []
            vertical_reg = []
            side_refinement_reg = []

            visual_img = copy.deepcopy(img)  # 該圖用於可視化標籤

            try:
                # loop all bbox in one image
                for box in gt_txt:
                    # generate anchors from one bbox
                    gt_anchor, visual_img = lib.generate_gt_anchor.generate_gt_anchor(img, box, draw_img_gt=visual_img)  # 獲取圖像的anchor標籤
                    positive1, negative1, vertical_reg1, side_refinement_reg1 = lib.tag_anchor.tag_anchor(gt_anchor, score, box) # 計算預測值反映在anchor層面的數據
                    positive += positive1
                    negative += negative1
                    vertical_reg += vertical_reg1
                    side_refinement_reg += side_refinement_reg1
            except:
                print("warning: img %s raise error!" % im)
                iteration += 1
                continue

            if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0:
                iteration += 1
                continue

            cv2.imwrite(os.path.join(DRAW_PREFIX, file_name), visual_img)
            optimizer.zero_grad()
            # 計算偏差
            loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive,
                                                               negative, vertical_reg, side_refinement_reg)
            # 反向傳播                                                   
            loss.backward()
            optimizer.step()
            iteration += 1
            # save gpu memory by transferring loss to float
            total_loss += float(loss)
            total_cls_loss += float(cls_loss)
            total_v_reg_loss += float(v_reg_loss)
            total_o_reg_loss += float(o_reg_loss)

            if iteration % display_iter == 0:
                end_time = time.time()
                total_time = end_time - start_time
                print('Epoch: {2}/{3}, Iteration: {0}/{1}, loss: {4}, cls_loss: {5}, v_reg_loss: {6}, o_reg_loss: {7}, {8}'.
                      format(iteration, total_iter, i, epoch, total_loss / display_iter, total_cls_loss / display_iter,
                             total_v_reg_loss / display_iter, total_o_reg_loss / display_iter, im))

                logger.info('Epoch: {2}/{3}, Iteration: {0}/{1}'.format(iteration, total_iter, i, epoch))
                logger.info('loss: {0}'.format(total_loss / display_iter))
                logger.info('classification loss: {0}'.format(total_cls_loss / display_iter))
                logger.info('vertical regression loss: {0}'.format(total_v_reg_loss / display_iter))
                logger.info('side-refinement regression loss: {0}'.format(total_o_reg_loss / display_iter))

                train_loss_list.append(total_loss)

                total_loss = 0
                total_cls_loss = 0
                total_v_reg_loss = 0
                total_o_reg_loss = 0
                start_time = time.time()

            # 按期驗證模型性能
            if iteration % val_iter == 0:
                net.eval()
                logger.info('Start evaluate at {0} epoch {1} iteration.'.format(i, iteration))
                val_loss = evaluate.val(net, criterion, val_batch_size, using_cuda, logger, val_im_list)
                logger.info('End evaluate.')
                net.train()
                start_time = time.time()
                test_loss_list.append(val_loss)

            # 按期存儲模型
            if iteration % save_iter == 0:
                print('Model saved at ./model/ctpn-{0}-{1}.model'.format(i, iteration))
                torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-{1}.model'.format(i, iteration))

        print('Model saved at ./model/ctpn-{0}-end.model'.format(i))
        torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-end.model'.format(i))

        # 畫出loss的變化圖
        draw_loss_plot(train_loss_list, test_loss_list)

縮放圖像具備必定規則:首先要保證文本框label的最短邊也要等於600。咱們經過

  scale = float(shortest_side)/float(min(height, width))

來求得圖像的縮放係數,對原始圖像進行縮放。 同時咱們也要對咱們的label也要根據該縮放係數進行縮放。

 

def scale_img(img, gt, shortest_side=600):
    height = img.shape[0]
    width = img.shape[1]
    scale = float(shortest_side)/float(min(height, width))
    img = cv2.resize(img, (0, 0), fx=scale, fy=scale)
    if img.shape[0] < img.shape[1] and img.shape[0] != 600:
        img = cv2.resize(img, (600, img.shape[1]))
    elif img.shape[0] > img.shape[1] and img.shape[1] != 600:
        img = cv2.resize(img, (img.shape[0], 600))
    elif img.shape[0] != 600:
        img = cv2.resize(img, (600, 600))
    h_scale = float(img.shape[0])/float(height)
    w_scale = float(img.shape[1])/float(width)
    scale_gt = []
    for box in gt:
        scale_box = []
        for i in range(len(box)):
            # x座標
            if i % 2 == 0:
                scale_box.append(int(int(box[i]) * w_scale))
            # y座標
            else:
                scale_box.append(int(int(box[i]) * h_scale))
        scale_gt.append(scale_box)
    return img, scale_gt

  驗證集評估:

def val(net, criterion, batch_num, using_cuda, logger):
    img_root = '../dataset/OCR_dataset/ctpn/test_im'
    gt_root = '../dataset/OCR_dataset/ctpn/test_gt'
    img_list = os.listdir(img_root)
    total_loss = 0
    total_cls_loss = 0
    total_v_reg_loss = 0
    total_o_reg_loss = 0
    start_time = time.time()
    for im in random.sample(img_list, batch_num):
        name, _ = os.path.splitext(im)
        gt_name = 'gt_' + name + '.txt'
        gt_path = os.path.join(gt_root, gt_name)
        if not os.path.exists(gt_path):
            print('Ground truth file of image {0} not exists.'.format(im))
            continue

        gt_txt = Dataset.port.read_gt_file(gt_path, have_BOM=True)
        img = cv2.imread(os.path.join(img_root, im))
        img, gt_txt = Dataset.scale_img(img, gt_txt)
        tensor_img = img[np.newaxis, :, :, :]
        tensor_img = tensor_img.transpose((0, 3, 1, 2))
        if using_cuda:
            tensor_img = torch.FloatTensor(tensor_img).cuda()
        else:
            tensor_img = torch.FloatTensor(tensor_img)

        vertical_pred, score, side_refinement = net(tensor_img)
        del tensor_img
        positive = []
        negative = []
        vertical_reg = []
        side_refinement_reg = []
        for box in gt_txt:
            gt_anchor = Dataset.generate_gt_anchor(img, box)
            positive1, negative1, vertical_reg1, side_refinement_reg1 = Net.tag_anchor(gt_anchor, score, box)
            positive += positive1
            negative += negative1
            vertical_reg += vertical_reg1
            side_refinement_reg += side_refinement_reg1

        if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0:
            batch_num -= 1
            continue

        loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive,
                                                           negative, vertical_reg, side_refinement_reg)
        total_loss += loss
        total_cls_loss += cls_loss
        total_v_reg_loss += v_reg_loss
        total_o_reg_loss += o_reg_loss
    end_time = time.time()
    total_time = end_time - start_time
    print('####################  Start evaluate  ####################')
    print('loss: {0}'.format(total_loss / float(batch_num)))
    logger.info('Evaluate loss: {0}'.format(total_loss / float(batch_num)))

    print('classification loss: {0}'.format(total_cls_loss / float(batch_num)))
    logger.info('Evaluate vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num)))

    print('vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num)))
    logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))

    print('side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))
    logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))

    print('{1} iterations for {0} seconds.'.format(total_time, batch_num))
    print('#####################  Evaluate end  #####################')
    print('\n')

  

訓練效果與預測效果

測試效果:輸入一張圖片,給出最後的檢測結果

def infer_one(im_name, net):
    im = cv2.imread(im_name)
    im = lib.dataset_handler.scale_img_only(im)  # 歸一化圖像
    img = copy.deepcopy(im)
    img = img.transpose(2, 0, 1)
    img = img[np.newaxis, :, :, :]
    img = torch.Tensor(img)
    v, score, side = net(img, val=True)  # 送入網絡預測
    result = []
    # 根據分數獲取有文字的anchor
    for i in range(score.shape[0]):
        for j in range(score.shape[1]):
            for k in range(score.shape[2]):
                if score[i, j, k, 1] > THRESH_HOLD:
                    result.append((j, k, i, float(score[i, j, k, 1].detach().numpy())))

    # nms過濾
    for_nms = []
    for box in result:
        pt = lib.utils.trans_to_2pt(box[1], box[0] * 16 + 7.5, anchor_height[box[2]])
        for_nms.append([pt[0], pt[1], pt[2], pt[3], box[3], box[0], box[1], box[2]])
    for_nms = np.array(for_nms, dtype=np.float32)
    nms_result = lib.nms.cpu_nms(for_nms, NMS_THRESH)

    out_nms = []
    for i in nms_result:
        out_nms.append(for_nms[i, 0:8])

    # 肯定哪幾個anchors是屬於一組的
    connect = get_successions(v, out_nms)
    # 將一組anchors合併成一條文本線
    texts = get_text_lines(connect, im.shape)

    for box in texts:
        box = np.array(box)
        print(box)
        lib.draw_image.draw_ploy_4pt(im, box[0:8])

    _, basename = os.path.split(im_name)
    cv2.imwrite('./infer_'+basename, im)

  

推斷時提到了get_successions用於獲取一個預測文本行裏的全部anchors,換句話說,咱們獲得的不少預測有字符的anchor,可是咱們怎麼知道哪些acnhors能夠組成一個文本線呢?因此咱們須要實現一個anchor合併算法,這也是CTPN代碼實現中最爲困難的一步。

CTPN論文提到,文本線構造法以下:文本行構建很簡單,經過將那些text/no-text score > 0.7的連續的text proposals相鏈接便可。文本行的構建以下。

  • 首先,爲一個proposal Bi定義一個鄰居(Bj):Bj−>Bi,其中:
  1. Bj在水平距離上離Bi最近
  2. 該距離小於50 pixels
  • 它們的垂直重疊(vertical overlap) > 0.7

一看理論很簡單,可是一到本身實現就困難重重了。真是應了那句「紙上得來終覺淺,絕知此事要躬行」啊!get_successions傳入的參數是v表明每一個預測anchor的h和y信息,anchors表明每一個anchors的四個頂點座標信息。

檢測效果和總結

首先看一下訓練出來的模型的文字檢測效果,爲了便於觀察,我把anchor和最終合併好的文本框一併畫出:

 

 

在實現過程當中的一些總結和想法:

  1. CTPN對於帶旋轉角度的文本的檢測效果很差,其實這是CTPN的算法特色決定的:一個個固定寬度的四邊形是很難合併出一個準確的文本框,好比一些anchors很難組成一組,即便組成一組了也很難精確恢復成完整的精確的文本矩形框(推斷階段的缺點)。固然啦,對於水平排布的文本檢測,我的認爲這個算法思路仍是很奏效的。
  2. CTPN中的side-refinement其實做用不大,若是咱們檢測出來的文本是直接拿出識別,這個side-refinement優化的幾個像素差異其實能夠忽略;
  3. CTPN的中間步驟有點多:從anchor標籤的生成到中間計算loss再到最後推斷的文本線生成步驟,都會引入必定的偏差,這個缺點也是EAST論文中所提出的。訓練的步驟越簡潔,中間過程越少,精度更有保障。
  4. CTPN的算法得出的效果能夠看出,準確率低但召回率高。這種基於16像素的anchor識別感受對於一些大的非文字圖標(好比路標)誤判率至關高,這是源於其anchor的寬度實在過小了,儘管使用了lstm關聯周圍anchor,可是我仍是認爲有點「一葉障目」的感受。因此CTPN對於過大或太小的文字檢測效果不會太好。
  5. CTPN是個比較老的算法了(2016年),其思路在當年仍是很創新的,可是也有不少弊端。如今提出的新方法已經基本解決了這些不足之處,好比EAST,PixelNet都是一些很優秀的新算法。

CTPN的完整實現能夠參考博主:   Github

 

https://www.cnblogs.com/skyfsm/p/10054386.html

相關文章
相關標籤/搜索