[譯] TensorFlow 教程 #04 - 保存 & 恢復

本篇主要介紹如何保存和恢復神經網絡變量以及Early-Stopping優化策略。
其中有大段以前教程的文字及代碼,若是看過的朋友能夠快速翻到下文Saver相關的部分。python

01 - 簡單線性模型 | 02 - 卷積神經網絡 | 03 - PrettyTensorgit

by Magnus Erik Hvass Pedersen / GitHub / Videos on YouTube
中文翻譯 thrillerist / Githubgithub

若有轉載,請附上本文連接。數組


介紹

這篇教程展現瞭如何保存以及恢復神經網絡中的變量。在優化的過程當中,當驗證集上分類準確率提升時,保存神經網絡的變量。若是通過1000次迭代還不能提高性能時,就終止優化。而後咱們從新載入在驗證集上表現最好的變量。網絡

這種策略稱爲Early-Stopping。它用來避免神經網絡的過擬合。(過擬合)會在神經網絡訓練時間太長時出現,此時神經網絡開始學習訓練集中的噪聲,將致使它誤分類新的圖像。session

這篇教程主要是用神經網絡來識別MNIST數據集中的手寫數字,過擬合在這裏並非什麼大問題。但本教程展現了Early Stopping的思想。ide

本文基於上一篇教程,你須要瞭解基本的TensorFlow和附加包Pretty Tensor。其中大量代碼和文字與以前教程類似,若是你已經看過就能夠快速地瀏覽本文。函數

流程圖

下面的圖表直接顯示了以後實現的卷積神經網絡中數據的傳遞。網絡有兩個卷積層和兩個全鏈接層,最後一層是用來給輸入圖像分類的。關於網絡和卷積的更多細節描述見教程 #02 。oop

from IPython.display import Image
Image('images/02_network_flowchart.png')複製代碼

導入

%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from sklearn.metrics import confusion_matrix
import time
from datetime import timedelta
import math
import os

# Use PrettyTensor to simplify Neural Network construction.
import prettytensor as pt複製代碼

使用Python3.5.2(Anaconda)開發,TensorFlow版本是:post

tf.__version__複製代碼

'0.12.0-rc0'

PrettyTensor 版本:

pt.__version__複製代碼

'0.7.1'

載入數據

MNIST數據集大約12MB,若是沒在給定路徑中找到就會自動下載。

from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets('data/MNIST/', one_hot=True)複製代碼

Extracting data/MNIST/train-images-idx3-ubyte.gz
Extracting data/MNIST/train-labels-idx1-ubyte.gz
Extracting data/MNIST/t10k-images-idx3-ubyte.gz
Extracting data/MNIST/t10k-labels-idx1-ubyte.gz

如今已經載入了MNIST數據集,它由70,000張圖像和對應的標籤(好比圖像的類別)組成。數據集分紅三份互相獨立的子集。咱們在教程中只用訓練集和測試集。

print("Size of:")
print("- Training-set:\t\t{}".format(len(data.train.labels)))
print("- Test-set:\t\t{}".format(len(data.test.labels)))
print("- Validation-set:\t{}".format(len(data.validation.labels)))複製代碼

Size of:
-Training-set: 55000
-Test-set: 10000
-Validation-set: 5000

類型標籤使用One-Hot編碼,這意外每一個標籤是長爲10的向量,除了一個元素以外,其餘的都爲零。這個元素的索引就是類別的數字,即相應圖片中畫的數字。咱們也須要測試數據集類別數字的整型值,用下面的方法來計算。

data.test.cls = np.argmax(data.test.labels, axis=1)
data.validation.cls = np.argmax(data.validation.labels, axis=1)複製代碼

數據維度

在下面的源碼中,有不少地方用到了數據維度。它們只在一個地方定義,所以咱們能夠在代碼中使用這些數字而不是直接寫數字。

# We know that MNIST images are 28 pixels in each dimension.
img_size = 28

# Images are stored in one-dimensional arrays of this length.
img_size_flat = img_size * img_size

# Tuple with height and width of images used to reshape arrays.
img_shape = (img_size, img_size)

# Number of colour channels for the images: 1 channel for gray-scale.
num_channels = 1

# Number of classes, one class for each of 10 digits.
num_classes = 10複製代碼

用來繪製圖片的幫助函數

這個函數用來在3x3的柵格中畫9張圖像,而後在每張圖像下面寫出真實類別和預測類別。

