tf.estimator.Estimator類的用法

官網連接:https://www.tensorflow.org/api_docs/python/tf/estimator/Estimatorpython

 

 Estimator - 一種可極大地簡化機器學習編程的高階 TensorFlow API。Estimator 會封裝下列操做:編程

 

  • 訓練
  • 評估
  • 預測
  • 導出以供使用

 

您可使用官方提供的預建立的 Estimator,也能夠編寫自定義 Estimator。全部 Estimator(不管是預建立的仍是自定義)都是基於 tf.estimator.Estimator 類的類。api

Estimator 的優點

Estimator 具備下列優點:安全

  • 您能夠在本地主機上或分佈式多服務器環境中運行基於 Estimator 的模型,而無需更改模型。此外,您能夠在 CPU、GPU 或 TPU 上運行基於 Estimator 的模型,而無需從新編碼模型。
  • Estimator 簡化了在模型開發者之間共享實現的過程。
  • 您可使用高級直觀代碼開發先進的模型。簡言之,採用 Estimator 建立模型一般比採用低階 TensorFlow API 更簡單。
  • Estimator 自己在 tf.layers 之上構建而成,能夠簡化自定義過程。
  • Estimator 會爲您構建圖。
  • Estimator 提供安全的分佈式訓練循環,能夠控制如何以及什麼時候:
    • 構建圖
    • 初始化變量
    • 開始排隊
    • 處理異常
    • 建立檢查點文件並從故障中恢復
    • 保存 TensorBoard 的摘要

使用 Estimator 編寫應用時,您必須將數據輸入管道從模型中分離出來。這種分離簡化了不一樣數據集的實驗流程。服務器

預建立的 Estimator

藉助預建立的 Estimator,您可以在比基本 TensorFlow API 高級不少的概念層面上進行操做。因爲 Estimator 會爲您處理全部「管道工做」,所以您沒必要再爲建立計算圖或會話而操心。也就是說,預建立的 Estimator 會爲您建立和管理 Graph 和 Session 對象。此外,藉助預建立的 Estimator,您只需稍微更改下代碼,就能夠嘗試不一樣的模型架構。例如,DNNClassifier 是一個預建立的 Estimator 類,它根據密集的前饋神經網絡訓練分類模型。網絡

預建立的 Estimator 程序的結構

依賴預建立的 Estimator 的 TensorFlow 程序一般包含下列四個步驟:session

  • 編寫一個或多個數據集導入函數。 例如,您能夠建立一個函數來導入訓練集,並建立另外一個函數來導入測試集。每一個數據集導入函數都必須返回兩個對象:數據結構

    • 一個字典,其中鍵是特徵名稱,值是包含相應特徵數據的張量(或 SparseTensor)
    • 一個包含一個或多個標籤的張量

    例如,如下代碼展現了輸入函數的基本框架:架構

def input_fn(dataset):
   ...  # manipulate dataset, extracting the feature dict and the label
   return feature_dict, label
  • 定義特徵列。 每一個 tf.feature_column 都標識了特徵名稱、特徵類型和任何輸入預處理操做。例如,如下代碼段建立了三個存儲整數或浮點數據的特徵列。前兩個特徵列僅標識了特徵的名稱和類型。第三個特徵列還指定了一個 lambda,該程序將調用此 lambda 來調節原始數據:
# Define three numeric feature columns.
population = tf.feature_column.numeric_column('population')
crime_rate = tf.feature_column.numeric_column('crime_rate')
median_education = tf.feature_column.numeric_column('median_education',
                    normalizer_fn=lambda x: x - global_education_mean)
  • 實例化相關的預建立的 Estimator。 例如,下面是對名爲 LinearClassifier 的預建立 Estimator 進行實例化的示例代碼:
# Instantiate an estimator, passing the feature columns.
estimator = tf.estimator.LinearClassifier(
    feature_columns=[population, crime_rate, median_education],
    )
  • 調用訓練、評估或推理方法。例如,全部 Estimator 都提供訓練模型的 train 方法。
# my_training_set is the function created in Step 1
estimator.train(input_fn=my_training_set, steps=2000)

從 Keras 模型建立 Estimator

您能夠將現有的 Keras 模型轉換爲 Estimator。這樣作以後,Keras 模型就能夠利用 Estimator 的優點,例如分佈式訓練。調用 tf.keras.estimator.model_to_estimator,以下例所示:app

# Instantiate a Keras inception v3 model.
keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None)
# Compile model with the optimizer, loss, and metrics you'd like to train with.
keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
                          loss='categorical_crossentropy',
                          metric='accuracy')
