近日,背景調查公司 Onfido 研究主管 Peter Roelants 在 Medium 上發表了一篇題爲《Higher-Level APIs in TensorFlow》的文章,經過實例詳細介紹瞭如何使用 TensorFlow 中的高級 API(Estimator、Experiment 和 Dataset)訓練模型。值得一提的是 Experiment 和 Dataset 能夠獨立使用。這些高級 API 已被最新發布的 TensorFlow1.3 版收錄。
TensorFlow 中有許多流行的庫,如 Keras、TFLearn 和 Sonnet,它們可讓你輕鬆訓練模型,而無需接觸哪些低級別函數。目前,Keras API 正傾向於直接在 TensorFlow 中實現,TensorFlow 也在提供愈來愈多的高級構造,其中的一些已經被最新發布的 TensorFlow1.3 版收錄。
在本文中,咱們將經過一個例子來學習如何使用一些高級構造,其中包括 Estimator、Experiment 和 Dataset。閱讀本文須要預先了解有關 TensorFlow 的基本知識。
Experiment、Estimator 和 DataSet 框架和它們的相互做用(如下將對這些組件進行說明)
在本文中,咱們使用 MNIST 做爲數據集。它是一個易於使用的數據集,能夠經過 TensorFlow 訪問。你能夠在這個 gist 中找到完整的示例代碼。使用這些框架的一個好處是咱們不須要直接處理圖形和會話
Estimator(評估器)類表明一個模型,以及這些模型被訓練和評估的方式。咱們能夠這樣構建一個評估器:
爲了構建一個 Estimator,咱們須要傳遞一個模型函數,一個參數集合以及一些配置。
-
參數應該是模型超參數的集合,它能夠是一個字典,但咱們將在本示例中將其表示爲 HParams 對象,用做 namedtuple。
-
該配置指定如何運行訓練和評估,以及如何存出結果。這些配置經過 RunConfig 對象表示,該對象傳達 Estimator 須要瞭解的關於運行模型的環境的全部內容。
-
模型函數是一個 Python 函數,它構建了給定輸入的模型(見後文)。
模型函數是一個 Python 函數,它做爲第一級函數傳遞給 Estimator。稍後咱們就會看到,TensorFlow 也會在其餘地方使用第一級函數。模型表示爲函數的好處在於模型能夠經過實例化函數不斷從新構建。該模型能夠在訓練過程當中被不一樣的輸入不斷建立,例如:在訓練期間運行驗證測試。
模型函數將輸入特徵做爲參數,相應標籤做爲張量。它還有一種模式來標記模型是否正在訓練、評估或執行推理。模型函數的最後一個參數是超參數的集合,它們與傳遞給 Estimator 的內容相同。模型函數須要返回一個 EstimatorSpec 對象——它會定義完整的模型。
EstimatorSpec 接受預測,損失,訓練和評估幾種操做,所以它定義了用於訓練,評估和推理的完整模型圖。因爲 EstimatorSpec 採用常規 TensorFlow Operations,所以咱們可使用像 TF-Slim 這樣的框架來定義本身的模型。
Experiment(實驗)類是定義如何訓練模型,並將其與 Estimator 進行集成的方式。咱們能夠這樣建立一個實驗類:
-
一個 Estimator(例如上面定義的那個)。
-
訓練和評估數據做爲第一級函數。這裏用到了和前述模型函數相同的概念,經過傳遞函數而非操做,若有須要,輸入圖能夠被重建。咱們會在後面繼續討論這個概念。
-
訓練和評估鉤子(hooks)。這些鉤子能夠用於監視或保存特定內容,或在圖形和會話中進行一些操做。例如,咱們將經過操做來幫助初始化數據加載器。
-
不一樣參數解釋了訓練時間和評估時間。
-
一旦咱們定義了 experiment,咱們就能夠經過 learn_runner.run 運行它來訓練和評估模型:
與模型函數和數據函數同樣,函數中的學習運算符將建立 experiment 做爲參數。
咱們將使用 Dataset 類和相應的 Iterator 來表示咱們的訓練和評估數據,並建立在訓練期間迭代數據的數據饋送器。在本示例中,咱們將使用 TensorFlow 中可用的 MNIST 數據,並在其周圍構建一個 Dataset 包裝器。例如,咱們把訓練的輸入數據表示爲:
調用這個 get_train_inputs 會返回一個一級函數,它在 TensorFlow 圖中建立數據加載操做,以及一個 Hook 初始化迭代器。
本示例中,咱們使用的 MNIST 數據最初表示爲 Numpy 數組。咱們建立一個佔位符張量來獲取數據,再使用佔位符來避免數據被複制。接下來,咱們在 from_tensor_slices 的幫助下建立一個切片數據集。咱們將確保該數據集運行無限長時間(experiment 能夠考慮 epoch 的數量),讓數據獲得清晰,並分紅所需的尺寸。
爲了迭代數據,咱們須要在數據集的基礎上建立迭代器。由於咱們正在使用佔位符,因此咱們須要在 NumPy 數據的相關會話中初始化佔位符。咱們能夠經過建立一個可初始化的迭代器來實現。建立圖形時,咱們將建立一個自定義的 IteratorInitializerHook 對象來初始化迭代器:
IteratorInitializerHook 繼承自 SessionRunHook。一旦建立了相關會話,這個鉤子就會調用 call after_create_session,並用正確的數據初始化佔位符。這個鉤子會經過 get_train_inputs 函數返回,並在建立時傳遞給 Experiment 對象。
train_inputs 函數返回的數據加載操做是 TensorFlow 操做,每次評估時都會返回一個新的批處理。
如今咱們已經定義了全部的東西,咱們能夠用如下命令運行代碼:
若是你不傳遞參數,它將使用文件頂部的默認標誌來肯定保存數據和模型的位置。訓練將在終端輸出全局步長、損失、精度等信息。除此以外,實驗和估算器框架將記錄 TensorBoard 能夠顯示的某些統計信息。若是咱們運行:
咱們就能夠看到全部訓練統計數據,如訓練損失、評估準確性、每步時間和模型圖。
在 TensorFlow 中,有關 Estimator、Experiment 和 Dataset 框架的示例不多,這也是本文存在的緣由。但願這篇文章能夠向你們介紹這些架構工做的原理,它們應該採用哪些抽象方法,以及如何使用它們。若是你對它們很感興趣,如下是其餘相關文檔。
關於 Estimator、Experiment 和 Dataset 的註釋
-
論文《TensorFlow Estimators: Managing Simplicity vs. Flexibility in High-Level Machine Learning Frameworks》:https://terrytangyuan.github.io/data/papers/tf-estimators-kdd-paper.pdf
-
Using the Dataset API for TensorFlow Input Pipelines:https://www.tensorflow.org/versions/r1.3/programmers_guide/datasets
-
tf.estimator.Estimator:https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator
-
tf.contrib.learn.RunConfig:https://www.tensorflow.org/api_docs/python/tf/contrib/learn/RunConfig
-
tf.estimator.DNNClassifier:https://www.tensorflow.org/api_docs/python/tf/estimator/DNNClassifier
-
tf.estimator.DNNRegressor:https://www.tensorflow.org/api_docs/python/tf/estimator/DNNRegressor
-
Creating Estimators in tf.estimator:https://www.tensorflow.org/extend/estimators
-
tf.contrib.learn.Head:https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Head
-
本文用到的 Slim 框架:https://github.com/tensorflow/models/tree/master/slim
在訓練模型後,咱們能夠運行 estimateator.predict 來預測給定圖像的類別。可以使用如下代碼示例。
原文連接:https://medium.com/onfido-tech/higher-level-apis-in-tensorflow-67bfb602e6c0