EAST結構分析+pytorch源碼實現

EAST結構分析+pytorch源碼實現

@

一. U-Net的前車可鑑

在介紹EAST網絡以前咱們先介紹一下前面的幾個網絡,看看這個EAST網絡怎麼來的?爲何來的?python

固然這裏的介紹僅僅是引出EAST而不是詳細的講解其餘網絡,有須要的讀者能夠去看看這三個優秀網絡。git

1.1 FCN網絡結構

​ FCN網絡,在以前FCN從原理到代碼的理解已經詳細分析了,有須要的能夠去看看,順便跑一跑代碼。github

圖1-1

  • 網絡的由來

不論是識別(傳統機器學習、CNN)仍是檢測(SSD、YOLO等),都只是基於大塊的特徵進行的,檢測以後都是以長方形去表示檢測結果,因爲這是其算法內部迴歸的結果致使,並且feature map通過卷積一直減少,若是強行進行256X256512X512的插值,那麼結果能夠想象,邊界很是很差。算法

那麼如何實現圖1-1所示的結果呢?把每一個像素都進行分割?服務器

  • 網絡的成果

FCN給出的方法是使用反捲積進行上採樣操做,使得通過CNN以後減少的圖可以恢復大小。網絡

固然做者還提出一個好方法,不一樣的feature map進行組合,使得感覺野進行擴充。架構

註釋:筆者認爲使用反捲積有兩個做用,其一是使得計算LOSS比較方便,標籤和結果能夠直接進行計算。其二是能夠進行參數的學習,更爲智能化。app

1.2 U-NET網絡

U-net網絡以前沒怎麼看過,如今也僅僅是大概看了論文和相關資料,內部實現不是很瞭解。dom

圖1-2

  • 網絡的由來

FCN徹底能夠作到基於像素點的分割,爲何還要這個U-net網絡啊?

FCN網絡檢測的效果還能夠,可是其邊緣的處理就特別的差。雖說多個層進行合併,可是合併的內容雜亂無章,致使最後的信息沒有徹底獲得。

總的來講FCN分割的效果不夠,精度也不夠。

  • 網絡的成果

U-net提出了對稱的網絡結構,使得網絡參數的學習效果更好(爲何對稱網絡學習更好,這個理解不透,若是是結果再放大一倍使得不對稱不也同樣嗎?感受仍是網絡結構設計的好,而不是對稱)

不一樣feature map合併的方式更加優化,使得在邊緣分割(細節)上更加優秀。

網絡架構清晰明瞭,分割效果也很好,如今醫學圖像分割領域還能看見身影。

1.3 CTPN網絡

剛開始準備使用CTPN進行文本的檢測,因此看了一些相關資料,致命缺點是不能檢測帶角度文字和網絡比較複雜。

圖1-3

  • 網絡的由來

文本檢測和其餘檢測卻別很大,好比用SSD檢測文本就比較困難(邊緣檢測很差),如何針對文本進行檢測?

  • 網絡的成果

CTPN網絡有不少創造的想法-->>

目標分割小塊,而後一一進行檢測,針對文本分割成height>width的方式,使得檢測的邊緣更爲精確。

使用BiLSTM對小塊進行鏈接,針對文本之間的相關性。

CTPN想法具備創造性,可是太過複雜。

  1. 首先樣本的製做麻煩
  2. 每一個小框進行迴歸,框的大小本身定義
  3. 邊緣特地進行偏移處理
  4. 使用RNN進行鏈接

檢測水平效果仍是不錯的,可是對於傾斜的文本就不行了。

爲何不加一個angle進行迴歸?

本就很複雜的網絡,若是再給每一個小box加一個angle參數會更復雜,固然是能夠實施的。

二. EAST結構分析

2.1 結構簡述

EAST原名爲: An Efficient and Accurate Scene Text Detector

結構:檢測層(PVANet) + 合併層 + 輸出層

圖2-1

下圖圖2-2是檢測效果,任意角度的文本均可以檢測到。

注意:EAST只是一個檢測網絡,如需識別害的使用CRNN等識別網絡進行後續操做。

圖2-2

具體網絡在2-2節進行詳細介紹=====>>>

2.2 結構詳解

  • 總體結構

EAST根據他的名字,咱們知道就是高效的文本檢測方法。

上面咱們介紹了CTPN網絡,其標籤製做很麻煩,結構很複雜(分割成小方框而後迴歸還要RNN進行合併)