# Create an Estimator from the compiled Keras model. Note the initial model
# state of the keras model is preserved in the created Estimator.
est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3)

# Treat the derived Estimator as you would with any other Estimator.
# First, recover the input name(s) of Keras model, so we can use them as the
# feature column name(s) of the Estimator input function:
keras_inception_v3.input_names  # print out: ['input_1']
# Once we have the input name(s), we can create the input function, for example,
# for input(s) in the format of numpy ndarray:
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
    x={"input_1": train_data},
    y=train_labels,
    num_epochs=1,
    shuffle=False)
# To train, we call Estimator's train function:
est_inception_v3.train(input_fn=train_input_fn, steps=2000)

 

class Estimator(builtins.object)
一 介紹
Estimator 類,用來訓練和驗證 TensorFlow 模型。
Estimator 對象包含了一個模型 model_fn,這個模型給定輸入和參數,會返回訓練、驗證或者預測等所須要的操做節點。
全部的輸出(檢查點、事件文件等)會寫入到 model_dir,或者其子文件夾中。若是 model_dir 爲空,則默認爲臨時目錄。
config 參數爲 tf.estimator.RunConfig 對象,包含了執行環境的信息。若是沒有傳遞 config,則它會被 Estimator 實例化,使用的是默認配置。
params 包含了超參數。Estimator 只傳遞超參數,不會檢查超參數,所以 params 的結構徹底取決於開發者。
Estimator 的全部方法都不能被子類覆蓋(它的構造方法強制決定的)。子類應該使用 model_fn 來配置母類,或者增添方法來實現特殊的功能。
Estimator 不支持 Eager Execution(eager execution可以使用Python 的debug工具、數據結構與控制流。而且無需使用placeholder、session,計算結果可以當即得出)。


二 類內方法

一、__init__(self, model_fn, model_dir=None, config=None, params=None, warm_start_from=None)
構造一個 Estimator 的實例.。
參數:

model_fn: 模型函數。函數的格式以下:
  參數:
  一、features: 這是 input_fn 返回的第一項(input_fn 是 train, evaluate 和 predict 的參數)。類型應該是單一的 Tensor 或者 dict。
  二、labels: 這是 input_fn 返回的第二項。類型應該是單一的 Tensor 或者 dict。若是 mode 爲 ModeKeys.PREDICT,則會默認爲 labels=None。若是 model_fn 不接受 mode,model_fn 應該仍然能夠處理 labels=None。
  三、mode: 可選。指定是訓練、驗證仍是測試。參見 ModeKeys。
  四、params: 可選,超參數的 dict。 能夠從超參數調整中配置 Estimators。
  五、config: 可選,配置。若是沒有傳則爲默認值。能夠根據 num_ps_replicas 或 model_dir 等配置更新 model_fn。
  返回:
  EstimatorSpec
model_dir: 保存模型參數、圖等的地址,也能夠用來將路徑中的檢查點加載至 estimator 中來繼續訓練以前保存的模型。若是是 PathLike, 那麼路徑就固定爲它了。若是是 None,那麼 config 中的 model_dir 會被使用(若是設置了的話),若是兩個都設置了,那麼必須相同;若是兩個都是 None,則會使用臨時目錄。
config: 配置類。
params: 超參數的dict,會被傳遞到 model_fn。keys 是參數的名稱,values 是基本 python 類型。
warm_start_from: 可選,字符串,檢查點的文件路徑,用來指示從哪裏開始熱啓動。或者是 tf.estimator.WarmStartSettings 類來所有配置熱啓動。若是是字符串路徑,則全部的變量都是熱啓動,而且須要 Tensor 和詞彙的名字都沒有變。
異常:

RuntimeError: 開啓了 eager execution

ValueError:model_fn 的參數與 params 不匹配

ValueError:這個函數被 Estimator 的子類所覆蓋


二、train(self, input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None)
根據所給數據 input_fn, 對模型進行訓練。
參數:

input_fn:一個函數,提供由小 batches 組成的數據, 供訓練使用。必須返回如下之一:
  一、一個 'tf.data.Dataset'對象:Dataset的輸出必須是一個元組 (features, labels),元組要求以下。
  二、一個元組 (features, labels):features 是一個 Tensor 或者一個字典(特徵名爲 Tensor),labels 是一個 Tensor 或者一個字典(特徵名爲 Tensor)。features 和 labels 都被 model_fn 所使用,應該符合 model_fn 輸入的要求。

