當咱們訓練一個deep learning模型時,怎麼樣判斷當前是過擬合,仍是欠擬合等狀態呢?實踐中,咱們經常會將數據集分爲三部分:train、validation、test。訓練過程當中,咱們讓模型盡力擬合train數據集,在validation數據集上測試擬合程度。當訓練過程結束後,咱們在test集上測試模型最終效果。有經驗的煉丹師每每會經過模型在train和validation上的表現,來判斷當前是不是過擬合,是不是欠擬合。這個時候,TensorBoard就派上了大用場!python
有沒有覺的一目瞭然呢?我強烈推薦你們使用TensorBoard,使用後煉丹功力顯著提高!api
下面,我來說一下如何使用TensorBoard。要使用,也要優雅!
若是你喜歡本身梳理知識,本身嘗試,那麼不妨閱讀官方文檔:戳這裏查看官方文檔
否則的話,就隨着老夫玩轉TensorBoard吧 ^0^瀏覽器
熟悉一個新知識的時候,應該將沒必要要的東西最精簡化,將注意力集中到咱們最關注的地方,因此,我寫了一個最簡單的模型,在這個模型的基礎上對TensorBoard進行探索。網絡
首先看一下這個極簡的線性模型:session
import tensorflow as tf import random class Model(object): def __init__(self): self.input_x = tf.placeholder(dtype=tf.float32, shape=[None, ], name='x') self.input_y = tf.placeholder(dtype=tf.float32, shape=[None, ], name='y') W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), dtype=tf.float32) b = tf.Variable(tf.random_uniform([1], -1.0, 1.0), dtype=tf.float32) y_predict = self.input_x * W + b self.loss = tf.reduce_sum(tf.abs(y_predict - self.input_y))
相信這個模型你們很快就能看懂,因此就很少說了。接下來看構造數據的代碼:app
x_all = [] y_all = [] random.seed(10) for i in range(3000): x = random.random() y = 0.3 * x + 0.1 + random.random() x_all.append(x) y_all.append(y) x_all = np.array(x_all) y_all = np.array(y_all) shuffle_indices = np.random.permutation(np.arange(len(x_all))) x_shuffled = x_all[shuffle_indices] y_shuffled = y_all[shuffle_indices] bound = int(len(x_all) / 10 * 7) x_train = x_shuffled[:bound] y_train = y_shuffled[:bound] x_val = x_shuffled[bound:] y_val = y_shuffled[bound:]
這段代碼裏作了三件事:dom
下面是對數據按batch取出:ide
def batch_iter(data, batch_size, num_epochs, shuffle=True): """ Generates a batch iterator for a dataset. """ data = np.array(data) data_size = len(data) num_batches_per_epoch = int((len(data)-1)/batch_size) + 1 for epoch in range(num_epochs): # Shuffle the data at each epoch if shuffle: shuffle_indices = np.random.permutation(np.arange(data_size)) shuffled_data = data[shuffle_indices] else: shuffled_data = data for batch_num in range(num_batches_per_epoch): start_index = batch_num * batch_size end_index = min((batch_num + 1) * batch_size, data_size) yield shuffled_data[start_index:end_index]
而後就到了比較本篇博客的核心部分:
首先我來描述一下關鍵的函數(大部分同窗心裏必定是拒絕的 2333,因此建議先看下面的代碼,而後再反過頭來看函數的介紹):函數
tf.summary.scalar(name, tensor, collections=None, family=None),調用這個函數來觀察Tensorflow的Graph中某個節點測試
tf.summary.merge(inputs, collections=None, name=None)
tf.summary.FileWriter,在給定的目錄中建立一個事件文件(event file),將summraies保存到該文件夾中。
__init__(logdir, graph=None, max_queue=10, flush_secs=120, graph_def=None, filename_suffix=None)
add_summary(summary, global_step=None)
with tf.Graph().as_default(): sess = tf.Session() with sess.as_default(): m = model.Model() global_step = tf.Variable(0, name='global_step', trainable=False) optimizer = tf.train.AdamOptimizer(1e-2) grads_and_vars = optimizer.compute_gradients(m.loss) train_op = optimizer.apply_gradients(grads_and_vars=grads_and_vars, global_step=global_step) loss_summary = tf.summary.scalar('loss', m.loss) train_summary_op = tf.summary.merge([loss_summary]) train_summary_writer = tf.summary.FileWriter('./summary/train', sess.graph) dev_summary_op = tf.summary.merge([loss_summary]) dev_summary_writer = tf.summary.FileWriter('./summary/dev', sess.graph) def train_step(x_batch, y_batch): feed_dict = {m.input_x: x_batch, m.input_y: y_batch} _, step, summaries, loss = sess.run( [train_op, global_step, train_summary_op, m.loss], feed_dict) train_summary_writer.add_summary(summaries, step) def dev_step(x_batch, y_batch): feed_dict = {m.input_x: x_batch, m.input_y: y_batch} step, summaries, loss = sess.run( [global_step, dev_summary_op, m.loss], feed_dict) dev_summary_writer.add_summary(summaries, step) sess.run(tf.global_variables_initializer()) batches = batch_iter(list(zip(x_train, y_train)), 100, 100) for batch in batches: x_batch, y_batch = zip(*batch) train_step(x_batch, y_batch) current_step = tf.train.global_step(sess, global_step) if current_step % 3 == 0: print('\nEvaluation:') dev_step(x_val, y_val)
如今咱們就可使用TensorBoard查看訓練過程了~~
在terminal中輸入以下命令:
tensorboard --logdir=summary
TensorBoard 0.4.0rc3 at http://liudaoxing-Lenovo-Rescuer-15ISK:6006 (Press CTRL+C to quit)
沒錯!這就是咱們train和validation過程當中loss的狀況。
點擊GRAPHS,就能夠看到網絡的結構
麻雀雖小,五臟俱全。但願你們有收穫~