Tensorflow學習(五)——多任務學習驗證碼識別實戰

 

1、驗證碼生成

"""
驗證碼生成腳本(使用captcha包提供的ImageCaptcha方法)
"""

from captcha.image import ImageCaptcha

import sys
import random
import numpy as np

"""
使用四位數字驗證碼,固然也能夠加入大小寫字母。四位驗證碼有10000種可能(0000~9999)
可是因爲生成過程具備隨機性,不免出現重複狀況,因此最終生成的驗證碼數量少於10000
"""
number = np.arange(0, 10)
number = [str(x) for x in number]

def random_captcha_text(char_set=number, captcha_size=4):
    # 驗證碼列表
    captcha_text = []
    for i in range(captcha_size):
        c = random.choice(char_set)     # 隨機選中構成名稱
        captcha_text.append(c)          # 加入列表
    return captcha_text

def gen_captcha_text_and_image():
    image = ImageCaptcha()
    # 得到隨機生成的驗證碼
    captcha_text = random_captcha_text()
    # 把驗證碼列表轉爲字符串
    captcha_text = ''.join(captcha_text)
    # 生成驗證碼
    captcha = image.generate(captcha_text)
    image.write(captcha_text, 'captcha/images/' + captcha_text + '.jpg')


num = 10000
for i in range(num):
    gen_captcha_text_and_image()
    sys.stdout.write('\r>> Creating image %d/%d' % (i+1, num))
    sys.stdout.flush()
sys.stdout.write('\n')
sys.stdout.flush()
print('生成完畢')

 

驗證碼存放在 "./captcha/images/’ 目錄下,如圖:
在這裏插入圖片描述
驗證碼圖片以下:python

在這裏插入圖片描述
每張圖片的label就是驗證碼數字,此圖驗證碼數字爲0695因此文件命名爲0695.jpggit

2、製做tfrecord文件

一、關於tfrecord文件:

TFRecords能夠容許你講任意的數據轉換爲TensorFlow所支持的格式, 這種方法可使TensorFlow的數據集更容易與網絡應用架構相匹配。這種建議的方法就是使用TFRecords文件,TFRecords文件包含了[tf.train.Example 協議內存塊(protocol buffer)](協議內存塊包含了字段[Features],你能夠寫一段代碼獲取你的數據, 將數據填入到Example協議內存塊(protocol buffer),將協議內存塊序列化爲一個字符串, 而且經過[tf.python_io.TFRecordWriter class]寫入到TFRecords文件。github

TFRecords文件格式在圖像識別中有很好的使用,其能夠將二進制數據和標籤數據(訓練的類別標籤)數據存儲在同一個文件中,它能夠在模型進行訓練以前經過預處理步驟將圖像轉換爲TFRecords格式,此格式最大的優勢實踐每幅輸入圖像和與之關聯的標籤放在同一個文件中.TFRecords文件是一種二進制文件,其不對數據進行壓縮,因此能夠被快速加載到內存中.格式不支持隨機訪問,所以它適合於大量的數據流,但不適用於快速分片或其餘非連續存取。express

TFrecord文件讀寫方式參考:https://zhuanlan.zhihu.com/p/31992460apache

二、代碼

from PIL import Image
import tensorflow as tf
import numpy as np
import os
import random
import sys

_NUM_TEST = 500
_RANDOM_SEED = 0
DATASET_DIR = 'captcha/images'
TFRECORD_DIR = 'captcha/'


# 判斷tfrecord文件是否存在
def _dataset_exists(dataset_dir):
    for split_name in ['train', 'test']:
        output_filename = os.path.join(dataset_dir, split_name + '.tfrecords')
        if not tf.gfile.Exists(output_filename):
            return False
    return True


def _get_filenames_and_classes(dataset_dir):
    photo_filenames = []
    for filename in os.listdir(dataset_dir):
        # 獲取文件路徑
        path = dataset_dir + '/' + filename
        photo_filenames.append(path)
    return photo_filenames


def bytes_feature(values):  # 格式轉換(字符串)
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def int64_feature(values):  # 格式轉換(64位int)
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def image_to_tfexample(image_date, label0, label1, label2, label3):
    # Abstract base class for protocol message
    return tf.train.Example(features=tf.train.Features(feature={
        'image': bytes_feature(image_date),
        'label0': int64_feature(label0),
        'label1': int64_feature(label1),
        'label2': int64_feature(label2),
        'label3': int64_feature(label3)
    }))


# 把數據轉換成tfrecord格式
def _convert_dataset(split_name, filenames, dataset_dir):
    assert split_name in ['train', 'test']

    with tf.Session() as sess:
        # 定義tfrecord文件的路徑和名稱
        output_filename = os.path.join(TFRECORD_DIR, split_name + '.tfrecords')
        with tf.python_io.TFRecordWriter(output_filename, options=tf.python_io.TFRecordOptions(1)) as tfrecord_writer:
            for i, filename in enumerate(filenames):
                try:
                    sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, len(filenames)))
                    sys.stdout.flush()
                    # 讀取圖片
                    image_data = Image.open(filename)
                    # 根據模型的結構resize
                    image_data = image_data.resize((224, 224))
                    # 灰度轉換
                    image_data = np.array(image_data.convert('L'))
                    # 將圖片轉換爲二進制數據
                    image_data = image_data.tobytes()
                    # 獲取label
                    labels = filename.split('/')[-1][0:4]
                    num_labels = []
                    for j in range(4):
                        num_labels.append(int(labels[j]))
                    # 生成protocol數據類型
                    example = image_to_tfexample(image_data, num_labels[0], num_labels[1],
                                                 num_labels[2], num_labels[3])
                    tfrecord_writer.write(example.SerializeToString())
                except IOError as e:
                    print('Could not read:', filenames[i])
                    print('Error:', e)
                    print('Skip it\n')
    sys.stdout.write('\n')
    sys.stdout.flush()


# 判斷tfrecord文件是否存在
if _dataset_exists(TFRECORD_DIR):
    print('tfrecord文件已經存在')
else:
    # 得到全部圖片
    photo_filenames = _get_filenames_and_classes(DATASET_DIR)
    # 把數據集分割爲訓練集和測試集並打亂
    random.seed(_RANDOM_SEED)
    random.shuffle(photo_filenames)
    training_filenames = photo_filenames[_NUM_TEST:]
    testing_filenames = photo_filenames[:_NUM_TEST]

    # 數據轉換
    _convert_dataset('train', training_filenames, DATASET_DIR)
    _convert_dataset('test', training_filenames, DATASET_DIR)
    print('生成tfrecord文件')

 

說明:DATASET_DIR定義了數據集路徑,TFRECORD_DIR定義了tfrecord文件存放路徑,_NUM_TEST定義了test數據集數量,該程序將全部圖片分爲兩部分,其中得到_NUM_TEST數量的圖像做爲測試數據集。在_convert_dataset()中咱們對圖像數據進行預處理包括灰度轉換、圖像大小轉換已經二進制轉換,這些操做方便了咱們將數據寫入文件以及訓練時候對數據的使用。網絡

最終生成的文件以下:
在這裏插入圖片描述架構

3、驗證碼識別模型訓練

一、驗證碼識別思路

將驗證碼label拆分爲4個app

例若有一個驗證碼爲0782,則拆分後的label以下(採用one-hot編碼,對應位數值置1):less

Label0:1000000000
Label1:0000000100
Label2:0000000010
Label3:0010000000dom

好處:可以使用多任務學習

二、什麼是多任務學習

在這裏插入圖片描述
其中X是輸入,Shared Layer就是一些卷積與池化操做,Task1-4對應四個標籤,產生四個loss,將四個loss求和得總的loss,用優化器優化總的loss,從而下降每一個標籤產生的loss。

三、獲取谷歌提供的alexnet_v2網絡

打開github,搜索 tensorflow/models,以下:
在這裏插入圖片描述
將models文件夾clone下來:
在這裏插入圖片描述
clone完成後,在路徑 「/models/research/silm/」 下找到nets文件夾,將該文件夾拷貝到項目目錄,咱們在訓練過程當中會調用nets文件夾下提供的python代碼(nets_factory.py)
在這裏插入圖片描述

四、修改alexnet.py代碼

修改後代碼以下:

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains a model definition for AlexNet.

This work was first described in:
  ImageNet Classification with Deep Convolutional Neural Networks
  Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton

and later refined in:
  One weird trick for parallelizing convolutional neural networks
  Alex Krizhevsky, 2014

Here we provide the implementation proposed in "One weird trick" and not
"ImageNet Classification", as per the paper, the LRN layers have been removed.

Usage:
  with slim.arg_scope(alexnet.alexnet_v2_arg_scope()):
    outputs, end_points = alexnet.alexnet_v2(inputs)

@@alexnet_v2
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.contrib import slim as contrib_slim

slim = contrib_slim

# pylint: disable=g-long-lambda
trunc_normal = lambda stddev: tf.compat.v1.truncated_normal_initializer(
    0.0, stddev)


def alexnet_v2_arg_scope(weight_decay=0.0005):
  with slim.arg_scope([slim.conv2d, slim.fully_connected],
                      activation_fn=tf.nn.relu,
                      biases_initializer=tf.compat.v1.constant_initializer(0.1),
                      weights_regularizer=slim.l2_regularizer(weight_decay)):
    with slim.arg_scope([slim.conv2d], padding='SAME'):
      with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
        return arg_sc


def alexnet_v2(inputs,
               num_classes=1000,
               is_training=True,
               dropout_keep_prob=0.5,
               spatial_squeeze=True,
               scope='alexnet_v2',
               global_pool=False):
  """AlexNet version 2.

  Described in: http://arxiv.org/pdf/1404.5997v2.pdf
  Parameters from:
  github.com/akrizhevsky/cuda-convnet2/blob/master/layers/
  layers-imagenet-1gpu.cfg

  Note: All the fully_connected layers have been transformed to conv2d layers.
        To use in classification mode, resize input to 224x224 or set
        global_pool=True. To use in fully convolutional mode, set
        spatial_squeeze to false.
        The LRN layers have been removed and change the initializers from
        random_normal_initializer to xavier_initializer.

  Args:
    inputs: a tensor of size [batch_size, height, width, channels].
    num_classes: the number of predicted classes. If 0 or None, the logits layer
    is omitted and the input features to the logits layer are returned instead.
    is_training: whether or not the model is being trained.
    dropout_keep_prob: the probability that activations are kept in the dropout
      layers during training.
    spatial_squeeze: whether or not should squeeze the spatial dimensions of the
      logits. Useful to remove unnecessary dimensions for classification.
    scope: Optional scope for the variables.
    global_pool: Optional boolean flag. If True, the input to the classification
      layer is avgpooled to size 1x1, for any input size. (This is not part
      of the original AlexNet.)

  Returns:
    net: the output of the logits layer (if num_classes is a non-zero integer),
      or the non-dropped-out input to the logits layer (if num_classes is 0
      or None).
    end_points: a dict of tensors with intermediate activations.
  """
  with tf.compat.v1.variable_scope(scope, 'alexnet_v2', [inputs]) as sc:
    end_points_collection = sc.original_name_scope + '_end_points'
    # Collect outputs for conv2d, fully_connected and max_pool2d.
    with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
                        outputs_collections=[end_points_collection]):
      net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
                        scope='conv1')
      net = slim.max_pool2d(net, [3, 3], 2, scope='pool1')
      net = slim.conv2d(net, 192, [5, 5], scope='conv2')
      net = slim.max_pool2d(net, [3, 3], 2, scope='pool2')
      net = slim.conv2d(net, 384, [3, 3], scope='conv3')
      net = slim.conv2d(net, 384, [3, 3], scope='conv4')
      net = slim.conv2d(net, 256, [3, 3], scope='conv5')
      net = slim.max_pool2d(net, [3, 3], 2, scope='pool5')

      # Use conv2d instead of fully_connected layers.
      with slim.arg_scope(
          [slim.conv2d],
          weights_initializer=trunc_normal(0.005),
          biases_initializer=tf.compat.v1.constant_initializer(0.1)):
        net = slim.conv2d(net, 4096, [5, 5], padding='VALID',
                          scope='fc6')
        net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
                           scope='dropout6')
        net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
        # Convert end_points_collection into a end_point dict.
        end_points = slim.utils.convert_collection_to_dict(
            end_points_collection)
        if global_pool:
          net = tf.reduce_mean(
              input_tensor=net, axis=[1, 2], keepdims=True, name='global_pool')
          end_points['global_pool'] = net
        if num_classes:
          net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
                             scope='dropout7')
          net0 = slim.conv2d(
              net,
              num_classes, [1, 1],
              activation_fn=None,
              normalizer_fn=None,
              biases_initializer=tf.compat.v1.zeros_initializer(),
              scope='fc8_0')
          net1 = slim.conv2d(
              net,
              num_classes, [1, 1],
              activation_fn=None,
              normalizer_fn=None,
              biases_initializer=tf.compat.v1.zeros_initializer(),
              scope='fc8_1')
          net2 = slim.conv2d(
              net,
              num_classes, [1, 1],
              activation_fn=None,
              normalizer_fn=None,
              biases_initializer=tf.compat.v1.zeros_initializer(),
              scope='fc8_2')
          net3 = slim.conv2d(
              net,
              num_classes, [1, 1],
              activation_fn=None,
              normalizer_fn=None,
              biases_initializer=tf.compat.v1.zeros_initializer(),
              scope='fc8_3')

          if spatial_squeeze:
            net0 = tf.squeeze(net0, [1, 2], name='fc8_0/squeezed')
          end_points[sc.name + '/fc8_0'] = net0
          if spatial_squeeze:
            net1 = tf.squeeze(net1, [1, 2], name='fc8_1/squeezed')
          end_points[sc.name + '/fc8_1'] = net1
          if spatial_squeeze:
            net2 = tf.squeeze(net2, [1, 2], name='fc8_2/squeezed')
          end_points[sc.name + '/fc8_2'] = net2
          if spatial_squeeze:
            net3 = tf.squeeze(net3, [1, 2], name='fc8_3/squeezed')
          end_points[sc.name + '/fc8_3'] = net3
      return net0, net1, net2, net3, end_points
alexnet_v2.default_image_size = 224

 

說明:網絡中的卷積層和池化層不發生變化,原網絡只有一個net輸出,因爲咱們的驗證碼識別項目將驗證碼拆分紅四個標籤,因此須要四個輸出,所以在源代碼基礎上增長net1 ~ net3輸出。

五、train代碼

"""驗證碼識別
學習模式:多任務學習
網絡模型:alexnet_v2
完成時間:2020-5-1
"""

import tensorflow as tf
from nets import nets_factory


CHAR_SET_LEN = 10  # 不一樣字符數量
IMAGE_HEIGHT = 60  # 圖片高度
IMAGE_WIDTH = 160  # 圖片寬度
BATCH_SIZE = 25
TFRECORD_FILE = 'D:/PycharmProject/StudyDemo/captcha/train.tfrecords'

# placeholder
x = tf.placeholder(tf.float32, [None, 224, 224])
y0 = tf.placeholder(tf.float32, [None])
y1 = tf.placeholder(tf.float32, [None])
y2 = tf.placeholder(tf.float32, [None])
y3 = tf.placeholder(tf.float32, [None])

_learn_rate = tf.Variable(0.003, dtype=tf.float32)


# 從tfrecord文件中讀出數據
def read_and_decode(filename):
    # 生成文件隊列
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader(options=tf.python_io.TFRecordOptions(1))
    # 返回文件名和文件
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features={
        'image': tf.FixedLenFeature([], tf.string),
        'label0': tf.FixedLenFeature([], tf.int64),
        'label1': tf.FixedLenFeature([], tf.int64),
        'label2': tf.FixedLenFeature([], tf.int64),
        'label3': tf.FixedLenFeature([], tf.int64),
    })
    # 獲取圖片數據
    image = tf.decode_raw(features['image'], tf.uint8)
    # tf.train.shuffle_batch的使用必須肯定shape
    image = tf.reshape(image, [224, 224])
    # 圖片預處理
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.subtract(image, 0.5)
    image = tf.multiply(image, 2.0)
    # 獲取label
    label0 = tf.cast(features['label0'], tf.int32)
    label1 = tf.cast(features['label1'], tf.int32)
    label2 = tf.cast(features['label2'], tf.int32)
    label3 = tf.cast(features['label3'], tf.int32)

    return image, label0, label1, label2, label3


# 獲取圖片數據與標籤
image, label0, label1, label2, label3 = read_and_decode(TFRECORD_FILE)
# 使用shuffle_batch隨機打亂張量順序建立批次
image_batch, label_batch0, label_batch1, label_batch2, label_batch3 = tf.train.shuffle_batch(
    [image, label0, label1, label2, label3], batch_size=BATCH_SIZE,
    capacity=50000, min_after_dequeue=10000, num_threads=1
)

# 定義網絡結構
train_network_fn = nets_factory.get_network_fn('alexnet_v2',
                                               num_classes=CHAR_SET_LEN,
                                               weight_decay=0.0005,
                                               is_training=True)
with tf.Session() as sess:
    # input參數要符合Alexnet_v2網絡的要求,因此先作個格式轉換
    X = tf.reshape(x, [BATCH_SIZE, 224, 224, 1])
    # 數據輸入網絡獲得輸出值
    logits0, logits1, logits2, logits3, _ = train_network_fn(X)

    # 把標籤轉換成one_hot形式
    one_hot_labels0 = tf.one_hot(indices=tf.cast(y0, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels1 = tf.one_hot(indices=tf.cast(y1, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels2 = tf.one_hot(indices=tf.cast(y2, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels3 = tf.one_hot(indices=tf.cast(y3, tf.int32), depth=CHAR_SET_LEN)

    # 計算損失值
    loss0 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits0,
                                                                   labels=one_hot_labels0))
    loss1 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits1,
                                                                   labels=one_hot_labels1))
    loss2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits2,
                                                                   labels=one_hot_labels2))
    loss3 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits3,
                                                                   labels=one_hot_labels3))
    # 總和損失值
    total_loss = (loss0 + loss1 + loss2 + loss3) / 4.0
    # 優化器
    optimizer = tf.train.AdamOptimizer(learning_rate=_learn_rate).minimize(total_loss)
    # 計算準確率
    correct_prediction0 = tf.equal(tf.argmax(one_hot_labels0, 1), tf.argmax(logits0, 1))
    accuracy0 = tf.reduce_mean(tf.cast(correct_prediction0, tf.float32))
    correct_prediction1 = tf.equal(tf.argmax(one_hot_labels1, 1), tf.argmax(logits1, 1))
    accuracy1 = tf.reduce_mean(tf.cast(correct_prediction1, tf.float32))
    correct_prediction2 = tf.equal(tf.argmax(one_hot_labels2, 1), tf.argmax(logits2, 1))
    accuracy2 = tf.reduce_mean(tf.cast(correct_prediction2, tf.float32))
    correct_prediction3 = tf.equal(tf.argmax(one_hot_labels3, 1), tf.argmax(logits3, 1))
    accuracy3 = tf.reduce_mean(tf.cast(correct_prediction3, tf.float32))

    # 保存模型
    saver = tf.train.Saver()
    # 初始化變量
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    # 建立一個協調器管理線程
    coord = tf.train.Coordinator()
    # 啓動QueueRunner,此時文件名隊列已經進隊
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(6001):
        # 得到一個批次的數據和標籤
        b_image, b_label0, b_label1, b_label2, b_label3 = sess.run([image_batch,
                                                                    label_batch0,
                                                                    label_batch1,
                                                                    label_batch2,
                                                                    label_batch3])
        # 優化模型
        sess.run(optimizer, feed_dict={
            x: b_image,
            y0: b_label0,
            y1: b_label1,
            y2: b_label2,
            y3: b_label3
        })
        # 每迭代50次計算並打印一次損失值和準確率
        if i % 50 == 0:
            # 每2000次下降學習率
            if i % 2000 == 0:
                sess.run(tf.assign(_learn_rate, _learn_rate / 3))
            acc0, acc1, acc2, acc3, loss_ = sess.run([accuracy0, accuracy1, accuracy2, accuracy3, total_loss],
                                                     feed_dict={
                                                         x: b_image,
                                                         y0: b_label0,
                                                         y1: b_label1,
                                                         y2: b_label2,
                                                         y3: b_label3
                                                     })
            learing_rate = sess.run(_learn_rate)
            print('Iter: %d  loss: %.3f  accuracy:%.2f,%.2f,%.2f,%.2f  learing_rate:%.4f'
                  % (i, loss_, acc0, acc1, acc2, acc3, learing_rate))
            # 中止訓練 / 保存模型
            if i == 6000:   # global_step參數是把訓練次數添加到模型名稱中
                saver.save(sess, './captcha/models/crack_captcha.model', global_step=i)
                break
    coord.request_stop()    # 通知其餘線程關閉
    coord.join(threads)     # 其餘線程關閉後該函數纔可返回

 

代碼概述:從train.tfrecord讀出數據和標籤,打亂,將數據送入alexnet網絡獲得輸出值,將輸出的標籤轉化爲one_hot形式,計算loss,對loss求和得total_loss並用優化器優化,計算準確率,訓練6000次,保存模型。
注意:tfrecords文件讀寫先後數據格式必定要對應,TFRecordWriter和TFRecordReader的options必定要相同,否則容易出現讀寫錯誤,需仔細檢查。

保存的模型以下:
在這裏插入圖片描述
提示:訓練過程較慢,筆者使用NVIDIA 940mx顯卡跑滿2G顯存總共花費13個小時完成訓練,最終準確率達到99%。

4、模型測試

代碼與訓練代碼類似:

import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from nets import nets_factory

