官網連接:https://www.tensorflow.org/api_docs/python/tf/estimator/Estimatorpython
Estimator - 一種可極大地簡化機器學習編程的高階 TensorFlow API。Estimator 會封裝下列操做:編程
您可使用官方提供的預建立的 Estimator,也能夠編寫自定義 Estimator。全部 Estimator(不管是預建立的仍是自定義)都是基於 tf.estimator.Estimator
類的類。api
Estimator 具備下列優點:安全
tf.layers
之上構建而成,能夠簡化自定義過程。使用 Estimator 編寫應用時,您必須將數據輸入管道從模型中分離出來。這種分離簡化了不一樣數據集的實驗流程。服務器
藉助預建立的 Estimator,您可以在比基本 TensorFlow API 高級不少的概念層面上進行操做。因爲 Estimator 會爲您處理全部「管道工做」,所以您沒必要再爲建立計算圖或會話而操心。也就是說,預建立的 Estimator 會爲您建立和管理 Graph
和 Session
對象。此外,藉助預建立的 Estimator,您只需稍微更改下代碼,就能夠嘗試不一樣的模型架構。例如,DNNClassifier
是一個預建立的 Estimator 類,它根據密集的前饋神經網絡訓練分類模型。網絡
依賴預建立的 Estimator 的 TensorFlow 程序一般包含下列四個步驟:session
編寫一個或多個數據集導入函數。 例如,您能夠建立一個函數來導入訓練集,並建立另外一個函數來導入測試集。每一個數據集導入函數都必須返回兩個對象:數據結構
例如,如下代碼展現了輸入函數的基本框架:架構
def input_fn(dataset): ... # manipulate dataset, extracting the feature dict and the label return feature_dict, label
tf.feature_column
都標識了特徵名稱、特徵類型和任何輸入預處理操做。例如,如下代碼段建立了三個存儲整數或浮點數據的特徵列。前兩個特徵列僅標識了特徵的名稱和類型。第三個特徵列還指定了一個 lambda,該程序將調用此 lambda 來調節原始數據:# Define three numeric feature columns. population = tf.feature_column.numeric_column('population') crime_rate = tf.feature_column.numeric_column('crime_rate') median_education = tf.feature_column.numeric_column('median_education', normalizer_fn=lambda x: x - global_education_mean)
LinearClassifier
的預建立 Estimator 進行實例化的示例代碼:# Instantiate an estimator, passing the feature columns. estimator = tf.estimator.LinearClassifier( feature_columns=[population, crime_rate, median_education], )
train
方法。# my_training_set is the function created in Step 1 estimator.train(input_fn=my_training_set, steps=2000)
您能夠將現有的 Keras 模型轉換爲 Estimator。這樣作以後,Keras 模型就能夠利用 Estimator 的優點,例如分佈式訓練。調用 tf.keras.estimator.model_to_estimator
,以下例所示:app
# Instantiate a Keras inception v3 model. keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None) # Compile model with the optimizer, loss, and metrics you'd like to train with. keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metric='accuracy') # Create an Estimator from the compiled Keras model. Note the initial model # state of the keras model is preserved in the created Estimator. est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3) # Treat the derived Estimator as you would with any other Estimator. # First, recover the input name(s) of Keras model, so we can use them as the # feature column name(s) of the Estimator input function: keras_inception_v3.input_names # print out: ['input_1'] # Once we have the input name(s), we can create the input function, for example, # for input(s) in the format of numpy ndarray: train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={"input_1": train_data}, y=train_labels, num_epochs=1, shuffle=False) # To train, we call Estimator's train function: est_inception_v3.train(input_fn=train_input_fn, steps=2000)
class Estimator(builtins.object)
一 介紹
Estimator 類,用來訓練和驗證 TensorFlow 模型。
Estimator 對象包含了一個模型 model_fn,這個模型給定輸入和參數,會返回訓練、驗證或者預測等所須要的操做節點。
全部的輸出(檢查點、事件文件等)會寫入到 model_dir,或者其子文件夾中。若是 model_dir 爲空,則默認爲臨時目錄。
config 參數爲 tf.estimator.RunConfig 對象,包含了執行環境的信息。若是沒有傳遞 config,則它會被 Estimator 實例化,使用的是默認配置。
params 包含了超參數。Estimator 只傳遞超參數,不會檢查超參數,所以 params 的結構徹底取決於開發者。
Estimator 的全部方法都不能被子類覆蓋(它的構造方法強制決定的)。子類應該使用 model_fn 來配置母類,或者增添方法來實現特殊的功能。
Estimator 不支持 Eager Execution(eager execution可以使用Python 的debug工具、數據結構與控制流。而且無需使用placeholder、session,計算結果可以當即得出)。
二 類內方法
一、__init__(self, model_fn, model_dir=None, config=None, params=None, warm_start_from=None)
構造一個 Estimator 的實例.。
參數:
model_fn: 模型函數。函數的格式以下:
參數:
一、features: 這是 input_fn 返回的第一項(input_fn 是 train, evaluate 和 predict 的參數)。類型應該是單一的 Tensor 或者 dict。
二、labels: 這是 input_fn 返回的第二項。類型應該是單一的 Tensor 或者 dict。若是 mode 爲 ModeKeys.PREDICT,則會默認爲 labels=None。若是 model_fn 不接受 mode,model_fn 應該仍然能夠處理 labels=None。
三、mode: 可選。指定是訓練、驗證仍是測試。參見 ModeKeys。
四、params: 可選,超參數的 dict。 能夠從超參數調整中配置 Estimators。
五、config: 可選,配置。若是沒有傳則爲默認值。能夠根據 num_ps_replicas 或 model_dir 等配置更新 model_fn。
返回:
EstimatorSpec
model_dir: 保存模型參數、圖等的地址,也能夠用來將路徑中的檢查點加載至 estimator 中來繼續訓練以前保存的模型。若是是 PathLike, 那麼路徑就固定爲它了。若是是 None,那麼 config 中的 model_dir 會被使用(若是設置了的話),若是兩個都設置了,那麼必須相同;若是兩個都是 None,則會使用臨時目錄。
config: 配置類。
params: 超參數的dict,會被傳遞到 model_fn。keys 是參數的名稱,values 是基本 python 類型。
warm_start_from: 可選,字符串,檢查點的文件路徑,用來指示從哪裏開始熱啓動。或者是 tf.estimator.WarmStartSettings 類來所有配置熱啓動。若是是字符串路徑,則全部的變量都是熱啓動,而且須要 Tensor 和詞彙的名字都沒有變。
異常:
RuntimeError: 開啓了 eager execution
ValueError:model_fn 的參數與 params 不匹配
ValueError:這個函數被 Estimator 的子類所覆蓋
二、train(self, input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None)
根據所給數據 input_fn, 對模型進行訓練。
參數:
input_fn:一個函數,提供由小 batches 組成的數據, 供訓練使用。必須返回如下之一:
一、一個 'tf.data.Dataset'對象:Dataset的輸出必須是一個元組 (features, labels),元組要求以下。
二、一個元組 (features, labels):features 是一個 Tensor 或者一個字典(特徵名爲 Tensor),labels 是一個 Tensor 或者一個字典(特徵名爲 Tensor)。features 和 labels 都被 model_fn 所使用,應該符合 model_fn 輸入的要求。
hooks:SessionRunHook 子類實例的列表。用於在訓練循環內部執行。
steps:模型訓練的步數。若是是 None, 則一直訓練,直到input_fn 拋出了超過界限的異常。steps 是遞進式進行的。若是執行了兩次訓練(steps=10),則總共訓練了 20 次。若是中途拋出了越界異常,則訓練在 20 次以前就會中止。若是你不想遞進式進行,請換爲設置 max_steps。若是設置了 steps,則 max_steps 必須是 None。
max_steps:模型訓練的最大步數。若是爲 None,則一直訓練,直到input_fn 拋出了超過界限的異常。若是設置了 max_steps, 則 steps 必須是 None。若是中途拋出了越界異常,則訓練在 max_steps 次以前就會中止。執行兩次 train(steps=100) 意味着 200 次訓練;可是,執行兩次 train(max_steps=100) 意味着第二次執行不會進行任何訓練,由於第一次執行已經作完了全部的 100 次。
saving_listeners:CheckpointSaverListener 對象的列表。用於在保存檢查點以前或以後當即執行的回調函數。
返回:
self:爲了連接下去。
異常:
ValueError:steps 和 max_steps 都不是 None
ValueError:steps 或 max_steps <= 0
三、evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None)
根據所給數據 input_fn, 對模型進行驗證。
對於每一步,執行 input_fn(返回數據的一個 batch)。
一直進行驗證,直到:
steps 個 batches 進行完畢,或者
input_fn 拋出了越界異常(OutOfRangeError 或 StopIteration)
參數:
input_fn:一個函數,構造了驗證所需的輸入數據,必須返回如下之一:
一、一個 'tf.data.Dataset'對象:Dataset的輸出必須是一個元組 (features, labels),元組要求以下。
二、一個元組 (features, labels):features 是一個 Tensor 或者一個字典(特徵名爲 Tensor),labels 是一個 Tensor 或者一個字典(特徵名爲 Tensor)。features 和 labels 都被 model_fn 所使用,應該符合 model_fn 輸入的要求。
steps:模型驗證的步數。若是是 None, 則一直驗證,直到input_fn 拋出了超過界限的異常。
hooks:SessionRunHook 子類實例的列表。用於在驗證內部執行。
checkpoint_path: 用於驗證的檢查點路徑。若是是 None, 則使用 model_dir 中最新的檢查點。
name:驗證的名字。使用者能夠針對不一樣的數據集運行多個驗證操做,好比訓練集 vs 測試集。不一樣驗證的結果被保存在不一樣的文件夾中,且分別出如今 tensorboard 中。
返回:
返回一個字典,包括 model_fn 中指定的評價指標、global_step(包含驗證進行的全局步數)
異常:
ValueError:若是 step 小於等於0
ValueError:若是 model_dir 指定的模型沒有被訓練,或者指定的 checkpoint_path 爲空。
四、predict(self, input_fn, predict_keys=None, hooks=None, checkpoint_path=None, yield_single_examples=True)
對給出的特徵進行預測
參數:
input_fn:一個函數,構造特徵。預測一直進行下去,直到 input_fn 拋出了越界異常(OutOfRangeError 或 StopIteration)。函數必須返回如下之一:
一、一個 'tf.data.Dataset'對象:Dataset的輸出和如下的限制相同。
二、features:一個 Tensor 或者一個字典(特徵名爲 Tensor)。features 被 model_fn 所使用,應該符合 model_fn 輸入的要求。
三、一個元組,其中第一項爲 features。
predict_keys:字符串列表,要預測的鍵值。當 EstimatorSpec.predictions 是一個 dict 時使用。若是使用了 predict_keys, 那麼剩下的預測值會從字典中過濾掉。若是是 None,則返回所有。
hooks:SessionRunHook 子類實例的列表。用於在預測內部回調。
checkpoint_path: 用於預測的檢查點路徑。若是是 None, 則使用 model_dir 中最新的檢查點。
yield_single_examples:If False, yield the whole batch as returned by the model_fn instead of decomposing the batch into individual elements. This is useful if model_fn returns some tensors whose first dimension is not equal to the batch size.
返回:
predictions tensors 的值
異常:
ValueError:model_dir 中找不到訓練好的模型。
ValueError:預測值的 batch 長度不一樣,且 yield_single_examples 爲 True。
ValueError:predict_keys 和 predictions 之間有衝突。例如,predict_keys 不是 None,可是 EstimatorSpec.predictions 不是一個 dict。