使用數據加強技術提高模型泛化能力

在《提升模型性能,你能夠嘗試這幾招...》一文中,咱們給出了幾種提升模型性能的方法,但這篇文章是在訓練數據集不變的前提下提出的優化方案。其實對於深度學習而言,數據量的多寡一般對模型性能的影響更大,因此擴充數據規模通常狀況是一個很是有效的方法。python

對於Google、Facebook來講,收集幾百萬張圖片,訓練超大規模的深度學習模型,天然不在話下。可是對於我的或者小型企業而言,收集現實世界的數據,特別是帶標籤的數據,將是一件很是費時費力的事。本文探討一種技術,在現有數據集的基礎上,進行數據加強(data augmentation),增長參與模型訓練的數據量,從而提高模型的性能。git

什麼是數據加強

所謂數據加強,就是採用在原有數據上隨機增長抖動和擾動,從而生成新的訓練樣本,新樣本的標籤和原始數據相同。這個也很好理解,對於一張標籤爲「狗」的圖片,作必定的模糊、裁剪、變形等處理,並不會改變這張圖片的類別。數據加強也不只侷限於圖片分類應用,好比有以下圖所示的數據,數據知足正態分佈:github

咱們在數據集的基礎上,增長一些擾動處理,數據分佈以下:web

數據就在原來的基礎上增長了幾倍,但總體上仍然知足正態分佈。有人可能會說,這樣的出來的模型不是沒有原來精確了嗎?考慮到現實世界的複雜性,咱們採集到的數據很難徹底知足正態分佈,因此這樣增長數據擾動,不只不會下降模型的精確度,然而加強了泛化能力。算法

對於圖片數據而言,可以作的數據加強的方法有不少,一般的方法是:bash

  • 平移
  • 旋轉
  • 縮放
  • 裁剪
  • 切變(shearing)
  • 水平/垂直翻轉
  • ...

上面幾種方法,可能切變(shearing)比較難以理解,看一張圖就明白了:微信

咱們要親自編寫這些數據加強算法嗎?一般不須要,好比keras就提供了批量處理圖片變形的方法。post

keras中的數據加強方法

keras中提供了ImageDataGenerator類,其構造方法以下:性能

ImageDataGenerator(featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization = False,
    samplewise_std_normalization = False,
    zca_whitening = False,
    rotation_range = 0.,
    width_shift_range = 0.,
    height_shift_range = 0.,
    shear_range = 0.,
    zoom_range = 0.,
    channel_shift_range = 0.,
    fill_mode = 'nearest',
    cval = 0.0,
    horizontal_flip = False,
    vertical_flip = False,
    rescale = None,
    preprocessing_function = None,
    data_format = K.image_data_format(),
)
複製代碼

參數不少,經常使用的參數有:學習

  • rotation_range: 控制隨機的度數範圍旋轉。
  • width_shift_range和height_shift_range: 分別用於水平和垂直移位。
  • zoom_range: 根據[1 - zoom_range,1 + zoom_range]範圍均勻將圖像「放大」或「縮小」。
  • horizontal_flip:控制是否水平翻轉。

完整的參數說明請參考keras文檔。

下面一段代碼將1張給定的圖片擴充爲10張,固然你還能夠擴充更多:

image = load_img(args["image"])
image = img_to_array(image)
image = np.expand_dims(image, axis=0)

aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1, height_shift_range=0.1,
                         shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode="nearest")

aug.fit(image)

imageGen = aug.flow(image, batch_size=1, save_to_dir=args["output"], save_prefix=args["prefix"],
                    save_format="jpeg")

total = 0
for image in imageGen:
  # increment out counter
  total += 1

  if total == 10:
    break
複製代碼

須要指出的是,上述代碼的最後一個迭代是必須的,否在不會在output目錄下生成圖片,另外output目錄必須存在,不然會出現一下錯誤:

Traceback (most recent call last):
  File "augmentation_demo.py", line 35, in <module>
    for image in imageGen:
  File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1526, in __next__
    return self.next(*args, **kwargs)
  File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1704, in next
    return self._get_batches_of_transformed_samples(index_array)
  File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1681, in _get_batches_of_transformed_samples
    img.save(os.path.join(self.save_to_dir, fname))
  File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/PIL/Image.py", line 1947, in save
    fp = builtins.open(filename, "w+b")
FileNotFoundError: [Errno 2] No such file or directory: 'output/image_0_1091.jpeg'
複製代碼

以下一張狗狗的圖片:

通過數據加強技術處理以後,能夠獲得以下10張形態稍微不一樣的狗狗的圖片,這至關於在原有數據集上增長了10倍的數據,其實咱們還能夠擴充得最多:

數據加強以後的比較

咱們以MiniVGGNet模型爲例,說明在其在17flowers數據集上進行訓練的效果。17flowers是一個很是小的數據集,包含17中品類的花卉圖案,每一個品類包含80張圖片,這對於深度學習而言,數據量實在是過小了。通常而言,要讓深度學習模型有必定的精確度,每一個類別的圖片至少須要1000~5000張。這樣的數據集能夠很好的說明數據加強技術的必要性。

從網站上下載的17flowers數據,全部的圖片都放在一個目錄下,而咱們一般訓練時的目錄結構爲:

{類別名}/{圖片文件}
複製代碼

爲此我寫了一個organize_flowers17.py腳本。

在沒有使用數據加強的狀況下,在訓練數據集和驗證數據集上精度、損失隨着訓練輪次的變化曲線圖:

能夠看到,大約通過十幾輪的訓練,在訓練數據集上的準確率很快就達到了接近100%,然而在驗證數據集上的準確率卻沒法再上升,只能達到60%左右。這個圖能夠明顯的看出模型出現了很是嚴重的過擬合。

若是採用數據加強技術呢?曲線圖以下:

從圖中能夠看到,雖然在訓練數據集上的準確率有所降低,但在驗證數據集上的準確率有比較明顯的提高,說明模型的泛化能力有所加強。

也許在咱們看來,準確率從60%多增長到70%,只有10%的提高,並非什麼了不起的成績。但要考慮到咱們採用的數據集樣本數量實在是太少,可以達到這樣的提高已是很是可貴,在實際項目中,有時爲了提高1%的準確率,都會花費很多的功夫。

總結

數據加強技術在必定程度上可以提升模型的泛化能力,減小過擬合,但在實際中,咱們若是可以收集到更多真實的數據,仍是要儘可能使用真實數據。另外,數據加強只需應用於訓練數據集,驗證集上則不須要,畢竟咱們但願在驗證集上測試真實數據的準確。

以上實例均有完整的代碼,點擊閱讀原文,跳轉到我在github上建的示例代碼。

另外,我在閱讀《Deep Learning for Computer Vision with Python》這本書,在微信公衆號後臺回覆「計算機視覺」關鍵字,能夠免費下載這本書的電子版。

參考閱讀

提升模型性能,你能夠嘗試這幾招...

計算機視覺與深度學習,看這本書就夠了

相關文章
相關標籤/搜索