def plot_images(images, cls_true, cls_pred=None):
    assert len(images) == len(cls_true) == 9

    # Create figure with 3x3 sub-plots.
    fig, axes = plt.subplots(3, 3)
    fig.subplots_adjust(hspace=0.3, wspace=0.3)

    for i, ax in enumerate(axes.flat):
        # Plot image.
        ax.imshow(images[i].reshape(img_shape), cmap='binary')

        # Show true and predicted classes.
        if cls_pred is None:
            xlabel = "True: {0}".format(cls_true[i])
        else:
            xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i])

        # Show the classes as the label on the x-axis.
        ax.set_xlabel(xlabel)

        # Remove ticks from the plot.
        ax.set_xticks([])
        ax.set_yticks([])

    # Ensure the plot is shown correctly with multiple plots
    # in a single Notebook cell.
    plt.show()複製代碼

繪製幾張圖像來看看數據是否正確

# Get the first images from the test-set.
images = data.test.images[0:9]

# Get the true classes for those images.
cls_true = data.test.cls[0:9]

# Plot the images and labels using our helper-function above.
plot_images(images=images, cls_true=cls_true)複製代碼

TensorFlow圖

TensorFlow的所有目的就是使用一個稱之爲計算圖(computational graph)的東西,它會比直接在Python中進行相同計算量要高效得多。TensorFlow比Numpy更高效,由於TensorFlow瞭解整個須要運行的計算圖,然而Numpy只知道某個時間點上惟一的數學運算。

TensorFlow也可以自動地計算須要優化的變量的梯度,使得模型有更好的表現。這是因爲圖是簡單數學表達式的結合,所以整個圖的梯度能夠用鏈式法則推導出來。

TensorFlow還能利用多核CPU和GPU,Google也爲TensorFlow製造了稱爲TPUs(Tensor Processing Units)的特殊芯片,它比GPU更快。

一個TensorFlow圖由下面幾個部分組成,後面會詳細描述:

  • 佔位符變量(Placeholder)用來改變圖的輸入。
  • 模型變量(Model)將會被優化,使得模型表現得更好。
  • 模型本質上就是一些數學函數,它根據Placeholder和模型的輸入變量來計算一些輸出。
  • 一個cost度量用來指導變量的優化。
  • 一個優化策略會更新模型的變量。

另外,TensorFlow圖也包含了一些調試狀態,好比用TensorBoard打印log數據,本教程不涉及這些。

佔位符 (Placeholder)變量

Placeholder是做爲圖的輸入,咱們每次運行圖的時候均可能改變它們。將這個過程稱爲feeding placeholder變量,後面將會描述這個。

首先咱們爲輸入圖像定義placeholder變量。這讓咱們能夠改變輸入到TensorFlow圖中的圖像。這也是一個張量(tensor),表明一個多維向量或矩陣。數據類型設置爲float32,形狀設爲[None, img_size_flat]None表明tensor可能保存着任意數量的圖像,每張圖象是一個長度爲img_size_flat的向量。

x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='x')複製代碼

卷積層但願x被編碼爲4維張量,所以咱們須要將它的形狀轉換至[num_images, img_height, img_width, num_channels]。注意img_height == img_width == img_size,若是第一維的大小設爲-1, num_images的大小也會被自動推導出來。轉換運算以下:

x_image = tf.reshape(x, [-1, img_size, img_size, num_channels])複製代碼

接下來咱們爲輸入變量x中的圖像所對應的真實標籤訂義placeholder變量。變量的形狀是[None, num_classes],這表明着它保存了任意數量的標籤,每一個標籤是長度爲num_classes的向量,本例中長度爲10。

y_true = tf.placeholder(tf.float32, shape=[None, 10], name='y_true')複製代碼

咱們也能夠爲class-number提供一個placeholder,但這裏用argmax來計算它。這裏只是TensorFlow中的一些操做,沒有執行什麼運算。

y_true_cls = tf.argmax(y_true, dimension=1)複製代碼

神經網絡

這一節用PrettyTensor實現卷積神經網絡,這要比直接在TensorFlow中實現來得簡單,詳見教程 #03。

基本思想就是用一個Pretty Tensor object封裝輸入張量x_image,它有一個添加新卷積層的幫助函數,以此來建立整個神經網絡。Pretty Tensor負責變量分配等等。

x_pretty = pt.wrap(x_image)複製代碼

如今咱們已經將輸入圖像裝到一個PrettyTensor的object中,再用幾行代碼就能夠添加捲積層和全鏈接層。

注意,在with代碼塊中,pt.defaults_scope(activation_fn=tf.nn.relu)activation_fn=tf.nn.relu看成每一個的層參數,所以這些層都用到了 Rectified Linear Units (ReLU) 。defaults_scope使咱們能更方便地修改全部層的參數。

