自定義estimator

Tensorflow從1.3版本開始推出了官方支持的高層封裝tf.estimator。Estimators API提供了一整套訓練模型、測試模型以及生成預測的方法。python

自定義模型函數

Tensorflow支持自定義estimator,首先須要定義一個模型函數model_fn,函數有4個輸入:features,labels,mode和params。
features爲模型的輸入,labels爲預測的真實值
mode的取值有3種:tf.estimator.ModeKeys.TRAINtf.estimator.ModeKeys.EVALtf.estimator.ModeKeys.PREDICT,分別對應訓練,驗證和測試。經過mode的值,能夠判斷當前屬於哪個階段。params是一個字典,包含模型相關的超參數,例如learning rate等。
自定義函數model_fn返回值必須是一個tf.estimator.EstimatorSpec對象,git

def __new__(cls,
              mode,
              predictions=None,
              loss=None,
              train_op=None,
              eval_metric_ops=None,
              export_outputs=None,
              training_chief_hooks=None,
              training_hooks=None,
              scaffold=None,
              evaluation_hooks=None,
              prediction_hooks=None):

其中,mode表示模型的使用模式,對應model_fn的參數mode;predictions表示根據輸入的特徵features計算返回的預測值;loss表示損失;train_op表示對模型的損失進行最小化的op;eval_metric_ops表示模型在eval時,須要額外輸出的指標。export_outputs表示導出模型的路徑。還有一些鉤子函數。
當mode不一樣,EstimatorSpec所需的參數也不同。若是mode爲TRAIN,則實例化EstimatorSpec時,必須設置參數losstrain_op,當mode爲EVAL時,必須設置參數loss,當mode爲PREDICT時,必須設置參數predictionsgithub

def my_model(features, labels, mode, params):
    W = tf.Variable(tf.random_normal([1]), name="weight")
    b = tf.Variable(tf.zeros([1]), name="bias")
    predictions = tf.multiply(W, tf.cast(features, dtype=tf.float32)) + b
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)
    loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions)
    mean_loss = tf.metrics.mean(loss)
    metrics = {'mean_loss':mean_loss}
    if mode == tf.estimator.ModeKeys.EVAL:
        # eval_metric_ops`用來定義評價指標,在運行eval的時候會計算這裏定義的全部評測標準。
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=metrics) 
    assert mode == tf.estimator.ModeKeys.TRAIN
    optimizer = tf.train.AdagradDAOptimizer(learning_rate=params["learning_rate"], global_step=tf.train.get_or_create_global_step())
    train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_or_create_global_step())
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

實例化estimator

最後經過實例化tf.estimator.Estimator就能夠獲得一個自定義的estimator。dom

def __init__(self,
             model_fn: Any,
             model_dir: Any = None,
             config: Any = None,
             params: Any = None,
             warm_start_from: Any = None) -> Any

參數model_fn即爲自定義的模型函數,model_dir用於保存模型的參數和模型圖等內容。warm_start_from用來指定檢查點路徑,並導入checkpoint開始訓練。warm_start_from能夠經過tf.estimator.WarmStartSettings實例化。函數

def __new__(cls,
            ckpt_to_initialize_from: Any,
            vars_to_warm_start: str = '.*',
            var_name_to_vocab_info: Any = None,
            var_name_to_prev_var_name: Any = None) -> _T

ckpt_to_initialize_from能夠指定加載checkpoint的路徑,vars_to_warm_start指定哪些參數須要熱啓動。學習

代碼

代碼自定義estimator測試

參考

  1. 深度學習之tensorflow工程化項目實戰。
相關文章
相關標籤/搜索