看下圖圖2-3,只要進行相似FCN的結構,計算LOSS就能夠進行訓練。測試的時候走過網絡,運行NMS就能夠得出結果。太簡單了是否是?

圖2-3

  • 特徵提取層

特徵的提取能夠任意網絡(VGG、RES-NET等檢測網絡),本文以VGG爲基礎進行特徵提取。這個比較簡單,看一下源碼就能夠清楚,見第四章源碼分析

  • 特徵合併層

在合併層中,首先在定義特徵提取層的時候把須要的輸出給保留下來,經過forward函數把結構進行輸出。以後再合併層調用便可

以下代碼定義,其中合併的過程再下面介紹

#提取VGG模型訓練參數
class extractor(nn.Module):
    def __init__(self, pretrained):
        super(extractor, self).__init__()
        vgg16_bn = VGG(make_layers(cfg, batch_norm=True))
        if pretrained:
            vgg16_bn.load_state_dict(torch.load('./pths/vgg16_bn-6c64b313.pth'))
        self.features = vgg16_bn.features
    
    def forward(self, x):
        out = []
        for m in self.features:
            x = m(x)
            #提取maxpool層爲後續合併
            if isinstance(m, nn.MaxPool2d):
                out.append(x)
        return out[1:]
  • 特徵合併層

合併特徵提取層的輸出,具體的定義以下代碼所示,代碼部分已經註釋.

其中x中存放的是特徵提取層的四個輸出

def forward(self, x):

        y = F.interpolate(x[3], scale_factor=2, mode='bilinear', align_corners=True)
        y = torch.cat((y, x[2]), 1)
        y = self.relu1(self.bn1(self.conv1(y)))     
        y = self.relu2(self.bn2(self.conv2(y)))
        
        y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
        y = torch.cat((y, x[1]), 1)
        y = self.relu3(self.bn3(self.conv3(y)))     
        y = self.relu4(self.bn4(self.conv4(y)))
        
        y = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
        y = torch.cat((y, x[0]), 1)
        y = self.relu5(self.bn5(self.conv5(y)))     
        y = self.relu6(self.bn6(self.conv6(y)))
        
        y = self.relu7(self.bn7(self.conv7(y)))
        return y
  • 輸出層

輸出層包括三個部分,這裏以RBOX爲例子,發現網上都沒有QUAN爲例子的?

首先QUAN的計算是爲了防止透視變換的存在,正常狀況下不存在這些問題,正常的斜框能夠解決。

由於QUAN的計算沒啥好處,前者已經徹底能夠解決正常的檢測問題,後者迴歸四個點相對來講較爲困難(若是文本變化較大就更困難,因此SSD和YOLO沒法檢測文本的緣由)。

若是想獲得特殊的文本,基本考慮別的網絡了(好比彎曲文字的檢測)

def forward(self, x):
        score = self.sigmoid1(self.conv1(x))
        loc   = self.sigmoid2(self.conv2(x)) * self.scope
        angle = (self.sigmoid3(self.conv3(x)) - 0.5) * math.pi
        geo   = torch.cat((loc, angle), 1) 
        return score, geo

三. EAST細節分析

3.1 標籤製做

注意:這裏是重點和難點!!!

文章說要把標籤向裏縮進0.3

筆者認爲這樣作的目的是提取到更爲準確的信息,不管是人工標註的好與很差,咱們按照0.3縮小以後提取的特徵都是所有的文本信息。

可是這樣作也會丟失一些邊緣信息,若是按照上述的推斷,那麼SSD或YOLO均可以這樣設計標籤了。

做者確定是通過測試的,有好處有壞處吧!

圖3-1

標籤格式爲:5個geometry(4個location+1個angle) + 1個score ==6 × N × M

其中(b)爲score圖 ,(d)爲四個location圖, (e)爲angle圖

上圖可能看的不清楚,下面以手繪圖進行說明:

圖3-2

上圖可能看不清楚,下面再用文字大概說一下吧!

  1. 先進行0.3縮放,這個時候的圖就是score
  2. 沒縮放的圖像爲基準,畫最小外接矩形,這個外接矩形的角度就是angle。這個大小是縮放的的圖大小。感受直接以score圖作角度也同樣的。
  3. score圖的每一個像素點到最小外接矩形的距離爲四個location圖。

3.2 LOSS計算

LOSS計算就比較簡單的,直接回歸location、angle、score便可。