with pt.defaults_scope(activation_fn=tf.nn.relu):
    y_pred, loss = x_pretty.\
        conv2d(kernel=5, depth=16, name='layer_conv1').\
        max_pool(kernel=2, stride=2).\
        conv2d(kernel=5, depth=36, name='layer_conv2').\
        max_pool(kernel=2, stride=2).\
        flatten().\
        fully_connected(size=128, name='layer_fc1').\
        softmax_classifier(num_classes=num_classes, labels=y_true)複製代碼

獲取權重

下面,咱們要繪製神經網絡的權重。當使用Pretty Tensor來建立網絡時,層的全部變量都是由Pretty Tensoe間接建立的。所以咱們要從TensorFlow中獲取變量。

咱們用layer_conv1layer_conv2表明兩個卷積層。這也叫變量做用域(不要與上面描述的defaults_scope混淆了)。PrettyTensor會自動給它爲每一個層建立的變量命名,所以咱們能夠經過層的做用域名稱和變量名來取得某一層的權重。

函數實現有點笨拙,由於咱們不得不用TensorFlow函數get_variable(),它是設計給其餘用途的,建立新的變量或重用現有變量。建立下面的幫助函數很簡單。

def get_weights_variable(layer_name):
    # Retrieve an existing variable named 'weights' in the scope
    # with the given layer_name.
    # This is awkward because the TensorFlow function was
    # really intended for another purpose.

    with tf.variable_scope(layer_name, reuse=True):
        variable = tf.get_variable('weights')

    return variable複製代碼

藉助這個幫助函數咱們能夠獲取變量。這些是TensorFlow的objects。你須要相似的操做來獲取變量的內容: contents = session.run(weights_conv1) ,下面會提到這個。

weights_conv1 = get_weights_variable(layer_name='layer_conv1')
weights_conv2 = get_weights_variable(layer_name='layer_conv2')複製代碼

優化方法

PrettyTensor給咱們提供了預測類型標籤(y_pred)以及一個須要最小化的損失度量,用來提高神經網絡分類圖片的能力。

PrettyTensor的文檔並無說明它的損失度量是用cross-entropy仍是其餘的。但如今咱們用AdamOptimizer來最小化損失。

優化過程並非在這裏執行。實際上,還沒計算任何東西,咱們只是往TensorFlow圖中添加了優化器,以便後續操做。

optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)複製代碼

性能度量

咱們須要另一些性能度量,來向用戶展現這個過程。

首先咱們從神經網絡輸出的y_pred中計算出預測的類別,它是一個包含10個元素的向量。類別數字是最大元素的索引。

y_pred_cls = tf.argmax(y_pred, dimension=1)複製代碼

而後建立一個布爾向量,用來告訴咱們每張圖片的真實類別是否與預測類別相同。

correct_prediction = tf.equal(y_pred_cls, y_true_cls)複製代碼

上面的計算先將布爾值向量類型轉換成浮點型向量,這樣子False就變成0,True變成1,而後計算這些值的平均數,以此來計算分類的準確度。

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))複製代碼

Saver

爲了保存神經網絡的變量,咱們建立一個稱爲Saver-object的對象,它用來保存及恢復TensorFlow圖的全部變量。在這裏並未保存什麼東西,(保存操做)在後面的optimize()函數中完成。

saver = tf.train.Saver()複製代碼

因爲(保存操做)常間隔着寫在(代碼)中,所以保存的文件一般稱爲checkpoints。

這是用來保存或恢復數據的文件夾。

save_dir = 'checkpoints/'複製代碼

若是文件夾不存在則建立。

if not os.path.exists(save_dir):
    os.makedirs(save_dir)複製代碼

這是保存checkpoint文件的路徑。

save_path = os.path.join(save_dir, 'best_validation')複製代碼

運行TensorFlow

建立TensorFlow會話(session)

一旦建立了TensorFlow圖,咱們須要建立一個TensorFlow會話,用來運行圖。

session = tf.Session()複製代碼

初始化變量

變量weightsbiases在優化以前須要先進行初始化。咱們寫一個簡單的封裝函數,後面會再次調用。

def init_variables():
    session.run(tf.global_variables_initializer())複製代碼

運行函數來初始化變量。

init_variables()複製代碼

用來優化迭代的幫助函數

在訓練集中有50,000張圖。用這些圖像計算模型的梯度會花不少時間。所以咱們利用隨機梯度降低的方法,它在優化器的每次迭代裏只用到了一小部分的圖像。

若是內存耗盡致使電腦死機或變得很慢,你應該試着減小這些數量,但同時可能還須要更優化的迭代。

train_batch_size = 64複製代碼

