深度學習的模型是怎麼訓練/優化出來的

以典型的分類問題爲例,來梳理模型的訓練過程。訓練的過程就是問題發現的過程,一次訓練是爲下一步迭代作好指引。算法

1.數據準備

準備:網絡

  • 數據標註前的標籤體系設定要合理
  • 用於標註的數據集須要無偏、全面、儘量均衡
  • 標註過程要審覈

整理數據集框架

  1. 將各個標籤的數據放於不一樣的文件夾中,並統計各個標籤的數目
    如:第一列是路徑,最後一列是圖片數目。

    PS:可能會存在某些標籤樣本不多/多,記下來模型效果很差就怨它。
  2. 樣本均衡,樣本不會絕對均衡,差很少就好了
    如:控制最大類/最小類<\(\delta\)\(\delta=5\),最後一列爲均衡的目標值。
  3. 切分樣本集
    如:90%用於訓練,10%留着測試,比例本身定。訓練集合,對於弱勢類要重採樣,最後的圖片列表要shuffle;測試集合就不用重採樣了。
    訓練中要保證樣本均衡,學習到弱勢類的特徵,測試過程要反應真實的數據集分佈。
    第一列是圖片路徑,後面幾列是標籤(多任務)。

    學習

  4. 按須要的格式生成tfrecord
    按照train.list和validation.list生成須要的格式。生成和解析tfrecord的代碼要根據具體狀況編寫。測試

2.訓練

  • 預處理,根據本身的喜愛,編寫預處理策略。
    preprocessing的方法,變換方案諸如:隨機裁剪、隨機變換框、添加光照飽和度、修改壓縮係數、各類縮放方案、多尺度等。進而,減均值除方差或歸一化到[-1,1],將float類型的Tensor送入網絡。
    這一步的目的是:讓網絡接受的訓練樣本儘量多樣,不要最後出現原圖沒問題,改改分辨率或寬高比就跪了的狀況。
  • 網絡設計,基礎網絡的選擇和Loss的設計。
    基礎網絡的選擇和問題的複雜程度息息相關,用ResNet18能夠解決的不必用101;還有一些SE、GN等模塊加上去有沒有提高也能夠去嘗試。
    Loss的設計,通常問題的抽象就是設計Loss數據公式的過程。好比多任務中的各個任務權重配比,centorLoss可讓特徵分佈更緊湊,SmoothL1Loss更平滑避免梯度爆炸等。
  • 優化算法
    通常來講,只要時間足夠,Adam和SGD+Momentum能夠達到的效果差別不大。用框架提供的理論上最好的優化策略就是了。
  • 訓練過程
    finetune網絡,我習慣分兩步:首先訓練fc層,迭代幾個epoch後保存模型;而後基於獲得的模型,訓練整個網絡,通常迭代40-60個epoch能夠獲得穩定的結果。

    total_loss會一直降低的,過程當中能夠評測下模型在測試集上的表現。真正的loss每每包括兩部分。後面total_loss的降低主要是正則項的功勞了。

3.評估模型

1.混淆矩陣必不可少
混淆矩陣能夠發現哪些類是難區分的。基於混淆矩陣能夠獲得各種的準召,進而能夠獲得哪些類比較差。
如:列爲真值,行爲檢測的值。優化

gt/pl 靴子 單鞋 運動 休閒 棉鞋 雪地靴 帆布 拖鞋 涼鞋 雨鞋
靴子 4524 45 39 79 12 59 5 6 0 20
單鞋 51 4088 15 44 115 9 18 80 43 6
運動 38 6 817 247 0 2 18 8 1 0
休閒 53 47 171 806 17 8 118 15 1 2
棉鞋 12 110 5 15 424 55 2 32 1 1
雪地靴 53 6 5 10 73 628 0 13 2 1
帆布鞋 5 28 16 158 1 1 515 17 3 4
拖鞋 6 139 1 12 33 3 18 2316 60 6
涼鞋 7 69 3 6 0 0 2 55 633 1
雨鞋 26 6 1 3 0 1 2 5 1 499

進而可得:spa

label 召回 精度
靴子 0.9446648569638756 0.947434554973822
單鞋 0.9147460281942269 0.8996478873239436
運動 0.7185576077396658 0.7614165890027959
休閒 0.6510500807754442 0.5840579710144927
... ... ...

PS:運動-休閒容易混淆。設計

2.抽樣看測試數據
從測試數據中每類抽1000張,把它們的模型結果放在不一樣的文件夾下。對於分析問題仍是頗有效的,爲何它會分錯,要拿出來看看!
有些確實是人工標錯了。
3d

3.CAM
經過CAM能夠查看網絡究竟學到了什麼(是否是學錯了)。對於細粒度問題就不用分析CAM了,通常7x7的特徵圖原本就很小了,根本就看不出細節學到了什麼,只能粗略看看部位定位是否準確。

也能夠必定程度上幫助理解爲何網絡會搞錯,好比下面的單鞋被誤判爲了拖鞋。
blog

相關文章
相關標籤/搜索