# 不一樣字符數量
CHAR_SET_LEN = 10
# 圖片高度和寬度
IMAGE_HEIGHT = 60
IMAGE_WIDTH = 160
# 批次
BATCH_SIZE = 1
# tfrecord文件存放路徑
TFRECORD_FILE = 'captcha/test.tfrecords'

# placeholder
x = tf.placeholder(tf.float32, [None, 224, 224])


# 從tfrecord讀出數據
def read_and_decode(filename):
    # 生成文件隊列
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader(options=tf.python_io.TFRecordOptions(1))
    # 返回文件名和文件
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features={
        'image': tf.FixedLenFeature([], tf.string),
        'label0': tf.FixedLenFeature([], tf.int64),
        'label1': tf.FixedLenFeature([], tf.int64),
        'label2': tf.FixedLenFeature([], tf.int64),
        'label3': tf.FixedLenFeature([], tf.int64),
    })
    # 獲取圖片數據
    image = tf.decode_raw(features['image'], tf.uint8)
    # 沒有通過預處理的灰度圖
    image_raw = tf.reshape(image, [224, 224])
    # tf.train.shuffle_batch的使用必須肯定shape
    image = tf.reshape(image, [224, 224])
    # 圖片預處理
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.subtract(image, 0.5)
    image = tf.multiply(image, 2.0)
    # 獲取label
    label0 = tf.cast(features['label0'], tf.int32)
    label1 = tf.cast(features['label1'], tf.int32)
    label2 = tf.cast(features['label2'], tf.int32)
    label3 = tf.cast(features['label3'], tf.int32)

    return image, image_raw, label0, label1, label2, label3


# 獲取圖片數據與標籤
image, image_raw, label0, label1, label2, label3 = read_and_decode(TFRECORD_FILE)
# 得到批次
image_batch, image_raw_batch, label_batch0, label_batch1, label_batch2, label_batch3 = tf.train.shuffle_batch(
    [image, image_raw, label0, label1, label2, label3], batch_size=BATCH_SIZE,
    capacity=50000, min_after_dequeue=10000, num_threads=1
)

# 定義網絡結構
train_network_fn = nets_factory.get_network_fn('alexnet_v2',
                                               num_classes=CHAR_SET_LEN,
                                               weight_decay=0.0005,
                                               is_training=False)
with tf.Session() as sess:
    # inputs格式[batch_size, height, width, channels]
    X = tf.reshape(x, [BATCH_SIZE, 224, 224, 1])
    # 數據輸入網絡獲得輸出值
    logits0, logits1, logits2, logits3, _ = train_network_fn(X)
    # 預測值
    predict0 = tf.reshape(logits0, [-1, CHAR_SET_LEN])
    predict0 = tf.argmax(predict0, 1)

    predict1 = tf.reshape(logits1, [-1, CHAR_SET_LEN])
    predict1 = tf.argmax(predict1, 1)

    predict2 = tf.reshape(logits2, [-1, CHAR_SET_LEN])
    predict2 = tf.argmax(predict2, 1)

    predict3 = tf.reshape(logits3, [-1, CHAR_SET_LEN])
    predict3 = tf.argmax(predict3, 1)

    # 初始化
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    # 載入模型
    saver = tf.train.Saver()
    saver.restore(sess, './captcha/models/crack_captcha.model-6000')
    # 建立一個協調器管理線程
    coord = tf.train.Coordinator()
    # 啓動QueueRunner,此時文件名隊列已經進隊
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(10):
        # 得到一個批次的數據和標籤
        b_image, b_image_raw, b_label0, b_label1, b_label2, b_label3 = sess.run([image_batch,
                                                                                 image_raw_batch,
                                                                                 label_batch0,
                                                                                 label_batch1,
                                                                                 label_batch2,
                                                                                 label_batch3])
        # 顯示圖片
        img = Image.fromarray(b_image_raw[0], 'L')
        plt.imshow(img)
        plt.axis('off')
        plt.show()
        # 打印標籤
        print('label:', b_label0, b_label1, b_label2, b_label3)
        # 預測
        label0, label1, label2, label3 = sess.run([predict0, predict1, predict2, predict3],
                                                  feed_dict={x: b_image})
        # 打印預測值
        print('predict:', label0, label1, label2, label3)

    # 通知其餘線程關閉
    coord.request_stop()
    coord.join(threads)

 

運行結果:
在這裏插入圖片描述
在這裏插入圖片描述

END

相關文章
相關標籤/搜索