一個例子瞭解遷移學習

遷移學習

對於傳統機器學習而言,要求訓練樣本與測試樣本知足獨立同分布,並且必需要有足夠多的訓練樣本。而遷移學習能把一個領域(即源領域)的知識,遷移到另一個領域(即目標領域),目標領域每每只有少許有標籤樣本,使得目標領域可以取得更好的學習效果。mysql

image

遷移方式

  • 樣本遷移,在源領域中找出與目標領域類似的樣本,增長該樣本的權重,使其在預測目標與的比重加大。
  • 特徵遷移,源領域與目標領域包含共同的交叉特徵,經過特徵變換將源領域和目標領域的的特徵變換到相同空間,使它們具備相同分佈。
  • 模型遷移,源領域和目標領域共享模型參數,將源領域已訓練好的網絡模型應用到目標領域的新問題上。
  • 關係遷移,源領域和目標領域具備某種類似關係,能夠將源領域的邏輯關係應用到目標領域中。

模型遷移

這裏基於預訓練的卷積神經網絡訓練一組新參數,而後將其用於分類任務,這樣就能共享模型參數,避免了從頭開始訓練模型的參數,大大減小訓練時間。git

數據集

在示例中使用flower17數據集,它是一個包含17種花卉類別的數據集,每一個類別有80張圖像。收集的花都是英國一些常見的花,這些圖像具備大比例、不一樣姿態和光線變化等性質。github

使用水仙花和款冬這兩類花,而且在預訓練的VGG16網絡之上構建分類器。sql

image

image

實現

首先導入全部必需的庫,包括應用程序、預處理、模型檢查點以及相關對象,cv2庫和NumPy庫用於圖像處理和數值的基本操做。數組

from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.models import Model
from keras.layers import Dropout, Flatten, Dense
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.applications.vgg16 import preprocess_input
import cv2
import numpy as np
複製代碼

定義輸入、數據源及與訓練參數相關的全部變量。bash

img_width, img_height = 224, 224
train_data_dir = "data/train"
validation_data_dir = "data/validation"
nb_train_samples = 300
nb_validation_samples = 100
batch_size = 16
epochs = 1
複製代碼

調用VGG16預訓練模型,其中不包括頂部的平整化層。凍結不參與訓練的層,這裏咱們凍結前五層,而後添加自定義層,從而建立最終的模型。網絡

model = applications.VGG16(weights="imagenet", include_top=False, input_shape=(img_width, img_height, 3))
for layer in model.layers[:5]:
    layer.trainable = False
x = model.output
x = Flatten()(x)
x = Dense(1024, activation="relu")(x)
x = Dropout(0.5)(x)
x = Dense(1024, activation="relu")(x)
predictions = Dense(2, activation="softmax")(x)
model_final = Model(inputs=model.input, output=predictions)
複製代碼

接着開始編譯模型,併爲訓練、測試數據集建立圖像數據加強生成器。併發

model_final.compile(loss="categorical_crossentropy", optimizer=optimizers.SGD(lr=0.0001, momentum=0.9),
                    metrics=["accuracy"])
train_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, fill_mode="nearest", zoom_range=0.3,
                                   width_shift_range=0.3, height_shift_range=0.3, rotation_range=30)
test_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, fill_mode="nearest", zoom_range=0.3,
                                  width_shift_range=0.3, height_shift_range=0.3, rotation_range=30)
複製代碼

生成加強後新的數據,根據狀況保存模型。app

train_generator = train_datagen.flow_from_directory(train_data_dir, target_size=(img_height, img_width),
                                                    batch_size=batch_size, class_mode="categorical")
validation_generator = test_datagen.flow_from_directory(validation_data_dir, target_size=(img_height, img_width),
                                                        class_mode="categorical")
checkpoint = ModelCheckpoint("vgg16_1.h5", monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False,
                             mode='auto', period=1)
early = EarlyStopping(monitor='val_acc', min_delta=0, patience=10, verbose=1, mode='auto')
複製代碼

開始對模型中新的網絡層進行擬合。機器學習

model_final.fit_generator(train_generator, samples_per_epoch=nb_train_samples, nb_epoch=epochs,
                          validation_data=validation_generator, nb_val_samples=nb_validation_samples,
                          callbacks=[checkpoint, early])
複製代碼

練完成後用水仙花圖像測試這個新模型,輸出的正確值應該爲接近[1.,0.]的數組。

im = cv2.resize(cv2.imread('data/test/gaff2.jpg'), (img_width, img_height))
im = np.expand_dims(im, axis=0).astype(np.float32)
im = preprocess_input(im)
out = model_final.predict(im)
print(out)
print(np.argmax(out))
複製代碼
1/18 [>.............................] - ETA: 16:43 - loss: 0.9380 - acc: 0.3750
 2/18 [==>...........................] - ETA: 13:51 - loss: 0.8720 - acc: 0.4062
 3/18 [====>.........................] - ETA: 12:32 - loss: 0.8382 - acc: 0.4167
 4/18 [=====>........................] - ETA: 10:53 - loss: 0.8103 - acc: 0.4663
 5/18 [=======>......................] - ETA: 10:00 - loss: 0.8208 - acc: 0.4606
 6/18 [=========>....................] - ETA: 9:12 - loss: 0.8083 - acc: 0.4567 
 7/18 [==========>...................] - ETA: 8:24 - loss: 0.7891 - acc: 0.4718
 8/18 [============>.................] - ETA: 7:37 - loss: 0.7994 - acc: 0.4832
 9/18 [==============>...............] - ETA: 6:51 - loss: 0.7841 - acc: 0.4850Epoch 00001: val_acc improved from -inf to 0.40000, saving model to vgg16_1.h5

 9/18 [==============>...............] - ETA: 7:16 - loss: 0.7841 - acc: 0.4850 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00[[0.2213877  0.77861226]]
複製代碼

github

github.com/sea-boat/De…

-------------推薦閱讀------------

個人開源項目彙總(機器&深度學習、NLP、網絡IO、AIML、mysql協議、chatbot)

爲何寫《Tomcat內核設計剖析》

個人2017文章彙總——機器學習篇

個人2017文章彙總——Java及中間件

個人2017文章彙總——深度學習篇

個人2017文章彙總——JDK源碼篇

個人2017文章彙總——天然語言處理篇

個人2017文章彙總——Java併發篇


跟我交流,向我提問:

歡迎關注:

相關文章
相關標籤/搜索