Attention Cluster 模型html
視頻分類問題在視頻標籤、監控、自動駕駛等領域有着普遍的應用,但它同時也是計算機視覺領域面臨的一項重要挑戰之一。網絡
目前的視頻分類問題大可能是基於 CNN 或者 RNN 網絡實現的。衆所周知,CNN 在圖像領域已經發揮了重大做用。它具備很好的特徵提取能力,經過卷積層和池化層,能夠在圖像的不一樣區域提取特徵。RNN 則在獲取時間相關的特徵方面有很強的能力。學習
Attention Cluster 在設計上僅利用了 CNN 模型,而沒有使用 RNN,主要是基於視頻的如下幾個特色考慮:測試
圖 1 視頻幀的分析優化
首先,一段視頻的連續幀經常有必定的類似性。在圖 1(上)能夠看到,除了擊球的動做之外,不一樣幀幾乎是同樣的。所以,對於分類,可能從總體上關注這些類似的特徵就足夠了,而沒有必要去特地觀察它們隨着時間的細節變化。ui
其次,視頻幀中的局部特徵有時就足夠表達出視頻的類別。好比圖 1(中),經過一些局部特徵,如牙刷、水池,就可以分辨出『刷牙』這個動做。所以,對於分類問題,關鍵在於找到幀中的關鍵的局部特徵,而非去找時間上的線索。google
最後,在一些視頻的分類中,幀的時間順序對於分類不必定是重要的。好比圖 1(下),能夠看到,雖然幀順序被打亂,依然可以看出這屬於『撐杆跳』這個類別。設計
基於以上考慮,該模型沒有考慮時間相關的線索,而是使用了 Attention 機制。它有如下幾點好處:orm
固然,一些視頻的局部特徵還有一個特色,那就是它可能會由多個部分組成。好比圖 1(下)的『撐杆跳』,跳、跑和着陸同時對這個分類起到做用。所以,若是隻用單一的 Attention 單元,只能獲取視頻的單一關鍵信息。而若是使用多個 Attention 單元,就可以提取更多的有用信息。因而,Attention Cluster 就應運而生了!在實現過程當中,百度計算機視覺團隊還發現,將不一樣的 Attention 單元進行一次簡單有效的『位移操做』(shifting operation),能夠增長不一樣單元的多樣性,從而提升準確率。視頻
接下來咱們看一下整個 Attention Cluster 的結構。
整個模型能夠分爲三個部分:
(1)。X 的維度爲 L,表明 L 個不一樣的特徵。
用 PaddlePaddle 訓練 Attention Cluster
PaddlePaddle 開源的 Attention Cluster 模型,使用了 2nd-Youtube-8M 數據集。該數據集已經使用了在 ImageNet 訓練集上 InceptionV3 模型對特徵進行了抽取。
若是運行該模型的樣例代碼,要求使用 PaddlePaddle Fluid V1.2.0 或以上的版本。
數據準備:首先請使用 Youtube-8M 官方提供的連接下載訓練集和測試集,或者使用官方腳本下載。數據下載完成後,將會獲得 3844 個訓練數據文件和 3844 個驗證數據文件(TFRecord 格式)。爲了適用於 PaddlePaddle 訓練,須要將下載好的 TFRecord 文件格式轉成了 pickle 格式,轉換腳本請使用 PaddlePaddle 提供的腳本 dataset/youtube8m/tf2pkl.py。
訓練集:http://us.data.yt8m.org/2/fra...
測試集:http://us.data.yt8m.org/2/fra...
官方腳本:https://research.google.com/y...
模型訓練:數據準備完畢後,經過如下方式啓動訓練(方法 1),同時咱們也提供快速啓動腳本 (方法 2)
用戶也可下載 Paddle Github 上已發佈模型經過--resume 指定權重存放路徑進行 finetune 等開發。
數據預處理說明: 模型讀取 Youtube-8M 數據集中已抽取好的 rgb 和 audio 數據,對於每一個視頻的數據,均勻採樣 100 幀,該值由配置文件中的 seg_num 參數指定。
模型設置: 模型主要可配置參數爲 cluster_nums 和 seg_num 參數。其中 cluster_nums 是 attention 單元的數量。當配置 cluster_nums 爲 32, seg_num 爲 100 時,在 Nvidia Tesla P40 上單卡可跑 batch_size=256。
訓練策略:
採用 Adam 優化器,初始 learning_rate=0.001
訓練過程當中不使用權重衰減
參數主要使用 MSRA 初始化
模型評估:可經過如下方式(方法 1)進行模型評估,一樣咱們也提供了快速啓動的腳本(方法 2):
使用 scripts/test/test_attention_cluster.sh 進行評估時,須要修改腳本中的--weights 參數指定須要評估的權重。
若未指定--weights 參數,腳本會下載已發佈模型進行評估
模型推斷:可經過以下命令進行模型推斷:
模型推斷結果存儲於 AttentionCluster_infer_result 中,經過 pickle 格式存儲。
若未指定--weights 參數,腳本會下載已發佈模型 model 進行推斷
模型精度:當模型取以下參數時,在 Youtube-8M 數據集上的指標爲:
參數取值:
評估精度: