對於傳統機器學習而言,要求訓練樣本與測試樣本知足獨立同分布,並且必需要有足夠多的訓練樣本。而遷移學習能把一個領域(即源領域)的知識,遷移到另一個領域(即目標領域),目標領域每每只有少許有標籤樣本,使得目標領域可以取得更好的學習效果。mysql
這裏基於預訓練的卷積神經網絡訓練一組新參數,而後將其用於分類任務,這樣就能共享模型參數,避免了從頭開始訓練模型的參數,大大減小訓練時間。git
在示例中使用flower17數據集,它是一個包含17種花卉類別的數據集,每一個類別有80張圖像。收集的花都是英國一些常見的花,這些圖像具備大比例、不一樣姿態和光線變化等性質。github
使用水仙花和款冬這兩類花,而且在預訓練的VGG16網絡之上構建分類器。sql
首先導入全部必需的庫,包括應用程序、預處理、模型檢查點以及相關對象,cv2庫和NumPy庫用於圖像處理和數值的基本操做。數組
from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.models import Model
from keras.layers import Dropout, Flatten, Dense
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.applications.vgg16 import preprocess_input
import cv2
import numpy as np
複製代碼
定義輸入、數據源及與訓練參數相關的全部變量。bash
img_width, img_height = 224, 224
train_data_dir = "data/train"
validation_data_dir = "data/validation"
nb_train_samples = 300
nb_validation_samples = 100
batch_size = 16
epochs = 1
複製代碼
調用VGG16預訓練模型,其中不包括頂部的平整化層。凍結不參與訓練的層,這裏咱們凍結前五層,而後添加自定義層,從而建立最終的模型。網絡
model = applications.VGG16(weights="imagenet", include_top=False, input_shape=(img_width, img_height, 3))
for layer in model.layers[:5]:
layer.trainable = False
x = model.output
x = Flatten()(x)
x = Dense(1024, activation="relu")(x)
x = Dropout(0.5)(x)
x = Dense(1024, activation="relu")(x)
predictions = Dense(2, activation="softmax")(x)
model_final = Model(inputs=model.input, output=predictions)
複製代碼
接着開始編譯模型,併爲訓練、測試數據集建立圖像數據加強生成器。併發
model_final.compile(loss="categorical_crossentropy", optimizer=optimizers.SGD(lr=0.0001, momentum=0.9),
metrics=["accuracy"])
train_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, fill_mode="nearest", zoom_range=0.3,
width_shift_range=0.3, height_shift_range=0.3, rotation_range=30)
test_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, fill_mode="nearest", zoom_range=0.3,
width_shift_range=0.3, height_shift_range=0.3, rotation_range=30)
複製代碼
生成加強後新的數據,根據狀況保存模型。app
train_generator = train_datagen.flow_from_directory(train_data_dir, target_size=(img_height, img_width),
batch_size=batch_size, class_mode="categorical")
validation_generator = test_datagen.flow_from_directory(validation_data_dir, target_size=(img_height, img_width),
class_mode="categorical")
checkpoint = ModelCheckpoint("vgg16_1.h5", monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False,
mode='auto', period=1)
early = EarlyStopping(monitor='val_acc', min_delta=0, patience=10, verbose=1, mode='auto')
複製代碼
開始對模型中新的網絡層進行擬合。機器學習
model_final.fit_generator(train_generator, samples_per_epoch=nb_train_samples, nb_epoch=epochs,
validation_data=validation_generator, nb_val_samples=nb_validation_samples,
callbacks=[checkpoint, early])
複製代碼
練完成後用水仙花圖像測試這個新模型,輸出的正確值應該爲接近[1.,0.]的數組。
im = cv2.resize(cv2.imread('data/test/gaff2.jpg'), (img_width, img_height))
im = np.expand_dims(im, axis=0).astype(np.float32)
im = preprocess_input(im)
out = model_final.predict(im)
print(out)
print(np.argmax(out))
複製代碼
1/18 [>.............................] - ETA: 16:43 - loss: 0.9380 - acc: 0.3750
2/18 [==>...........................] - ETA: 13:51 - loss: 0.8720 - acc: 0.4062
3/18 [====>.........................] - ETA: 12:32 - loss: 0.8382 - acc: 0.4167
4/18 [=====>........................] - ETA: 10:53 - loss: 0.8103 - acc: 0.4663
5/18 [=======>......................] - ETA: 10:00 - loss: 0.8208 - acc: 0.4606
6/18 [=========>....................] - ETA: 9:12 - loss: 0.8083 - acc: 0.4567
7/18 [==========>...................] - ETA: 8:24 - loss: 0.7891 - acc: 0.4718
8/18 [============>.................] - ETA: 7:37 - loss: 0.7994 - acc: 0.4832
9/18 [==============>...............] - ETA: 6:51 - loss: 0.7841 - acc: 0.4850Epoch 00001: val_acc improved from -inf to 0.40000, saving model to vgg16_1.h5
9/18 [==============>...............] - ETA: 7:16 - loss: 0.7841 - acc: 0.4850 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00[[0.2213877 0.77861226]]
複製代碼
-------------推薦閱讀------------
個人開源項目彙總(機器&深度學習、NLP、網絡IO、AIML、mysql協議、chatbot)
跟我交流,向我提問:
歡迎關注: