from IPython.display import Image, display
Image('images/08_transfer_learning_flowchart.png')複製代碼
導入
%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import time
from datetime import timedelta
import os
# Functions and classes for loading and using the Inception model.import inception
# We use Pretty Tensor to define the new classifier.import prettytensor as pt複製代碼
使用Python3.5.2(Anaconda)開發,TensorFlow版本是:
tf.__version__複製代碼
'0.12.0-rc0'
PrettyTensor 版本:
pt.__version__複製代碼
'0.7.1'
載入CIFAR-10數據
import cifar10複製代碼
cirfa10模塊中已經定義好了數據維度,所以咱們須要時只要導入就行。
from cifar10 import num_classes複製代碼
設置電腦上保存數據集的路徑。
# cifar10.data_path = "data/CIFAR-10/"複製代碼
CIFAR-10數據集大概有163MB,若是給定路徑沒有找到文件的話,將會自動下載。
cifar10.maybe_download_and_extract()複製代碼
Data has apparently already been downloaded and unpacked.
defplot_images(images, cls_true, cls_pred=None, smooth=True):assert len(images) == len(cls_true)
# Create figure with sub-plots.
fig, axes = plt.subplots(3, 3)
# Adjust vertical spacing.if cls_pred isNone:
hspace = 0.3else:
hspace = 0.6
fig.subplots_adjust(hspace=hspace, wspace=0.3)
# Interpolation type.if smooth:
interpolation = 'spline16'else:
interpolation = 'nearest'for i, ax in enumerate(axes.flat):
# There may be less than 9 images, ensure it doesn't crash.if i < len(images):
# Plot image.
ax.imshow(images[i],
interpolation=interpolation)
# Name of the true class.
cls_true_name = class_names[cls_true[i]]
# Show true and predicted classes.if cls_pred isNone:
xlabel = "True: {0}".format(cls_true_name)
else:
# Name of the predicted class.
cls_pred_name = class_names[cls_pred[i]]
xlabel = "True: {0}\nPred: {1}".format(cls_true_name, cls_pred_name)
# Show the classes as the label on the x-axis.
ax.set_xlabel(xlabel)
# Remove ticks from the plot.
ax.set_xticks([])
ax.set_yticks([])
# Ensure the plot is shown correctly with multiple plots# in a single Notebook cell.
plt.show()複製代碼
繪製幾張圖像看看數據是否正確
# Get the first images from the test-set.
images = images_test[0:9]
# Get the true classes for those images.
cls_true = cls_test[0:9]
# Plot the images and labels using our helper-function above.
plot_images(images=images, cls_true=cls_true, smooth=False)複製代碼
下載Inception模型
從網上下載Inception模型。這是你保存數據文件的默認文件夾。若是文件夾不存在就自動建立。
# inception.data_dir = 'inception/'複製代碼
若是文件夾中不存在Inception模型,就自動下載。 它有85MB。
更多詳情見教程#07。
inception.maybe_download()複製代碼
Downloading Inception v3 Model ... Data has apparently already been downloaded and unpacked.
print("Processing Inception transfer-values for training-images ...")
# Scale images because Inception needs pixels to be between 0 and 255,# while the CIFAR-10 functions return pixels between 0.0 and 1.0
images_scaled = images_train * 255.0# If transfer-values have already been calculated then reload them,# otherwise calculate them and save them to a cache-file.
transfer_values_train = transfer_values_cache(cache_path=file_path_cache_train,
images=images_scaled,
model=model)複製代碼
Processing Inception transfer-values for training-images ...
- Data loaded from cache-file: data/CIFAR-10/inception_cifar10_train.pkl複製代碼
print("Processing Inception transfer-values for test-images ...")
# Scale images because Inception needs pixels to be between 0 and 255,# while the CIFAR-10 functions return pixels between 0.0 and 1.0
images_scaled = images_test * 255.0# If transfer-values have already been calculated then reload them,# otherwise calculate them and save them to a cache-file.
transfer_values_test = transfer_values_cache(cache_path=file_path_cache_test,
images=images_scaled,
model=model)複製代碼
Processing Inception transfer-values for test-images ...
Data loaded from cache-file: data/CIFAR-10/inception_cifar10_test.pkl
defplot_transfer_values(i):
print("Input image:")
# Plot the i'th image from the test-set.
plt.imshow(images_test[i], interpolation='nearest')
plt.show()
print("Transfer-values for the image using Inception model:")
# Transform the transfer-values into an image.
img = transfer_values_test[i]
img = img.reshape((32, 64))
# Plot the image for the transfer-values.
plt.imshow(img, interpolation='nearest', cmap='Reds')
plt.show()複製代碼
plot_transfer_values(i=16)複製代碼
Input image:
Transfer-values for the image using Inception model:
plot_transfer_values(i=17)複製代碼
Input image:
Transfer-values for the image using Inception model:
defplot_scatter(values, cls):# Create a color-map with a different color for each class.import matplotlib.cm as cm
cmap = cm.rainbow(np.linspace(0.0, 1.0, num_classes))
# Get the color for each sample.
colors = cmap[cls]
# Extract the x- and y-values.
x = values[:, 0]
y = values[:, 1]
# Plot it.
plt.scatter(x, y, color=colors)
plt.show()複製代碼
# Wrap the transfer-values as a Pretty Tensor object.
x_pretty = pt.wrap(x)
with pt.defaults_scope(activation_fn=tf.nn.relu):
y_pred, loss = x_pretty.\
fully_connected(size=1024, name='layer_fc1').\
softmax_classifier(num_classes=num_classes, labels=y_true)複製代碼
defrandom_batch():# Number of images (transfer-values) in the training-set.
num_images = len(transfer_values_train)
# Create a random index.
idx = np.random.choice(num_images,
size=train_batch_size,
replace=False)
# Use the random index to select random x and y-values.# We use the transfer-values instead of images as x-values.
x_batch = transfer_values_train[idx]
y_batch = labels_train[idx]
return x_batch, y_batch複製代碼
defoptimize(num_iterations):# Start-time used for printing time-usage below.
start_time = time.time()
for i in range(num_iterations):
# Get a batch of training examples.# x_batch now holds a batch of images (transfer-values) and# y_true_batch are the true labels for those images.
x_batch, y_true_batch = random_batch()
# Put the batch into a dict with the proper names# for placeholder variables in the TensorFlow graph.
feed_dict_train = {x: x_batch,
y_true: y_true_batch}
# Run the optimizer using this batch of training data.# TensorFlow assigns the variables in feed_dict_train# to the placeholder variables and then runs the optimizer.# We also want to retrieve the global_step counter.
i_global, _ = session.run([global_step, optimizer],
feed_dict=feed_dict_train)
# Print status to screen every 100 iterations (and last).if (i_global % 100 == 0) or (i == num_iterations - 1):
# Calculate the accuracy on the training-batch.
batch_acc = session.run(accuracy,
feed_dict=feed_dict_train)
# Print status.
msg = "Global Step: {0:>6}, Training Batch Accuracy: {1:>6.1%}"
print(msg.format(i_global, batch_acc))
# Ending time.
end_time = time.time()
# Difference between start and end-times.
time_dif = end_time - start_time
# Print the time-usage.
print("Time usage: " + str(timedelta(seconds=int(round(time_dif)))))複製代碼
展現結果的幫助函數
繪製錯誤樣本的幫助函數
函數用來繪製測試集中被誤分類的樣本。
defplot_example_errors(cls_pred, correct):# This function is called from print_test_accuracy() below.# cls_pred is an array of the predicted class-number for# all images in the test-set.# correct is a boolean array whether the predicted class# is equal to the true class for each image in the test-set.# Negate the boolean array.
incorrect = (correct == False)
# Get the images from the test-set that have been# incorrectly classified.
images = images_test[incorrect]
# Get the predicted classes for those images.
cls_pred = cls_pred[incorrect]
# Get the true classes for those images.
cls_true = cls_test[incorrect]
n = min(9, len(images))
# Plot the first n images.
plot_images(images=images[0:n],
cls_true=cls_true[0:n],
cls_pred=cls_pred[0:n])複製代碼
繪製混淆(confusion)矩陣的幫助函數
# Import a function from sklearn to calculate the confusion-matrix.from sklearn.metrics import confusion_matrix
defplot_confusion_matrix(cls_pred):# This is called from print_test_accuracy() below.# cls_pred is an array of the predicted class-number for# all images in the test-set.# Get the confusion matrix using sklearn.
cm = confusion_matrix(y_true=cls_test, # True class for test-set.
y_pred=cls_pred) # Predicted class.# Print the confusion matrix as text.for i in range(num_classes):
# Append the class-name to each line.
class_name = "({}) {}".format(i, class_names[i])
print(cm[i, :], class_name)
# Print the class-numbers for easy reference.
class_numbers = [" ({0})".format(i) for i in range(num_classes)]
print("".join(class_numbers))複製代碼
計算分類的幫助函數
這個函數用來計算圖像的預測類別,同時返回一個表明每張圖像分類是否正確的布爾數組。
因爲計算可能會耗費太多內存,就分批處理。若是你的電腦死機了,試着下降batch-size。
# Split the data-set in batches of this size to limit RAM usage.
batch_size = 256defpredict_cls(transfer_values, labels, cls_true):# Number of images.
num_images = len(transfer_values)
# Allocate an array for the predicted classes which# will be calculated in batches and filled into this array.
cls_pred = np.zeros(shape=num_images, dtype=np.int)
# Now calculate the predicted classes for the batches.# We will just iterate through all the batches.# There might be a more clever and Pythonic way of doing this.# The starting index for the next batch is denoted i.
i = 0while i < num_images:
# The ending index for the next batch is denoted j.
j = min(i + batch_size, num_images)
# Create a feed-dict with the images and labels# between index i and j.
feed_dict = {x: transfer_values[i:j],
y_true: labels[i:j]}
# Calculate the predicted class using TensorFlow.
cls_pred[i:j] = session.run(y_pred_cls, feed_dict=feed_dict)
# Set the start-index for the next batch to the# end-index of the current batch.
i = j
# Create a boolean array whether each image is correctly classified.
correct = (cls_true == cls_pred)
return correct, cls_pred複製代碼
defclassification_accuracy(correct):# When averaging a boolean array, False means 0 and True means 1.# So we are calculating: number of True / len(correct) which is# the same as the classification accuracy.# Return the classification accuracy# and the number of correct classifications.return correct.mean(), correct.sum()複製代碼
defprint_test_accuracy(show_example_errors=False, show_confusion_matrix=False):# For all the images in the test-set,# calculate the predicted classes and whether they are correct.
correct, cls_pred = predict_cls_test()
# Classification accuracy and the number of correct classifications.
acc, num_correct = classification_accuracy(correct)
# Number of images being classified.
num_images = len(correct)
# Print the accuracy.
msg = "Accuracy on Test-Set: {0:.1%} ({1} / {2})"
print(msg.format(acc, num_correct, num_images))
# Plot some examples of mis-classifications, if desired.if show_example_errors:
print("Example errors:")
plot_example_errors(cls_pred=cls_pred, correct=correct)
# Plot the confusion matrix, if desired.if show_confusion_matrix:
print("Confusion Matrix:")
plot_confusion_matrix(cls_pred=cls_pred)複製代碼
Global Step: 100, Training Batch Accuracy: 82.8% Global Step: 200, Training Batch Accuracy: 90.6% Global Step: 300, Training Batch Accuracy: 90.6% Global Step: 400, Training Batch Accuracy: 95.3% Global Step: 500, Training Batch Accuracy: 85.9% Global Step: 600, Training Batch Accuracy: 84.4% Global Step: 700, Training Batch Accuracy: 90.6% Global Step: 800, Training Batch Accuracy: 93.8% Global Step: 900, Training Batch Accuracy: 92.2% Global Step: 1000, Training Batch Accuracy: 95.3% Global Step: 1100, Training Batch Accuracy: 93.8% Global Step: 1200, Training Batch Accuracy: 90.6% Global Step: 1300, Training Batch Accuracy: 95.3% Global Step: 1400, Training Batch Accuracy: 90.6% Global Step: 1500, Training Batch Accuracy: 90.6% Global Step: 1600, Training Batch Accuracy: 92.2% Global Step: 1700, Training Batch Accuracy: 90.6% Global Step: 1800, Training Batch Accuracy: 92.2% Global Step: 1900, Training Batch Accuracy: 84.4% Global Step: 2000, Training Batch Accuracy: 85.9% Global Step: 2100, Training Batch Accuracy: 87.5% Global Step: 2200, Training Batch Accuracy: 90.6% Global Step: 2300, Training Batch Accuracy: 92.2% Global Step: 2400, Training Batch Accuracy: 95.3% Global Step: 2500, Training Batch Accuracy: 89.1% Global Step: 2600, Training Batch Accuracy: 93.8% Global Step: 2700, Training Batch Accuracy: 87.5% Global Step: 2800, Training Batch Accuracy: 90.6% Global Step: 2900, Training Batch Accuracy: 92.2% Global Step: 3000, Training Batch Accuracy: 96.9% Global Step: 3100, Training Batch Accuracy: 96.9% Global Step: 3200, Training Batch Accuracy: 92.2% Global Step: 3300, Training Batch Accuracy: 95.3% Global Step: 3400, Training Batch Accuracy: 93.8% Global Step: 3500, Training Batch Accuracy: 89.1% Global Step: 3600, Training Batch Accuracy: 89.1% Global Step: 3700, Training Batch Accuracy: 95.3% Global Step: 3800, Training Batch Accuracy: 98.4% Global Step: 3900, Training Batch Accuracy: 89.1% Global Step: 4000, Training Batch Accuracy: 92.2% Global Step: 4100, Training Batch Accuracy: 96.9% Global Step: 4200, Training Batch Accuracy: 100.0% Global Step: 4300, Training Batch Accuracy: 100.0% Global Step: 4400, Training Batch Accuracy: 90.6% Global Step: 4500, Training Batch Accuracy: 95.3% Global Step: 4600, Training Batch Accuracy: 96.9% Global Step: 4700, Training Batch Accuracy: 96.9% Global Step: 4800, Training Batch Accuracy: 96.9% Global Step: 4900, Training Batch Accuracy: 92.2% Global Step: 5000, Training Batch Accuracy: 98.4% Global Step: 5100, Training Batch Accuracy: 93.8% Global Step: 5200, Training Batch Accuracy: 92.2% Global Step: 5300, Training Batch Accuracy: 98.4% Global Step: 5400, Training Batch Accuracy: 98.4% Global Step: 5500, Training Batch Accuracy: 100.0% Global Step: 5600, Training Batch Accuracy: 92.2% Global Step: 5700, Training Batch Accuracy: 98.4% Global Step: 5800, Training Batch Accuracy: 92.2% Global Step: 5900, Training Batch Accuracy: 92.2% Global Step: 6000, Training Batch Accuracy: 93.8% Global Step: 6100, Training Batch Accuracy: 95.3% Global Step: 6200, Training Batch Accuracy: 98.4% Global Step: 6300, Training Batch Accuracy: 98.4% Global Step: 6400, Training Batch Accuracy: 96.9% Global Step: 6500, Training Batch Accuracy: 95.3% Global Step: 6600, Training Batch Accuracy: 96.9% Global Step: 6700, Training Batch Accuracy: 96.9% Global Step: 6800, Training Batch Accuracy: 92.2% Global Step: 6900, Training Batch Accuracy: 96.9% Global Step: 7000, Training Batch Accuracy: 100.0% Global Step: 7100, Training Batch Accuracy: 95.3% Global Step: 7200, Training Batch Accuracy: 96.9% Global Step: 7300, Training Batch Accuracy: 96.9% Global Step: 7400, Training Batch Accuracy: 95.3% Global Step: 7500, Training Batch Accuracy: 95.3% Global Step: 7600, Training Batch Accuracy: 93.8% Global Step: 7700, Training Batch Accuracy: 93.8% Global Step: 7800, Training Batch Accuracy: 95.3% Global Step: 7900, Training Batch Accuracy: 95.3% Global Step: 8000, Training Batch Accuracy: 93.8% Global Step: 8100, Training Batch Accuracy: 95.3% Global Step: 8200, Training Batch Accuracy: 98.4% Global Step: 8300, Training Batch Accuracy: 93.8% Global Step: 8400, Training Batch Accuracy: 98.4% Global Step: 8500, Training Batch Accuracy: 96.9% Global Step: 8600, Training Batch Accuracy: 96.9% Global Step: 8700, Training Batch Accuracy: 98.4% Global Step: 8800, Training Batch Accuracy: 95.3% Global Step: 8900, Training Batch Accuracy: 98.4% Global Step: 9000, Training Batch Accuracy: 98.4% Global Step: 9100, Training Batch Accuracy: 98.4% Global Step: 9200, Training Batch Accuracy: 96.9% Global Step: 9300, Training Batch Accuracy: 100.0% Global Step: 9400, Training Batch Accuracy: 90.6% Global Step: 9500, Training Batch Accuracy: 92.2% Global Step: 9600, Training Batch Accuracy: 98.4% Global Step: 9700, Training Batch Accuracy: 96.9% Global Step: 9800, Training Batch Accuracy: 98.4% Global Step: 9900, Training Batch Accuracy: 98.4% Global Step: 10000, Training Batch Accuracy: 100.0% Time usage: 0:00:32
# This has been commented out in case you want to modify and experiment# with the Notebook without having to restart it.# model.close()# session.close()複製代碼