def forward(self, gt_score, pred_score, gt_geo, pred_geo, ignored_map):
        #圖像中不存在目標直接返回0
        if torch.sum(gt_score) < 1:
            return torch.sum(pred_score + pred_geo) * 0
        #score loss 採用Dice方式計算,沒有采用log熵計算,爲了防止樣本不均衡問題
        classify_loss = get_dice_loss(gt_score, pred_score*(1-ignored_map))
        #geo loss採用Iou方式計算(計算每一個像素點的loss)
        iou_loss_map, angle_loss_map = get_geo_loss(gt_geo, pred_geo)
        #計算一整張圖的loss,angle_loss_map*gt_score去除不是目標點的像素(感受這句話應該放在前面減小計算量,放在這裏沒有減小計算loss的計算量)
        angle_loss = torch.sum(angle_loss_map*gt_score) / torch.sum(gt_score)
        iou_loss = torch.sum(iou_loss_map*gt_score) / torch.sum(gt_score)
        geo_loss = self.weight_angle * angle_loss + iou_loss#這裏的權重設置爲1
        print('classify loss is {:.8f}, angle loss is {:.8f}, iou loss is {:.8f}'.format(classify_loss, angle_loss, iou_loss))
        return geo_loss + classify_loss

注意:這裏score的LOSS使用Dice方式,由於普通的交叉熵沒法解決樣本不均衡問題!!!

圖3-3

3.3 NMS計算

NMS使用的是locality NMS,也就是爲了針對EAST而提出來的。

首先咱們先來看看這個LANMS的原理和過程:

import numpy as np
from shapely.geometry import Polygon

def intersection(g, p):
    #取g,p中的幾何體信息組成多邊形
    g = Polygon(g[:8].reshape((4, 2)))
    p = Polygon(p[:8].reshape((4, 2)))

    # 判斷g,p是否爲有效的多邊形幾何體
    if not g.is_valid or not p.is_valid:
        return 0

    # 取兩個幾何體的交集和並集
    inter = Polygon(g).intersection(Polygon(p)).area
    union = g.area + p.area - inter
    if union == 0:
        return 0
    else:
        return inter/union

def weighted_merge(g, p):
    # 取g,p兩個幾何體的加權(權重根據對應的檢測得分計算獲得)
    g[:8] = (g[8] * g[:8] + p[8] * p[:8])/(g[8] + p[8])
    
    #合併後的幾何體的得分爲兩個幾何體得分的總和
    g[8] = (g[8] + p[8])
    return g

def standard_nms(S, thres):
    #標準NMS
    order = np.argsort(S[:, 8])[::-1]
    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
        inds = np.where(ovr <= thres)[0]
        order = order[inds+1]
        
    return S[keep]

def nms_locality(polys, thres=0.3):
    '''
    locality aware nms of EAST
    :param polys: a N*9 numpy array. first 8 coordinates, then prob
    :return: boxes after nms
    '''
    S = []    #合併後的幾何體集合
    p = None   #合併後的幾何體
    for g in polys:
        if p is not None and intersection(g, p) > thres:    #若兩個幾何體的相交面積大於指定的閾值,則進行合併
            p = weighted_merge(g, p)
        else:    #反之,則保留當前的幾何體
            if p is not None:
                S.append(p)
            p = g
    if p is not None:
        S.append(p)
    if len(S) == 0:
        return np.array([])
    return standard_nms(np.array(S), thres)

if __name__ == '__main__':
    # 343,350,448,135,474,143,369,359
    print(Polygon(np.array([[343, 350], [448, 135],
                            [474, 143], [369, 359]])).area)

別看那麼多代碼,講的很玄乎,其實很簡單:

  1. 遍歷每一個預測的框,而後按照交集大於某個值K就合併相鄰的兩個框。
  2. 合併完以後就按照正常NMS消除不合理的框就好了。

注意: 爲何相鄰的框合併?

  1. 由於每一個像素預測一個框(不明白就本身去看上面LOSS計算),一個目標的幾百上千個框基本都是重合的(若是預測的準的話),因此說相鄰的框直接進行合併就好了。
  2. 其實豎直和橫向都合併一次最好,反正原理同樣的。

四. Pytorch源碼分析

源碼就不進行分析了,上面已經說得很是明白了,基本每一個難點和重點都說到了。

有一點小bug,現進行說明:

  1. 訓練的時候出現孔樣本跑死