每迭代100次下面的優化函數,會計算一次驗證集上的分類準確率。若是過了1000次迭代驗證準確率仍是沒有提高,就中止優化。咱們須要一些變量來跟蹤這個過程。

# Best validation accuracy seen so far.
best_validation_accuracy = 0.0

# Iteration-number for last improvement to validation accuracy.
last_improvement = 0

# Stop optimization if no improvement found in this many iterations.
require_improvement = 1000複製代碼

函數用來執行必定數量的優化迭代,以此來逐漸改善網絡層的變量。在每次迭代中,會從訓練集中選擇新的一批數據,而後TensorFlow在這些訓練樣本上執行優化。每100次迭代會打印出(信息),同時計算驗證準確率,若是效果有提高的話會將它保存至文件。

# Counter for total number of iterations performed so far.
total_iterations = 0

def optimize(num_iterations):
    # Ensure we update the global variables rather than local copies.
    global total_iterations
    global best_validation_accuracy
    global last_improvement

    # Start-time used for printing time-usage below.
    start_time = time.time()

    for i in range(num_iterations):

        # Increase the total number of iterations performed.
        # It is easier to update it in each iteration because
        # we need this number several times in the following.
        total_iterations += 1

        # Get a batch of training examples.
        # x_batch now holds a batch of images and
        # y_true_batch are the true labels for those images.
        x_batch, y_true_batch = data.train.next_batch(train_batch_size)

        # Put the batch into a dict with the proper names
        # for placeholder variables in the TensorFlow graph.
        feed_dict_train = {x: x_batch,
                           y_true: y_true_batch}

        # Run the optimizer using this batch of training data.
        # TensorFlow assigns the variables in feed_dict_train
        # to the placeholder variables and then runs the optimizer.
        session.run(optimizer, feed_dict=feed_dict_train)

        # Print status every 100 iterations and after last iteration.
        if (total_iterations % 100 == 0) or (i == (num_iterations - 1)):

            # Calculate the accuracy on the training-batch.
            acc_train = session.run(accuracy, feed_dict=feed_dict_train)

            # Calculate the accuracy on the validation-set.
            # The function returns 2 values but we only need the first.
            acc_validation, _ = validation_accuracy()

            # If validation accuracy is an improvement over best-known.
            if acc_validation > best_validation_accuracy:
                # Update the best-known validation accuracy.
                best_validation_accuracy = acc_validation

                # Set the iteration for the last improvement to current.
                last_improvement = total_iterations

                # Save all variables of the TensorFlow graph to file.
                saver.save(sess=session, save_path=save_path)

                # A string to be printed below, shows improvement found.
                improved_str = '*'
            else:
                # An empty string to be printed below.
                # Shows that no improvement was found.
                improved_str = ''

            # Status-message for printing.
            msg = "Iter: {0:>6}, Train-Batch Accuracy: {1:>6.1%}, Validation Acc: {2:>6.1%} {3}"

            # Print it.
            print(msg.format(i + 1, acc_train, acc_validation, improved_str))

        # If no improvement found in the required number of iterations.
        if total_iterations - last_improvement > require_improvement:
            print("No improvement found in a while, stopping optimization.")

            # Break out from the for-loop.
            break

    # Ending time.
    end_time = time.time()

    # Difference between start and end-times.
    time_dif = end_time - start_time

    # Print the time-usage.
    print("Time usage: " + str(timedelta(seconds=int(round(time_dif)))))複製代碼

用來繪製錯誤樣本的幫助函數

函數用來繪製測試集中被誤分類的樣本。

def plot_example_errors(cls_pred, correct):
    # This function is called from print_test_accuracy() below.

    # cls_pred is an array of the predicted class-number for
    # all images in the test-set.

    # correct is a boolean array whether the predicted class
    # is equal to the true class for each image in the test-set.

    # Negate the boolean array.
    incorrect = (correct == False)

    # Get the images from the test-set that have been
    # incorrectly classified.
    images = data.test.images[incorrect]

    # Get the predicted classes for those images.
    cls_pred = cls_pred[incorrect]

    # Get the true classes for those images.
    cls_true = data.test.cls[incorrect]

    # Plot the first 9 images.
    plot_images(images=images[0:9],
                cls_true=cls_true[0:9],
                cls_pred=cls_pred[0:9])複製代碼

繪製混淆(confusion)矩陣的幫助函數

