Google 開發者大會 (Google Developer Days,簡稱 GDD) 是展現 Google 最新開發者產品和平臺的全球盛會,旨在幫助你快速開發優質應用,發展和留住活躍用戶羣,充分利用各類工具得到更多收益。2018 Google 開發者大會於 9 月 20 日和 21 日於上海舉辦。👉Google 開發者大會 2018 掘金專題python
GDD 2018 次日的 9 月 21 日 ,陳爽(Google Brain 軟件工程師)爲咱們帶來了《以 tf.data 優化訓練數據》,講解如何使用 tf.data 爲各種模型打造高性能的 TensorFlow 輸入渠道,本文將摘錄演講技術乾貨。git
圖中代碼分別對應 ETL 系統的三個步驟,使用 tf.data 便可輕鬆實現。github
files = tf.data.Dataset.list_files("training-*-of-1024.tfrecord")
dataset = tf.data.TFRecordDataset(files, num_parallel_reads=32)
複製代碼
dataset = dataset.apply(tf.contrib.data.shuffle_and_repaeat(10000, NUM_EPOCHS))
dataset = dataset.apply(tf.contrib.data.map_and_batch(lambda x: ..., BATCH_SIZE))
複製代碼
dataset = dataset.apply(tf.contrib.data.prefetch_to_device("/gpu:0"))
複製代碼
最終代碼以下圖所示,更多優化手段能夠參考 tf.data 性能指南:sql
如上圖,能夠用自定義的 map_fn 處理 TensorFlow 或兼容的函數,同時支持 AutoGraph 處理過的函數。數據庫
以下圖,使用 Python 自帶的 urllib 獲取服務器數據,存入 dataset:編程
如普通文件系統丶GCP 雲儲存丶其餘雲儲存丶SQL 數據庫等。服務器
讀取 Google 雲儲存的 TFRecord 文件示例:多線程
files = tf.contrib.data.TFRecordDataset(
"gs://path/to/file.tfrecord", num_parallel_reads=32)
複製代碼
使用自訂 SQL 數據庫示例:架構
files = tf.contrib.data.SqlDataset(
"sqllite", "/foo/db.sqlite", "SELECT name, age FROM people",
(tf.string, tf.int32))
複製代碼
tf.enable_eager_execution()
for batch in dataset:
train_model(batch)
複製代碼
上圖能夠簡單替換爲一個函數:app
dataset = tf.contrib.data.make_batched_features_dataset(
"training-*-of-1024.tfrecord",
BATCH_SIZE, features, num_epochs=NUM_EPOCHS)
複製代碼
使用 CSV 數據集的情境:
dataset = tf.contrib.data.make_csv_dataset(
"*.csv", BATCH_SIZE, num_epochs=NUM_EPOCHS)
複製代碼
能夠簡單的使用 AUTOTUNE 找到 prefetching 的最佳參數:
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
複製代碼
對於 Keras,能夠將 dataset 直接傳遞使用;對於 Estimators 訓練函數,將 dataset 包裝至輸入函數並返回便可,以下示例:
def input_fn():
dataset = tf.contrib.data.make_csv_dataset(
"*.csv", BATCH_SIZE, num_epochs=NUM_EPOCHS)
return dataset
tf.estimator.Estimator(model_fn=train_model).train(input_fn=input_fn)
複製代碼
本場演講介紹了 tf.data 這個兼具高效丶靈活與易用的 API,同時瞭解如何運用管道化及其餘優化手段來增進運算效能,以及許多可能不曾發現的實用函數。