SampleNum = 3400 #定義樣本數量,應對空標籤的文本bug,臨時處理方案
class custom_dataset(data.Dataset):
    def __init__(self, img_path, gt_path, scale=0.25, length=512):
        super(custom_dataset, self).__init__()
        self.img_files = [os.path.join(img_path, img_file) for img_file in sorted(os.listdir(img_path))]
        self.gt_files  = [os.path.join(gt_path, gt_file) for gt_file in sorted(os.listdir(gt_path))]
        self.scale = scale
        self.length = length

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, index):
        with open(self.gt_files[index], 'r') as f:
            lines = f.readlines()
        while(len(lines)<1):
            index = int(SampleNum*np.random.rand())
            with open(self.gt_files[index], 'r') as f:
                lines = f.readlines()
        vertices, labels = extract_vertices(lines)
        
        img = Image.open(self.img_files[index])
        img, vertices = adjust_height(img, vertices) 
        img, vertices = rotate_img(img, vertices)
        img, vertices = crop_img(img, vertices, labels, self.length,index)
        transform = transforms.Compose([transforms.ColorJitter(0.5, 0.5, 0.5, 0.25), \
                                        transforms.ToTensor(), \
                                        transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
        
        score_map, geo_map, ignored_map = get_score_geo(img, vertices, labels, self.scale, self.length)
        return transform(img), score_map, geo_map, ignored_map
  1. 測試的時候讀取PIL會出現RGBA狀況
img_path    = './013.jpg'
    model_path  = './pths/model_epoch_225.pth'
    res_img     = './res.bmp'
    img = Image.open(img_path)
    img = np.array(img)[:,:,:3]
    img = Image.fromarray(img)
  • 後續工做

    1. 這個代碼感受有點問題,訓練速度很慢,猜想是數據處理部分。
    2. 原版EAST每一個點都進行迴歸,太浪費時間了,後續參考AdvanceEAST進行修改,同時加我的理解優化
    3. 網絡太大了,只適合服務器或者PC上跑,當前網絡已經修改到15MB,感受仍是有點大。
    4. 後續還要加識別部分,困難重重。。。。。。

這裏的代碼都是github上的,筆者只是搬運工而已!!!

原做者下載地址

五. 第一次更新內容

  • 2019-6-30更新

以前提到這個工程的代碼有幾個缺陷,在這裏進行詳細的解決

  1. 訓練速度很慢

這是因爲源代碼的數據處理部分編寫有問題致使,隨機crop中對於邊界問題處理
如下給出解決方案,具體修改請讀者對比源代碼便可:

def crop_img(img, vertices, labels, length, index):
    '''crop img patches to obtain batch and augment
    Input:
        img         : PIL Image
        vertices    : vertices of text regions <numpy.ndarray, (n,8)>
        labels      : 1->valid, 0->ignore, <numpy.ndarray, (n,)>
        length      : length of cropped image region
    Output:
        region      : cropped image region
        new_vertices: new vertices in cropped region
    '''
    try:
        h, w = img.height, img.width
        # confirm the shortest side of image >= length
        if h >= w and w < length:
            img = img.resize((length, int(h * length / w)), Image.BILINEAR)
        elif h < w and h < length:
            img = img.resize((int(w * length / h), length), Image.BILINEAR)
        ratio_w = img.width / w
        ratio_h = img.height / h
        assert(ratio_w >= 1 and ratio_h >= 1)

        new_vertices = np.zeros(vertices.shape)
        if vertices.size > 0:
            new_vertices[:,[0,2,4,6]] = vertices[:,[0,2,4,6]] * ratio_w
            new_vertices[:,[1,3,5,7]] = vertices[:,[1,3,5,7]] * ratio_h
        #find four limitate point by vertices
        vertice_x = [np.min(new_vertices[:, [0, 2, 4, 6]]), np.max(new_vertices[:, [0, 2, 4, 6]])]
        vertice_y = [np.min(new_vertices[:, [1, 3, 5, 7]]), np.max(new_vertices[:, [1, 3, 5, 7]])]
        # find random position
        remain_w = [0,img.width - length]
        remain_h = [0,img.height - length]
        if vertice_x[1]>length:
            remain_w[0] = vertice_x[1] - length
        if vertice_x[0]<remain_w[1]:
            remain_w[1] = vertice_x[0]
        if vertice_y[1]>length:
            remain_h[0] = vertice_y[1] - length
        if vertice_y[0]<remain_h[1]:
            remain_h[1] = vertice_y[0]

        start_w = int(np.random.rand() * (remain_w[1]-remain_w[0]))+remain_w[0]
        start_h = int(np.random.rand() * (remain_h[1]-remain_h[0]))+remain_h[0]
        box = (start_w, start_h, start_w + length, start_h + length)
        region = img.crop(box)
        if new_vertices.size == 0:
            return region, new_vertices

        new_vertices[:,[0,2,4,6]] -= start_w
        new_vertices[:,[1,3,5,7]] -= start_h
    except IndexError:
        print("\n crop_img function index error!!!\n,imge is %d"%(index))
    else:
        pass
    return region, new_vertices
  1. LOSS剛開始收斂降低,到後面就呈現抖動(像過擬合現象),檢測效果角度不好

因爲Angle Loss角度計算錯誤致使,請讀者閱讀做者原文進行對比

def find_min_rect_angle(vertices):
    '''find the best angle to rotate poly and obtain min rectangle
    Input:
        vertices: vertices of text region <numpy.ndarray, (8,)>
    Output:
        the best angle <radian measure>
    '''
    angle_interval = 1
    angle_list = list(range(-90, 90, angle_interval))
    area_list = []
    for theta in angle_list: 
        rotated = rotate_vertices(vertices, theta / 180 * math.pi)
        x1, y1, x2, y2, x3, y3, x4, y4 = rotated
        temp_area = (max(x1, x2, x3, x4) - min(x1, x2, x3, x4)) * \
                    (max(y1, y2, y3, y4) - min(y1, y2, y3, y4))
        area_list.append(temp_area)
    
    sorted_area_index = sorted(list(range(len(area_list))), key=lambda k : area_list[k])
    min_error = float('inf')
    best_index = -1
    rank_num = 10
    # find the best angle with correct orientation
    for index in sorted_area_index[:rank_num]:
        rotated = rotate_vertices(vertices, angle_list[index] / 180 * math.pi)
        temp_error = cal_error(rotated)
        if temp_error < min_error:
            min_error = temp_error
            best_index = index

    if angle_list[best_index]>0:
        return (angle_list[best_index] - 90) / 180 * math.pi

    return (angle_list[best_index]+90) / 180 * math.pi
  1. 修改網絡從50MB到15MB,對於小樣本訓練效果很好

這裏比較簡單,直接修改VGG和U-NET網絡feature map便可

cfg = [32, 32, 'M', 64, 64, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M']
#合併不一樣的feature map
class merge(nn.Module):
    def __init__(self):
        super(merge, self).__init__()

        self.conv1 = nn.Conv2d(512, 128, 1)
        self.bn1 = nn.BatchNorm2d(128)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(128, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.Conv2d(256, 64, 1)
        self.bn3 = nn.BatchNorm2d(64)
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.relu4 = nn.ReLU()

        self.conv5 = nn.Conv2d(128, 32, 1)
        self.bn5 = nn.BatchNorm2d(32)
        self.relu5 = nn.ReLU()
        self.conv6 = nn.Conv2d(32, 32, 3, padding=1)
        self.bn6 = nn.BatchNorm2d(32)
        self.relu6 = nn.ReLU()

        self.conv7 = nn.Conv2d(32, 32, 3, padding=1)
        self.bn7 = nn.BatchNorm2d(32)
        self.relu7 = nn.ReLU()
        #初始化網絡參數
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
  1. 小的字體檢測很好,大的字體檢測不到(部分檢測不到)狀況

這裏是模仿AdvanceEAST的方法進行訓練,先在小圖像進行訓練,而後遷移到大圖像便可。

意思就是先將圖像縮小到254254訓練獲得modeul_254.pth
而後在將圖像resize到384
384,網絡參數使用modeul_254.pth,訓練獲得modeul_384.pth
。。。一次進行512或者更大的圖像便可

  1. 針對圖像訓練和檢測的慢(相對於其餘檢測網絡)

這裏須要根據原理來講了,是由於所有的像素都須要預測和計算loss,能夠看看AdvanceEAST的網絡進行處理便可

  1. 修改網絡說明

    訓練樣本3000
    測試樣本100
    檢測精度85%,IOU準確度80%
    5個epoch收斂結束(這些都是這裏測試的)
    兩塊1080TI,訓練時間10分鐘左右

這裏是我完整的工程


五. 參考文獻

相關文章
相關標籤/搜索