def plot_confusion_matrix(cls_pred):
    # This is called from print_test_accuracy() below.

    # cls_pred is an array of the predicted class-number for
    # all images in the test-set.

    # Get the true classifications for the test-set.
    cls_true = data.test.cls

    # Get the confusion matrix using sklearn.
    cm = confusion_matrix(y_true=cls_true,
                          y_pred=cls_pred)

    # Print the confusion matrix as text.
    print(cm)

    # Plot the confusion matrix as an image.
    plt.matshow(cm)

    # Make various adjustments to the plot.
    plt.colorbar()
    tick_marks = np.arange(num_classes)
    plt.xticks(tick_marks, range(num_classes))
    plt.yticks(tick_marks, range(num_classes))
    plt.xlabel('Predicted')
    plt.ylabel('True')

    # Ensure the plot is shown correctly with multiple plots
    # in a single Notebook cell.
    plt.show()複製代碼

計算分類的幫助函數

這個函數用來計算圖像的預測類別,同時返回一個表明每張圖像分類是否正確的布爾數組。

因爲計算可能會耗費太多內存,就分批處理。若是你的電腦死機了,試着下降batch-size。

# Split the data-set in batches of this size to limit RAM usage.
batch_size = 256

def predict_cls(images, labels, cls_true):
    # Number of images.
    num_images = len(images)

    # Allocate an array for the predicted classes which
    # will be calculated in batches and filled into this array.
    cls_pred = np.zeros(shape=num_images, dtype=np.int)

    # Now calculate the predicted classes for the batches.
    # We will just iterate through all the batches.
    # There might be a more clever and Pythonic way of doing this.

    # The starting index for the next batch is denoted i.
    i = 0

    while i < num_images:
        # The ending index for the next batch is denoted j.
        j = min(i + batch_size, num_images)

        # Create a feed-dict with the images and labels
        # between index i and j.
        feed_dict = {x: images[i:j, :],
                     y_true: labels[i:j, :]}

        # Calculate the predicted class using TensorFlow.
        cls_pred[i:j] = session.run(y_pred_cls, feed_dict=feed_dict)

        # Set the start-index for the next batch to the
        # end-index of the current batch.
        i = j

    # Create a boolean array whether each image is correctly classified.
    correct = (cls_true == cls_pred)

    return correct, cls_pred複製代碼

計算測試集上的預測類別。

def predict_cls_test():
    return predict_cls(images = data.test.images,
                       labels = data.test.labels,
                       cls_true = data.test.cls)複製代碼

計算驗證集上的預測類別。

def predict_cls_validation():
    return predict_cls(images = data.validation.images,
                       labels = data.validation.labels,
                       cls_true = data.validation.cls)複製代碼

分類準確率的幫助函數

這個函數計算了給定布爾數組的分類準確率,布爾數組表示每張圖像是否被正確分類。好比, cls_accuracy([True, True, False, False, False]) = 2/5 = 0.4

def cls_accuracy(correct):
    # Calculate the number of correctly classified images.
    # When summing a boolean array, False means 0 and True means 1.
    correct_sum = correct.sum()

    # Classification accuracy is the number of correctly classified
    # images divided by the total number of images in the test-set.
    acc = float(correct_sum) / len(correct)

    return acc, correct_sum複製代碼

計算驗證集上的分類準確率。

def validation_accuracy():
    # Get the array of booleans whether the classifications are correct
    # for the validation-set.
    # The function returns two values but we only need the first.
    correct, _ = predict_cls_validation()

    # Calculate the classification accuracy and return it.
    return cls_accuracy(correct)複製代碼

展現性能的幫助函數

函數用來打印測試集上的分類準確率。

爲測試集上的全部圖片計算分類會花費一段時間,所以咱們直接從這個函數裏調用上面的函數,這樣就不用每一個函數都從新計算分類。

def print_test_accuracy(show_example_errors=False, show_confusion_matrix=False):

    # For all the images in the test-set,
    # calculate the predicted classes and whether they are correct.
    correct, cls_pred = predict_cls_test()

    # Classification accuracy and the number of correct classifications.
    acc, num_correct = cls_accuracy(correct)

    # Number of images being classified.
    num_images = len(correct)

    # Print the accuracy.
    msg = "Accuracy on Test-Set: {0:.1%} ({1} / {2})"
    print(msg.format(acc, num_correct, num_images))

    # Plot some examples of mis-classifications, if desired.
    if show_example_errors:
        print("Example errors:")
        plot_example_errors(cls_pred=cls_pred, correct=correct)

    # Plot the confusion matrix, if desired.
    if show_confusion_matrix:
        print("Confusion Matrix:")
        plot_confusion_matrix(cls_pred=cls_pred)複製代碼

繪製卷積權重的幫助函數

