摘要:本文提出了一種基於Transformer的端到端的線段檢測模型。採用多尺度的Encoder/Decoder算法,能夠獲得比較準確的線端點座標。做者直接用預測的線段端點和Ground truth的端點的距離做爲目標函數,能夠更好的對線段端點座標進行迴歸。
本文分享自華爲雲社區《論文解讀系列十七:基於Transformer的直線段檢測》,做者:cver。算法
傳統的形態學線段檢測首先要對圖像進行邊緣檢測,而後進行後處理獲得線段的檢測結果。通常的深度學習方法,首先要獲得線段端點和線的熱力圖特徵,而後進行融合處理獲得線的檢測結果。做者提出了一種新的基於Transformer的方法,無需進行邊緣檢測、也無需端點和線的熱力圖特徵,端到端的直接獲得線段的檢測結果,也即線段的端點座標。函數
線段檢測屬於目標檢測的範疇,本文提出的線段檢測模型LETR是在DETR(End-to-End Object Detection with Transformers)的基礎上的擴展,區別就是Decoder在最後預測和迴歸的時候,一個是迴歸的box的中心點、寬、高值,一個是迴歸的線的端點座標。性能
所以,接下來首先介紹一下DETR是如何利用Transformer進行目標檢測的。以後重點介紹一下LETR獨有的一些內容。學習
圖1. DETR模型結構測試
上圖是DETR的模型結構。DETR首先利用一個CNN 的backbone提取圖像的features,編碼以後輸入Transformer模型獲得N個預測的box,而後利用FFN進行分類和座標迴歸,這一部分和傳統的目標檢測相似,以後把N個預測的box和M個真實的box進行二分匹配(N>M,多出的爲空類,即沒有物體,座標值直接設置爲0)。利用匹配結果和匹配的loss更新權重參數,獲得最終的box的檢測結果和類別。這裏有幾個關鍵點:編碼
CNN-backbone輸出的特徵的維度爲C*H*W,首先用1*1的conv進行降維,將channel從C壓縮到d, 獲得d*H*W的特徵圖。以後合併H、W兩個維度,特徵圖的維度變爲d*HW。序列化的特徵圖丟失了原圖的位置信息,所以須要再加上position encoding特徵,獲得最終序列化編碼的特徵。3d
目標檢測的Transformer的Decoder是一次處理所有的Decoder輸入,也即 object queries,和原始的Transformer從左到右一個一個輸出略有不一樣。code
另一點Decoder的輸入是隨機初始化的,而且是能夠訓練更新的。orm
Transformer的Decoder輸出了N個object proposal ,咱們並不知道它和真實的Ground truth的對應關係,所以須要經二分圖匹配,採用的是匈牙利算法,獲得一個使的匹配loss最小的匹配。匹配loss以下:blog
獲得最終匹配後,利用這個loss和分類loss更新參數。
圖2. LETR模型結構
Transformer的結構主要包括Encoder、Decoder 和 FFN。每一個Encoder包含一個self-attention和feed-forward兩個子層。Decoder 除了self-attention和feed-forward還包含cross-attention。注意力機制:注意力機制和原始的Transformer相似,惟一的不一樣就是Decoder的cross-attention,上文已經作了介紹,就再也不贅述。
從上圖中能夠看出LETR包含了兩個Transformer。做者稱此爲a multi-scale Encoder/Decoder strategy,兩個Transformer分別稱之爲Coarse Encoder/Decoder,Fine Encoder/Decoder。也就是先用CNN backbone深層的小尺度的feature map(ResNet的conv5,feature map的尺寸爲原圖尺寸的1/32,通道數爲2048) 訓練一個Transformer,即Coarse Encoder/Decoder,獲得粗粒度的線段的特徵(訓練的時候固定Fine Encoder/Decoder,只更新Coarse Encoder/Decoder的參數)。而後把Coarse Decoder的輸出做爲Fine Decoder的輸入,再訓練一個Transformer,即Fine Encoder/Decoder。Fine Encoder的輸入是CNN backbone淺層的feature map(ResNet的conv4,feature map的尺寸爲原圖尺寸的1/16,通道數爲1024),比深層的feature map具備更大的維度,能夠更好的利用圖像的高分辨率信息。
注:CNN的backbone深層和淺層的feature map特徵都須要先經過1*1的卷積把通道數都降到256維,再做爲Transformer的輸入
和DETR同樣, 利用fine Decoder的N個輸出進行分類和迴歸,獲得N個線段的預測結果。可是咱們並不知道N個預測結果和M個真實的線段的對應關係,而且N還要大於M。這個時候咱們就要進行二分匹配。所謂的二分匹配就是找到一個對應關係,使得匹配loss最小,所以咱們須要給出匹配的loss,和上面DERT的表達式同樣,只不過這一項略有不一樣,一個是GIou一個是線段的端點距離。
模型在Wireframe和YorkUrban數據集上達到了state-of–the-arts。
圖3. 線段檢測方法效果對比
圖四、線段檢測方法在兩種數據集上的性能指標對比(Table 1);線段檢測方法的PR曲線(Figure 6)