estimator = KerasClassifier

如何在scikit-learn模型中使用Keras

經過用 KerasClassifier 或 KerasRegressor 類包裝Keras模型,可將其用於scikit-learn。python

要使用這些包裝,必須定義一個函數,以便按順序模式建立並返回Keras,而後當構建 KerasClassifier 類時,把該函數傳遞給 build_fn 參數。函數

例如:測試

def create_model(): ... return model model = KerasClassifier(build_fn=create_model)

KerasClassifier類 的構建器爲能夠採起默認參數,並將其被傳遞給 model.fit() 的調用函數,好比 epochs數目和批尺寸(batch size)。ui

例如:spa

def create_model(): ... return model model = KerasClassifier(build_fn=create_model, nb_epoch=10)

KerasClassifier類的構造也能夠使用新的參數,使之可以傳遞給自定義的create_model()函數。這些新的參數,也必須由使用默認參數的 create_model() 函數的簽名定義。code

例如:ci

def create_model(dropout_rate=0.0): ... return model model = KerasClassifier(build_fn=create_model, dropout_rate=0.2)

 

pred = estimator.predict(X_test)#返回給定測試數據的類預測。
pred1=estimator.predict_proba(X_test)#返回給定測試數據的類機率估計。
# pred3=estimator.score(X_test,Y_test)#返回給定測試數據和標籤的平均精度。
print(X_test)#
print(Y_test)#實際類別
print(pred)#預測類別it


print(pred1)class

 

 

 

[[0. 1. 0. ... 1. 0. 0.]
[0. 0. 1. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 1. 1. ... 0. 0. 0.]]
[[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
...
[0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1.]
[0. 0. 1. 0. 0. 0.]]
[5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
5 5 5 5]
[[0.02377683 0.0266185 0.04945414 0.08426233 0.04495123 0.77093697]
[0.02115186 0.01721832 0.03360457 0.05283894 0.05303674 0.82214963]
[0.00838055 0.01647644 0.02293482 0.05378568 0.057558 0.8408645 ]
...
[0.01674003 0.01713392 0.03502046 0.03685626 0.03512193 0.85912746]
[0.0494712 0.0336375 0.05689533 0.03956604 0.04415505 0.77627486]
[0.04764625 0.04542363 0.08352048 0.15077472 0.10701337 0.5656215 ]]test

相關文章
相關標籤/搜索