def plot_conv_weights(weights, input_channel=0):
    # Assume weights are TensorFlow ops for 4-dim variables
    # e.g. weights_conv1 or weights_conv2.

    # Retrieve the values of the weight-variables from TensorFlow.
    # A feed-dict is not necessary because nothing is calculated.
    w = session.run(weights)

    # Print mean and standard deviation.
    print("Mean: {0:.5f}, Stdev: {1:.5f}".format(w.mean(), w.std()))

    # Get the lowest and highest values for the weights.
    # This is used to correct the colour intensity across
    # the images so they can be compared with each other.
    w_min = np.min(w)
    w_max = np.max(w)

    # Number of filters used in the conv. layer.
    num_filters = w.shape[3]

    # Number of grids to plot.
    # Rounded-up, square-root of the number of filters.
    num_grids = math.ceil(math.sqrt(num_filters))

    # Create figure with a grid of sub-plots.
    fig, axes = plt.subplots(num_grids, num_grids)

    # Plot all the filter-weights.
    for i, ax in enumerate(axes.flat):
        # Only plot the valid filter-weights.
        if i<num_filters:
            # Get the weights for the i'th filter of the input channel.
            # The format of this 4-dim tensor is determined by the
            # TensorFlow API. See Tutorial #02 for more details.
            img = w[:, :, input_channel, i]

            # Plot image.
            ax.imshow(img, vmin=w_min, vmax=w_max,
                      interpolation='nearest', cmap='seismic')

        # Remove ticks from the plot.
        ax.set_xticks([])
        ax.set_yticks([])

    # Ensure the plot is shown correctly with multiple plots
    # in a single Notebook cell.
    plt.show()複製代碼

優化以前的性能

測試集上的準確度很低,這是因爲模型只作了初始化,並沒作任何優化,因此它只是對圖像作隨機分類。

print_test_accuracy()複製代碼

Accuracy on Test-Set: 8.5% (849 / 10000)

卷積權重是隨機的,但也很難把它與下面優化過的權重區分開來。這裏也展現了平均值和標準差,所以咱們能夠看看是否有差異。

plot_conv_weights(weights=weights_conv1)複製代碼

Mean: 0.00880, Stdev: 0.28635

10,000次優化迭代後的性能

如今咱們進行了10,000次優化迭代,而且,當通過1000次迭代驗證集上的性能卻沒有提高時就中止優化。

星號 * 表明驗證集上的分類準確度有提高。

optimize(num_iterations=10000)複製代碼

Iter: 100, Train-Batch Accuracy: 84.4%, Validation Acc: 85.2%
Iter: 200, Train-Batch Accuracy: 92.2%, Validation Acc: 91.5%

Iter: 300, Train-Batch Accuracy: 95.3%, Validation Acc: 93.7%
Iter: 400, Train-Batch Accuracy: 92.2%, Validation Acc: 94.3%

Iter: 500, Train-Batch Accuracy: 98.4%, Validation Acc: 94.7%
Iter: 600, Train-Batch Accuracy: 93.8%, Validation Acc: 94.7%
Iter: 700, Train-Batch Accuracy: 98.4%, Validation Acc: 95.6%

Iter: 800, Train-Batch Accuracy: 100.0%, Validation Acc: 96.3%
Iter: 900, Train-Batch Accuracy: 98.4%, Validation Acc: 96.4%

Iter: 1000, Train-Batch Accuracy: 100.0%, Validation Acc: 96.9%
Iter: 1100, Train-Batch Accuracy: 96.9%, Validation Acc: 97.0%

Iter: 1200, Train-Batch Accuracy: 93.8%, Validation Acc: 97.0%
Iter: 1300, Train-Batch Accuracy: 92.2%, Validation Acc: 97.2%

Iter: 1400, Train-Batch Accuracy: 100.0%, Validation Acc: 97.3%
Iter: 1500, Train-Batch Accuracy: 96.9%, Validation Acc: 97.4%

Iter: 1600, Train-Batch Accuracy: 100.0%, Validation Acc: 97.7%
Iter: 1700, Train-Batch Accuracy: 100.0%, Validation Acc: 97.8%

Iter: 1800, Train-Batch Accuracy: 98.4%, Validation Acc: 97.7%
Iter: 1900, Train-Batch Accuracy: 98.4%, Validation Acc: 98.1%
Iter: 2000, Train-Batch Accuracy: 95.3%, Validation Acc: 98.0%
Iter: 2100, Train-Batch Accuracy: 98.4%, Validation Acc: 97.9%
Iter: 2200, Train-Batch Accuracy: 100.0%, Validation Acc: 98.0%
Iter: 2300, Train-Batch Accuracy: 96.9%, Validation Acc: 98.1%
Iter: 2400, Train-Batch Accuracy: 93.8%, Validation Acc: 98.1%
Iter: 2500, Train-Batch Accuracy: 98.4%, Validation Acc: 98.2%

