在使用TensorFlow 1.X版本的estimator的時候常常會碰到相似於ValueError:GraphDef cannot be larger than 2GB
的報錯信息,可能的緣由是數據太大沒法寫入graph。html
通常來講,常見的數據構建方法以下:python
def input_fn(): features, labels = (np.random.sample((100,2)), np.random.sample((100,1))) dataset = tf.data.Dataset.from_tensor_slices((features,labels)) dataset = dataset.shuffle(100000).repeat().batch(batch_size) return dataset ... estimator.train(input_fn)
TensorFlow在讀取數據的時候會將數據也寫入Graph,因此當數據量很大的時候會碰到這種狀況,以前作實驗在多GPU的時候也會遇到這種狀況,即便我把batch size調到很低。因此解決辦法有兩種思路,一直不保存graph,而是使用feed_dict
的方式來構建input pipeline。git
個人代碼環境是TensorFlow1.14,因此我以這個版本爲例進行介紹。github
首先總結一下estimator的運行原理(假設在單卡狀況下),以estimator.train
爲例(eval和predict相似),其調用順序以下:session
class Estimator(): ... def train(): ... loss = self._train_model(input_fn, hooks, saving_listeners) ... def _train_model(self, input_fn, hooks, saving_listeners): if self._train_distribution: return self._train_model_distributed(input_fn, hooks, saving_listeners) else: return self._train_model_default(input_fn, hooks, saving_listeners) def _train_model_default(self, input_fn, hooks, saving_listeners): ... return self._train_with_estimator_spec(estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners) def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners): .... with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=estimator_spec.scaffold, hooks=worker_hooks, chief_only_hooks=(tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=save_summary_steps, config=self._session_config, max_wait_secs=self._config.session_creation_timeout_secs, log_step_count_steps=log_step_count_steps) as mon_sess:
單步調試後發現,estimator寫入event文件發生在調用MonitoredTrainingSession
的時刻,而真正寫入event是在執行hook的時候,例如在個人實驗中我設置了log_step_count_steps
這個值,這個值會每隔指定次數steps就會打印出計算速度和當前的loss值。而實現這一功能的是StepCounterHook
,它定義在tensorflow/tensorflow/python/training/basic_session_run_hooks.py
中,部分定義以下:大數據
class StepCounterHook(session_run_hook.SessionRunHook): """Hook that counts steps per second.""" def __init__(...): ... self._summary_writer = summary_writer def begin(self): if self._summary_writer is None and self._output_dir: self._summary_writer = SummaryWriterCache.get(self._output_dir) self._summary_tag = training_util.get_global_step().op.name + "/sec" def before_run(self, run_context): # pylint: disable=unused-argument return SessionRunArgs(self._global_step_tensor) def _log_and_record(self, elapsed_steps, elapsed_time, global_step): steps_per_sec = elapsed_steps / elapsed_time if self._summary_writer is not None: summary = Summary(value=[ Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec) ]) self._summary_writer.add_summary(summary, global_step) logging.info("%s: %g", self._summary_tag, steps_per_sec)
因此咱們只須要將出現相似於self._summary_writer.add_summary
的地方註釋掉,這樣estimator在運行過程當中就不會再生成event文件,也就不會有2GB的問題了。
爲了在大數據量時使用 dataset,咱們能夠用 placeholder 建立 dataset。這時數據就不會直接寫到 graph 中,graph 中只有一個 placeholder 佔位符。可是,用了 placeholder 就須要咱們在一開始對它進行初始化填數據,須要調用 sess.run(iter.initializer, feed_dict={ x: data })
。
可是estimator並無顯示的session能夠調用,那應該怎麼辦呢?其實咱們可使用SessionRunHook
來解決這個問題。tf.train.SessionRunHook()
類定義在tensorflow/python/training/session_run_hook.py
,該類的具體介紹可參見【轉】tf.SessionRunHook使用方法。
仔細看一下 estimator 的 train 和 evaluate 函數定義能夠發現它們都接收 hooks 參數,這個參數的定義是:List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the training loop. 也就是說咱們能夠本身定義一個SessionRunHook做爲參數傳遞到hook就能夠了。
train( input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None )
咱們如今想要在訓練以前初始化 dataset 的 placeholder,那麼咱們就應該具體實現 SessionRunHook 的after_create_session 成員函數:
class IteratorInitializerHook(tf.train.SessionRunHook): def __init__(self): super(IteratorInitializerHook, self).__init__() self.iterator_initializer_fn = None def after_create_session(self, session, coord): del coord self.iterator_initializer_fn(session) def make_input_fn(): iterator_initializer_hook = IteratorInitializerHook() def input_fn(): x = tf.placeholder(tf.float32, shape=[None,2]) dataset = tf.data.Dataset.from_tensor_slices(x) dataset = dataset.shuffle(100000).repeat().batch(batch_size) iter = dataset.make_initializable_iterator() data = np.random.sample((100,2)) iterator_initializer_hook.iterator_initializer_fn = ( lambda sess: sess.run(iter.initializer, feed_dict={x: data}) ) return iter.get_next() return input_fn, iterator_initializer_hook ... input_fn, iterator_initializer_hook = make_input_fn() estimator.train(input_fn, hooks=[iterator_initializer_hook])