基於sklearn和keras的數據切分與交叉驗證

在訓練深度學習模型的時候,一般將數據集切分爲訓練集和驗證集.Keras提供了兩種評估模型性能的方法:html

  • 使用自動切分的驗證集
  • 使用手動切分的驗證集

 

一.自動切分

在Keras中,能夠從數據集中切分出一部分做爲驗證集,而且在每次迭代(epoch)時在驗證集中評估模型的性能.算法

具體地,調用model.fit()訓練模型時,可經過validation_split參數來指定從數據集中切分出驗證集的比例.app

# MLP with automatic validation set
from keras.models import Sequential
from keras.layers import Dense
import numpy
# fix random seed for reproducibility
numpy.random.seed(7)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10)

validation_split:0~1之間的浮點數,用來指定訓練集的必定比例數據做爲驗證集。驗證集將不參與訓練,並在每一個epoch結束後測試的模型的指標,如損失函數、精確度等。dom

注意,validation_split的劃分在shuffle以前,所以若是你的數據自己是有序的,須要先手工打亂再指定validation_split,不然可能會出現驗證集樣本不均勻。 函數

 

二.手動切分

Keras容許在訓練模型的時候手動指定驗證集.性能

例如,用sklearn庫中的train_test_split()函數將數據集進行切分,而後在kerasmodel.fit()的時候經過validation_data參數指定前面切分出來的驗證集.學習

# MLP with manual validation set
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# split into 67% for train and 33% for test
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=seed) # create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(X_train, y_train, validation_data=(X_test,y_test), epochs=150, batch_size=10)

 

三.K折交叉驗證(k-fold cross validation)

將數據集分紅k份,每一輪用其中(k-1)份作訓練而剩餘1份作驗證,以這種方式執行k輪,獲得k個模型.將k次的性能取平均,做爲該算法的總體性能.k通常取值爲5或者10.測試

  • 優勢:能比較魯棒性地評估模型在未知數據上的性能.
  • 缺點:計算複雜度較大.所以,在數據集較大,模型複雜度較高,或者計算資源不是很充沛的狀況下,可能不適用,尤爲是在訓練深度學習模型的時候.

sklearn.model_selection提供了KFold以及RepeatedKFold, LeaveOneOut, LeavePOut, ShuffleSplit, StratifiedKFold, GroupKFold, TimeSeriesSplit等變體.lua

下面的例子中用的StratifiedKFold採用的是分層抽樣,它保證各種別的樣本在切割後每一份小數據集中的比例都與原數據集中的比例相同.spa

# MLP for Pima Indians Dataset with 10-fold cross validation
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import StratifiedKFold
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# define 10-fold cross validation test harness
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
cvscores = []
for train, test in kfold.split(X, Y): # create model
    model = Sequential()
    model.add(Dense(12, input_dim=8, activation='relu'))
    model.add(Dense(8, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    # Compile model
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    # Fit the model
    model.fit(X[train], Y[train], epochs=150, batch_size=10, verbose=0)
    # evaluate the model
    scores = model.evaluate(X[test], Y[test], verbose=0)
    print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
    cvscores.append(scores[1] * 100)
print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores), numpy.std(cvscores)))

  

參考:

Evaluate the Performance Of Deep Learning Models in Keras

3.1. Cross-validation: evaluating estimator performance — scikit-learn 0.19.1 documentation

sklearn中的交叉驗證與參數選擇

相關文章
相關標籤/搜索