Iter: 2600, Train-Batch Accuracy: 98.4%, Validation Acc: 98.0%
Iter: 2700, Train-Batch Accuracy: 98.4%, Validation Acc: 98.0%
Iter: 2800, Train-Batch Accuracy: 96.9%, Validation Acc: 98.1%
Iter: 2900, Train-Batch Accuracy: 96.9%, Validation Acc: 98.2%
Iter: 3000, Train-Batch Accuracy: 98.4%, Validation Acc: 98.2%
Iter: 3100, Train-Batch Accuracy: 100.0%, Validation Acc: 98.1%
Iter: 3200, Train-Batch Accuracy: 100.0%, Validation Acc: 98.3%
Iter: 3300, Train-Batch Accuracy: 98.4%, Validation Acc: 98.4%

Iter: 3400, Train-Batch Accuracy: 95.3%, Validation Acc: 98.0%
Iter: 3500, Train-Batch Accuracy: 98.4%, Validation Acc: 98.3%
Iter: 3600, Train-Batch Accuracy: 100.0%, Validation Acc: 98.5%
Iter: 3700, Train-Batch Accuracy: 98.4%, Validation Acc: 98.3%
Iter: 3800, Train-Batch Accuracy: 96.9%, Validation Acc: 98.1%
Iter: 3900, Train-Batch Accuracy: 96.9%, Validation Acc: 98.5%
Iter: 4000, Train-Batch Accuracy: 100.0%, Validation Acc: 98.4%
Iter: 4100, Train-Batch Accuracy: 100.0%, Validation Acc: 98.5%
Iter: 4200, Train-Batch Accuracy: 100.0%, Validation Acc: 98.3%
Iter: 4300, Train-Batch Accuracy: 100.0%, Validation Acc: 98.6%

Iter: 4400, Train-Batch Accuracy: 96.9%, Validation Acc: 98.4%
Iter: 4500, Train-Batch Accuracy: 98.4%, Validation Acc: 98.5%
Iter: 4600, Train-Batch Accuracy: 98.4%, Validation Acc: 98.5%
Iter: 4700, Train-Batch Accuracy: 98.4%, Validation Acc: 98.4%
Iter: 4800, Train-Batch Accuracy: 100.0%, Validation Acc: 98.8% *
Iter: 4900, Train-Batch Accuracy: 100.0%, Validation Acc: 98.8%
Iter: 5000, Train-Batch Accuracy: 98.4%, Validation Acc: 98.6%
Iter: 5100, Train-Batch Accuracy: 98.4%, Validation Acc: 98.6%
Iter: 5200, Train-Batch Accuracy: 100.0%, Validation Acc: 98.6%
Iter: 5300, Train-Batch Accuracy: 96.9%, Validation Acc: 98.5%
Iter: 5400, Train-Batch Accuracy: 98.4%, Validation Acc: 98.7%
Iter: 5500, Train-Batch Accuracy: 98.4%, Validation Acc: 98.6%
Iter: 5600, Train-Batch Accuracy: 100.0%, Validation Acc: 98.4%
Iter: 5700, Train-Batch Accuracy: 100.0%, Validation Acc: 98.6%
Iter: 5800, Train-Batch Accuracy: 100.0%, Validation Acc: 98.7%
No improvement found in a while, stopping optimization.
Time usage: 0:00:28

print_test_accuracy(show_example_errors=True,
                    show_confusion_matrix=True)複製代碼

Accuracy on Test-Set: 98.4% (9842 / 10000)
Example errors:

Confusion Matrix:
[[ 974 0 0 0 0 1 2 0 2 1]
[ 0 1127 2 2 0 0 1 0 3 0]
[ 4 4 1012 4 1 0 0 3 4 0]
[ 0 0 1 1005 0 2 0 0 2 0]
[ 1 0 1 0 961 0 2 0 3 14]
[ 2 0 1 6 0 880 1 0 1 1]
[ 4 2 0 1 3 4 942 0 2 0]
[ 1 1 8 6 1 0 0 994 1 16]
[ 6 0 1 4 1 1 1 2 952 6]
[ 3 3 0 3 2 2 0 0 1 995]]

如今卷積權重是通過優化的。將這些與上面的隨機權重進行對比。它們看起來基本相同。實際上,一開始我覺得程序有bug,由於優化先後的權重看起來差很少。

但保存圖像,並排着比較它們(你能夠右鍵保存)。你會發現二者有細微的不一樣。

