AI:拿來主義——預訓練網絡(二)

上一篇文章咱們聊的是使用預訓練網絡中的一種方法,特徵提取,今天咱們討論另一種方法,微調模型,這也是遷移學習的一種方法。html

微調模型python

爲何須要微調模型?咱們猜想和以前的實驗,咱們有這樣的共識,數據量越少,網絡的特徵節點越多,會越容易致使過擬合,這固然不是咱們所但願的,但對於那些預先訓練好的模型,還有可能最終沒法很好的完成所要作的工做,所以咱們還須要對其更改,基於此緣由,咱們須要作的就是拿來一個訓練好的模型,更改其中更加抽象的層,即網絡後面的層,而後再採用新的分類器,這樣能夠比較好的解決上面所提出的過擬合問題了。網絡

進行微調網絡的步驟是:app

  1. 在已經訓練好的網絡(基網絡)基礎上,添加自定義的層;學習

  2. 凍結基網絡並訓練新添加的層;優化

  3. 凍結基網絡的一部分層,另外一部分可訓練;3d

  4. 聯合訓練解凍的這些層和添加的部分。rest

咱們上一篇提到的方法就能夠完成前兩個步驟,接下來咱們看如何解決後兩個步驟。這裏咱們還要更明確一下調整的層數若是過多會帶來什麼問題:隨着可變層數的增多,過擬合的風險會隨之加大。還要明確調整網絡中識別像素和線條的層不如調整識別耳朵的層更有效,由於不管是識別貓仍是桌子識別線條的方法層更通用。code

完成這項任務所須要寫的代碼也是很簡單的,就是設置模型是可訓練的,而後遍歷網絡的每一層,針對每一層分別設置是不是可訓練的,直到 layer_name 層,前面的層都是不可訓練的:htm

conv_base.trainable = True
set_trainable = False
for layer in conv_base.layers:
    if layer.name == 'layer_name':
        set_trainable = True
    if set_trainable:
        layer.trainable = True
    else:
        layer.trainable = False

這裏是關鍵部分代碼,老規矩,最後將給出所有代碼,咱們先來看看結果:

image

須要注意一下這裏的數據,在開始的時候不穩定,迅速爬升,所以縱座標的數據沒有那麼好,但咱們仔細看一下後期的數據,訓練精度和驗證精度都在百分之九十到百分之百,驗證精度一直有一些波動,是網絡的一些噪聲引發的,我不想去強制讓它們那麼漂亮了,一是由於訓練時間會比較長,而是由於我以爲沒有特別大的必要,波動的最高點和最低點都在可接受的範圍內,應該把關注點放在更重要的問題上去。

image

基於本篇文章和上一篇文章,咱們作個小結:

  1. 計算機視覺領域中,卷積神經網絡的表現很是不錯,而且在數據集較小的狀況下,表現讓人是很是優秀的。

  2. 數據加強是很好的避免過擬合的方法,過擬合產生的主要緣由多是數據量太少或者是參數過多。

  3. 特徵提取能夠比較好的將現有的神經網絡應用於小型數據集,還可使用微調的方式進行優化。

咱們看看代碼吧,這裏還有一個建議,若是可能儘可能使用 GPU 去作網絡模型的訓練,CPU 在現階段處理這些問題會有點力不從心,耗時較長,讀者也能夠考慮減小一些數據量加快速度,但要避免過擬合,請讀者心中記住此類問題,在遇到問題的時候是一個方向(固然,筆者是很是慘的,沒有好用的 GPU,所以等待數據畫圖截圖是很是痛苦的一件事):

#!/usr/bin/env python3
​
import os
import time
​
import matplotlib.pyplot as plt
from keras import layers
from keras import models
from keras import optimizers
from keras.applications import VGG16
from keras.preprocessing.image import ImageDataGenerator
​
​
def cat():
    base_dir = '/Users/renyuzhuo/Desktop/cat/dogs-vs-cats-small'
    train_dir = os.path.join(base_dir, 'train')
    validation_dir = os.path.join(base_dir, 'validation')
​
    train_datagen = ImageDataGenerator(
        rescale=1. / 255,
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')
​
    test_datagen = ImageDataGenerator(rescale=1. / 255)
​
    train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=(150, 150),
        batch_size=20,
        class_mode='binary')
​
    validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(150, 150),
        batch_size=20,
        class_mode='binary')
​
    # 定義密集鏈接分類器
    conv_base = VGG16(weights='imagenet',
                      include_top=False,
                      input_shape=(150, 150, 3))
    conv_base.trainable = True
    set_trainable = False
    for layer in conv_base.layers:
        if layer.name == 'block5_conv1':
            set_trainable = True
        if set_trainable:
            layer.trainable = True
        else:
            layer.trainable = False
    model = models.Sequential()
    model.add(conv_base)
    model.add(layers.Flatten())
    model.add(layers.Dense(256, activation='relu', input_dim=4 * 4 * 512))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(1, activation='sigmoid'))
​
    conv_base.summary()
​
    # 對模型進行配置
    model.compile(loss='binary_crossentropy',
                  optimizer=optimizers.RMSprop(lr=1e-5),
                  metrics=['acc'])
​
    # 對模型進行訓練
    history = model.fit_generator(
        train_generator,
        steps_per_epoch=100,
        epochs=100,
        validation_data=validation_generator,
        validation_steps=50)
​
    # 畫圖
    acc = history.history['acc']
    val_acc = history.history['val_acc']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs = range(len(acc))
    plt.plot(epochs, acc, 'bo', label='Training acc')
    plt.plot(epochs, val_acc, 'b', label='Validation acc')
    plt.title('Training and validation accuracy')
    plt.legend()
    plt.show()
    plt.figure()
    plt.plot(epochs, loss, 'bo', label='Training loss')
    plt.plot(epochs, val_loss, 'b', label='Validation loss')
    plt.title('Training and validation loss')
    plt.legend()
    plt.show()
​
​
if __name__ == "__main__":
    time_start = time.time()
    cat()
    time_end = time.time()
    print('Time Used: ', time_end - time_start)

本文首發自公衆號:RAIS

原文出處:https://www.cnblogs.com/renyuzhuo/p/12405447.html

相關文章
相關標籤/搜索