Tensorflow從1.3版本開始推出了官方支持的高層封裝tf.estimator
。Estimators API提供了一整套訓練模型、測試模型以及生成預測的方法。python
Tensorflow支持自定義estimator,首先須要定義一個模型函數model_fn,函數有4個輸入:features,labels,mode和params。
features爲模型的輸入,labels爲預測的真實值
mode的取值有3種:tf.estimator.ModeKeys.TRAIN
,tf.estimator.ModeKeys.EVAL
和tf.estimator.ModeKeys.PREDICT
,分別對應訓練,驗證和測試。經過mode的值,能夠判斷當前屬於哪個階段。params是一個字典,包含模型相關的超參數,例如learning rate等。
自定義函數model_fn返回值必須是一個tf.estimator.EstimatorSpec
對象,git
def __new__(cls, mode, predictions=None, loss=None, train_op=None, eval_metric_ops=None, export_outputs=None, training_chief_hooks=None, training_hooks=None, scaffold=None, evaluation_hooks=None, prediction_hooks=None):
其中,mode
表示模型的使用模式,對應model_fn的參數mode;predictions
表示根據輸入的特徵features
計算返回的預測值;loss
表示損失;train_op
表示對模型的損失進行最小化的op;eval_metric_ops
表示模型在eval時,須要額外輸出的指標。export_outputs
表示導出模型的路徑。還有一些鉤子函數。
當mode不一樣,EstimatorSpec所需的參數也不同。若是mode爲TRAIN
,則實例化EstimatorSpec時,必須設置參數loss
和train_op
,當mode爲EVAL
時,必須設置參數loss
,當mode爲PREDICT
時,必須設置參數predictions
。github
def my_model(features, labels, mode, params): W = tf.Variable(tf.random_normal([1]), name="weight") b = tf.Variable(tf.zeros([1]), name="bias") predictions = tf.multiply(W, tf.cast(features, dtype=tf.float32)) + b if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode, predictions=predictions) loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions) mean_loss = tf.metrics.mean(loss) metrics = {'mean_loss':mean_loss} if mode == tf.estimator.ModeKeys.EVAL: # eval_metric_ops`用來定義評價指標,在運行eval的時候會計算這裏定義的全部評測標準。 return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=metrics) assert mode == tf.estimator.ModeKeys.TRAIN optimizer = tf.train.AdagradDAOptimizer(learning_rate=params["learning_rate"], global_step=tf.train.get_or_create_global_step()) train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_or_create_global_step()) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
最後經過實例化tf.estimator.Estimator
就能夠獲得一個自定義的estimator。dom
def __init__(self, model_fn: Any, model_dir: Any = None, config: Any = None, params: Any = None, warm_start_from: Any = None) -> Any
參數model_fn
即爲自定義的模型函數,model_dir
用於保存模型的參數和模型圖等內容。warm_start_from
用來指定檢查點路徑,並導入checkpoint開始訓練。warm_start_from能夠經過tf.estimator.WarmStartSettings
實例化。函數
def __new__(cls, ckpt_to_initialize_from: Any, vars_to_warm_start: str = '.*', var_name_to_vocab_info: Any = None, var_name_to_prev_var_name: Any = None) -> _T
ckpt_to_initialize_from
能夠指定加載checkpoint的路徑,vars_to_warm_start
指定哪些參數須要熱啓動。學習
代碼自定義estimator測試