hooks:SessionRunHook 子類實例的列表。用於在訓練循環內部執行。

steps:模型訓練的步數。若是是 None, 則一直訓練,直到input_fn 拋出了超過界限的異常。steps 是遞進式進行的。若是執行了兩次訓練(steps=10),則總共訓練了 20 次。若是中途拋出了越界異常,則訓練在 20 次以前就會中止。若是你不想遞進式進行,請換爲設置 max_steps。若是設置了 steps,則 max_steps 必須是 None。

max_steps:模型訓練的最大步數。若是爲 None,則一直訓練,直到input_fn 拋出了超過界限的異常。若是設置了 max_steps, 則 steps 必須是 None。若是中途拋出了越界異常,則訓練在 max_steps 次以前就會中止。執行兩次 train(steps=100) 意味着 200 次訓練;可是,執行兩次 train(max_steps=100) 意味着第二次執行不會進行任何訓練,由於第一次執行已經作完了全部的 100 次。

saving_listeners:CheckpointSaverListener 對象的列表。用於在保存檢查點以前或以後當即執行的回調函數。

返回:
self:爲了連接下去。
異常:
ValueError:steps 和 max_steps 都不是 None
ValueError:steps 或 max_steps <= 0


三、evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None)
根據所給數據 input_fn, 對模型進行驗證。
對於每一步,執行 input_fn(返回數據的一個 batch)。
一直進行驗證,直到:

steps 個 batches 進行完畢,或者
input_fn 拋出了越界異常(OutOfRangeError 或 StopIteration)
參數:

input_fn:一個函數,構造了驗證所需的輸入數據,必須返回如下之一:
  一、一個 'tf.data.Dataset'對象:Dataset的輸出必須是一個元組 (features, labels),元組要求以下。
  二、一個元組 (features, labels):features 是一個 Tensor 或者一個字典(特徵名爲 Tensor),labels 是一個 Tensor 或者一個字典(特徵名爲 Tensor)。features 和 labels 都被 model_fn 所使用,應該符合 model_fn 輸入的要求。
steps:模型驗證的步數。若是是 None, 則一直驗證,直到input_fn 拋出了超過界限的異常。
hooks:SessionRunHook 子類實例的列表。用於在驗證內部執行。
checkpoint_path: 用於驗證的檢查點路徑。若是是 None, 則使用 model_dir 中最新的檢查點。
name:驗證的名字。使用者能夠針對不一樣的數據集運行多個驗證操做,好比訓練集 vs 測試集。不一樣驗證的結果被保存在不一樣的文件夾中,且分別出如今 tensorboard 中。
返回:
返回一個字典,包括 model_fn 中指定的評價指標、global_step(包含驗證進行的全局步數)
異常:
ValueError:若是 step 小於等於0
ValueError:若是 model_dir 指定的模型沒有被訓練,或者指定的 checkpoint_path 爲空。

四、predict(self, input_fn, predict_keys=None, hooks=None, checkpoint_path=None, yield_single_examples=True)
對給出的特徵進行預測
參數:

input_fn:一個函數,構造特徵。預測一直進行下去,直到 input_fn 拋出了越界異常(OutOfRangeError 或 StopIteration)。函數必須返回如下之一:
  一、一個 'tf.data.Dataset'對象:Dataset的輸出和如下的限制相同。
  二、features:一個 Tensor 或者一個字典(特徵名爲 Tensor)。features 被 model_fn 所使用,應該符合 model_fn 輸入的要求。
  三、一個元組,其中第一項爲 features。
predict_keys:字符串列表,要預測的鍵值。當 EstimatorSpec.predictions 是一個 dict 時使用。若是使用了 predict_keys, 那麼剩下的預測值會從字典中過濾掉。若是是 None,則返回所有。
hooks:SessionRunHook 子類實例的列表。用於在預測內部回調。
checkpoint_path: 用於預測的檢查點路徑。若是是 None, 則使用 model_dir 中最新的檢查點。
yield_single_examples:If False, yield the whole batch as returned by the model_fn instead of decomposing the batch into individual elements. This is useful if model_fn returns some tensors whose first dimension is not equal to the batch size.
返回:
predictions tensors 的值
異常:
ValueError:model_dir 中找不到訓練好的模型。
ValueError:預測值的 batch 長度不一樣,且 yield_single_examples 爲 True。
ValueError:predict_keys 和 predictions 之間有衝突。例如,predict_keys 不是 None,可是 EstimatorSpec.predictions 不是一個 dict。

相關文章
相關標籤/搜索