平均值和標準差也有一點變化,所以優化過的權重確定是不同的。

plot_conv_weights(weights=weights_conv1)複製代碼

Mean: 0.02895, Stdev: 0.29949

再次初始化變量

再一次用隨機值來初始化全部神經網絡變量。

init_variables()複製代碼

這意味着神經網絡又是徹底隨機地對圖片進行分類,因爲只是隨機的猜想因此分類準確率很低。

print_test_accuracy()複製代碼

Accuracy on Test-Set: 13.4% (1341 / 10000)

卷積權重看起來應該與上面的不一樣。

plot_conv_weights(weights=weights_conv1)複製代碼

Mean: -0.01086, Stdev: 0.28023

恢復最好的變量

從新載入在優化過程當中保存到文件的全部變量。

saver.restore(sess=session, save_path=save_path)複製代碼

使用以前保存的那些變量,分類準確率又提升了。

注意,準確率與以前相比可能會有細微的上升或降低,這是因爲文件裏的變量是用來最大化驗證集上的分類準確率,但在保存文件以後,又進行了1000次的優化迭代,所以這是兩組有輕微不一樣的變量的結果。有時這會致使測試集上更好或更差的表現。

print_test_accuracy(show_example_errors=True,
                    show_confusion_matrix=True)複製代碼

Accuracy on Test-Set: 98.3% (9826 / 10000)
Example errors:

Confusion Matrix:
[[ 973 0 0 0 0 0 2 0 3 2]
[ 0 1124 2 2 0 0 3 0 4 0]
[ 2 1 1027 0 0 0 0 1 1 0]
[ 0 0 1 1005 0 2 0 0 2 0]
[ 0 0 3 0 968 0 1 0 3 7]
[ 2 0 1 9 0 871 3 0 3 3]
[ 4 2 1 0 3 3 939 0 6 0]
[ 1 3 19 11 2 0 0 972 2 18]
[ 6 0 3 5 1 0 1 2 951 5]
[ 3 3 0 1 4 1 0 0 1 996]]

卷積權重也與以前顯示的圖幾乎相同,一樣,因爲多作了1000次優化迭代,兩者並不是徹底同樣。

plot_conv_weights(weights=weights_conv1)複製代碼

Mean: 0.02792, Stdev: 0.29822

關閉TensorFlow會話

如今咱們已經用TensorFlow完成了任務,關閉session,釋放資源。

# This has been commented out in case you want to modify and experiment
# with the Notebook without having to restart it.
# session.close()複製代碼

總結

這篇教程描述了在TensorFlow中如何保存並恢復神經網絡的變量。它有許多用處。好比,當你用神經網絡來識別圖像的時候,只須要訓練網絡一次,而後能夠在其餘電腦上完成開發工做。

checkpoint的另外一個用處是,若是你有一個很是大的神經網絡和數據集,就可能會在中間保存一些checkpoints來避免電腦死機,這樣,你就能夠在最近的checkpoint開始優化而不是重頭開始。

本教程也展現瞭如何用驗證集來進行所謂的Early Stopping,若是沒有下降驗證錯誤優化就會終止。這在神經網絡出現過擬合以及開始學習訓練集中的噪聲時頗有用;不過這在本教程的神經網絡和MNIST數據集中並非什麼大問題。

還有一個有趣的現象,最優化時卷積權重(或者叫濾波)的變化很小,即便網絡的性能從隨機猜想提升到近乎完美的分類。奇怪的是隨機的權重好像已經足夠好了。你認爲爲何會有這種現象?

練習

下面使一些可能會讓你提高TensorFlow技能的一些建議練習。爲了學習如何更合適地使用TensorFlow,實踐經驗是很重要的。

在你對這個Notebook進行修改以前,可能須要先備份一下。

  • 在通過1000次迭代而性能沒有提高時,優化就終止了。這樣夠嗎?你能想出一個更好地進行Early Stopping的方法麼?試着實現它。
  • 若是checkpoint文件已經存在了,載入它而不是作優化。
  • 每100次優化迭代保存一次checkpoint。經過saver.latest_checkpoint()取回最新的(保存點)。爲何保存多個checkpoints而不是隻保存最近的一個?
  • 試着改變神經網絡,好比添加其餘層。當你從不一樣的網絡中從新載入變量會出現什麼問題?
  • plot_conv_weights()函數在優化先後畫出第二個卷積層的權重。它們幾乎相同的麼?
  • 你認爲優化過的卷積權重爲何與隨機初始化的(權重)幾乎相同?
  • 不看源碼,本身重寫程序。
  • 向朋友解釋程序如何工做。
相關文章
相關標籤/搜索