3.基於LSTM+CTC實現不定長文本圖片OCR

上一篇實現了圖片CNN多標籤分類(4位定長驗證碼識別任務)python

(地址:juejin.im/post/5c073b…git

本文繼續優化,實現不定長文本圖片的識別任務web

下一篇考慮玩一玩GAN網絡json

本文所用到的10w不定長驗證碼文本數據集百度網盤下載地址(也可以使用下文代碼自行生成): pan.baidu.com/s/11BzIvuT4…flask

利用本文代碼訓練並生成的模型(對應項目中的my-model文件夾): pan.baidu.com/s/1AoKtZVys…bash

項目簡介: 須要預先安裝pip install captcha==0.1.1,pip install opencv-python,pip install flask, pip install tensorflow/pip install tensorflow-gpu) 本文采用LSTM+CTC實現1-10位不定長驗證碼圖片OCR(生成的驗證碼由隨機的1-10位大寫字母組成),本質上是一張圖片多個標籤的分類問題,且每一個圖片的標籤數量不固定(數據以下圖所示) 網絡

0_PIY.png
1_BCAVDPXT.png
2_N.png

總體訓練邏輯: 1,將圖像傳入到LSTM中得到sequence,和sequence的長度(大體的原理是:將圖像的width看作LSTM中的time_step,將圖像的height看作每一個time_step輸入tensor的size) 2,將真實的y_label轉爲稀疏矩陣張量(此處的sparseTensor是個重點,同窗們能夠把代碼中的153行y_train_tmp打印出來觀察一下) 3,損失函數採用tf.nn.ctc_loss,而後對以上兩步得到的數據進行訓練,最終使得損失函數儘量的減少session

關於ctc_loss的原理能夠百度科普一下,它的主要做用能夠大概理解爲將上層網絡預測出的AAABBBBCCDEE收斂成ABBCDE,這裏面牽涉到AAA到底收斂爲幾個A,BBBB又收斂爲幾個B,這也是他的核心app

總體預測邏輯: 1,將圖像傳入到LSTM中得到sequence,和sequence的長度 2,將sequence,sequence的長度輸入到tf.nn.ctc_beam_search_decoder函數預測出稀疏矩陣張量 3,將第二步獲得的稀疏矩陣張量反向轉化爲sequence,並最終解碼成A~Z的大寫字母並輸出框架

後續優化邏輯: 1,能夠在LSTM以前先採用CNN對圖像特徵進行一次提取 2,TF自帶的ctc_loss能夠換成百度開源的Warp_CTC 3,針對少許原始圖片爲AAA結果最終識別爲AA,丟掉了一個A的狀況,是否能夠把原先的標籤['A', 'A', 'A']擴充爲['A-left', 'A-middle', 'A-right', 'A-left', 'A-middle', 'A-right', 'A-left', 'A-middle', 'A-right']將每一個字由原先的1個標籤擴充爲三個標籤,此處拋磚引玉,能夠自行嘗試優化

優缺點: 1,LSTM+CTC考慮了一行文本從左到右的序列關係,這一點上比CNN更強,同時能夠輕鬆實現不定長的OCR 2,也正是因爲RNN網絡考慮了時序間的關係,因此運算量相對於CNN網絡大幅增長,收斂比較慢,有條件的同窗仍是上一塊好點的GPU吧,能提高不少效率

運行命令: 自行生成驗證碼訓練寄(本文生成了10w張,修改self.im_total_num變量):

python LstmCtcOcr.py create_dataset

對數據集進行訓練: python LstmCtcOcr.py train

對新的圖片進行測試: python LstmCtcOcr.py test

啓動成http服務: python LstmCtcOcr.py start

利用flask框架將整個項目啓動成web服務,使得項目支持http方式調用 啓動服務後調用如下地址測試

http://127.0.0.1:5050/captchaOcr?img_path=./dataset/test/0_PIY.png

http://127.0.0.1:5050/captchaOcr?img_path=./dataset/test/1_BCAVDPXT.png

http://127.0.0.1:5050/captchaOcr?img_path=./dataset/test/2_N.png

項目目錄結構:

項目結構.png

訓練200個epoch以後,能夠看到model在val上的acc已經能達到84%了,後續你們能夠自行修改學習率和增大epoch次數來提高精度(True表示預測正確,左邊爲預測值,右邊爲真實標籤):

lstm-ctc-199-epoch.png

總體代碼以下(LstmCtcOcr.py文件):

# coding:utf-8

from captcha.image import ImageCaptcha
import numpy as np
import cv2
import tensorflow as tf
import random, os, sys
import operator


from flask import request
from flask import Flask
import json
app = Flask(__name__)

class LstmCtcOcr:
    def __init__(self):
        self.epoch_max = 200  # 最大迭代epoch次數
        self.batch_size = 16  # 訓練時每一個批次參與訓練的圖像數目,顯存不足的能夠調小
        self.lr = 5e-5  # 初始學習率
        self.save_epoch = 5  # 每相隔多少個epoch保存一次模型
        self.n_hidden = 256  # 隱藏神經元個數

        self.im_width = 256
        self.im_height = 64
        self.im_total_num = 100000  # 總共生成的驗證碼圖片數量
        self.train_max_num = self.im_total_num  # 訓練時讀取的最大圖片數目
        self.val_num = 30 * self.batch_size  # 不能大於self.train_max_num 作驗證集用
        self.words_max_num = 10  # 每張驗證碼圖片上的最大字母個數
        self.words = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
        self.n_classes = len(self.words) + 1  # 26個字母 + blank
        self.x = None
        self.y = None

    def captchaOcr(self, img_path):
        """ 驗證碼識別 :param img_path: :return: """
        im = cv2.imread(img_path)
        im = cv2.resize(im, (self.im_width, self.im_height))
        im = np.array([im[:, :, 0]], dtype=np.float32)
        im -= 147
        pred = self.sess.run(self.pred, feed_dict={self.x: im})
        sequence = self.sparseTensor2sequence(pred)
        return ''.join(sequence[0])


    def test(self, img_path):
        """ 測試接口 :param img_path: :return: """
        self.batch_size = 1
        self.learning_rate = tf.placeholder(dtype=tf.float32)  # 動態學習率
        self.weight = tf.Variable(tf.truncated_normal([self.n_hidden, self.n_classes], stddev=0.1))
        self.bias = tf.Variable(tf.constant(0., shape=[self.n_classes]))
        self.x = tf.placeholder(tf.float32, [None, self.im_height, self.im_width])
        logits, seq_len = self.rnnNet(self.x, self.weight, self.bias)
        decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False)
        self.pred = tf.cast(decoded[0], tf.int32)

        saver = tf.train.Saver()
        # tfconfig = tf.ConfigProto(allow_soft_placement=True)
        # tfconfig.gpu_options.per_process_gpu_memory_fraction = 0.3 # 佔用顯存的比例
        # self.ses = tf.Session(config=tfconfig)
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())  # 全局tf變量初始化

        # 加載w,b參數
        saver.restore(self.sess, './my-model/LstmCtcOcr-200')
        im = cv2.imread(img_path)
        im = cv2.resize(im, (self.im_width, self.im_height))
        im = np.array([im[:, :, 0]], dtype=np.float32)
        im -= 147
        pred = self.sess.run(self.pred, feed_dict={self.x: im})
        sequence = self.sparseTensor2sequence(pred)
        print(''.join(sequence[0]))


    def train(self):
        """ 訓練 :return: """
        x_train_list, y_train_list, x_val_list, y_val_list = self.getTrainDataset()

        print('開始轉換tensor隊列')
        x_train_list_tensor = tf.convert_to_tensor(x_train_list, dtype=tf.string)
        y_train_list_tensor = tf.convert_to_tensor(y_train_list, dtype=tf.int32)

        x_val_list_tensor = tf.convert_to_tensor(x_val_list, dtype=tf.string)
        y_val_list_tensor = tf.convert_to_tensor(y_val_list, dtype=tf.int32)

        x_train_queue = tf.train.slice_input_producer(tensor_list=[x_train_list_tensor], shuffle=False)
        y_train_queue = tf.train.slice_input_producer(tensor_list=[y_train_list_tensor], shuffle=False)

        x_val_queue = tf.train.slice_input_producer(tensor_list=[x_val_list_tensor], shuffle=False)
        y_val_queue = tf.train.slice_input_producer(tensor_list=[y_val_list_tensor], shuffle=False)

        train_im, train_label = self.dataset_opt(x_train_queue, y_train_queue)
        train_batch = tf.train.batch(tensors=[train_im, train_label], batch_size=self.batch_size, num_threads=2)

        val_im, val_label = self.dataset_opt(x_val_queue, y_val_queue)
        val_batch = tf.train.batch(tensors=[val_im, val_label], batch_size=self.batch_size, num_threads=2)

        print('準備訓練')
        self.learning_rate = tf.placeholder(dtype=tf.float32)  # 動態學習率
        self.weight = tf.Variable(tf.truncated_normal([self.n_hidden, self.n_classes], stddev=0.1))
        self.bias = tf.Variable(tf.constant(0., shape=[self.n_classes]))

        # self.global_step = tf.Variable(0, trainable=False) # 全局步驟計數

        # im_width當作LSTM的time_step ,im_height當作是每一個time_step輸入tensor的size
        self.x = tf.placeholder(tf.float32, [None, self.im_height, self.im_width])
        # 定義ctc_loss須要的稀疏矩陣
        self.y = tf.sparse_placeholder(tf.int32)

        logits, seq_len = self.rnnNet(self.x, self.weight, self.bias)

        # loss
        self.loss = tf.nn.ctc_loss(self.y, logits, seq_len)
        # cost
        self.cost = tf.reduce_mean(self.loss)
        # optimizer
        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.cost)


        # 前面說的劃分塊以後找每塊的類屬機率分佈,ctc_beam_search_decoder方法,是每次找最大的K個機率分佈
        # 還有一種貪心策略是隻找機率最大那個,也就是K=1的狀況ctc_ greedy_decoder
        decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False)
        self.pred = tf.cast(decoded[0], tf.int32)
        self.distance = tf.reduce_mean(tf.edit_distance(self.pred, self.y))

        print('開始訓練')
        saver = tf.train.Saver()  # 保存tf模型
        with tf.Session() as self.sess:
            self.sess.run(tf.global_variables_initializer())
            coordinator = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=self.sess, coord=coordinator)

            batch_max = len(x_train_list) // self.batch_size
            print('batch:', batch_max)
            total_step = 0
            for epoch_num in range(self.epoch_max):
                lr_tmp = self.lr * (1 - (epoch_num / self.epoch_max) ** 2)  # 動態學習率
                print('lr:', lr_tmp)
                for batch_num in range(batch_max):
                    # print(epoch_num, batch_num)
                    x_train_tmp, y_train_tmp = self.sess.run(train_batch)
                    y_train_tmp = self.sequence2sparseTensor(y_train_tmp)  # 將labels轉爲稀疏矩陣張量
                    self.sess.run(self.optimizer, feed_dict={self.x: x_train_tmp, self.y: y_train_tmp, self.learning_rate: lr_tmp})

                    if total_step % 100 == 0 or total_step == 0:
                        print('epoch:%d/%d batch:%d/%d total_step:%d lr:%.10f' % (epoch_num, self.epoch_max, batch_num, batch_max, total_step, lr_tmp))
                        # train部分
                        train_loss, train_distance = self.sess.run([self.cost, self.distance], feed_dict={self.x: x_train_tmp, self.y: y_train_tmp})

                        # val部分
                        val_loss_list, val_distance_list, val_acc_list = [], [], []
                        for i in range(int(self.val_num / self.batch_size)):
                            x_val_tmp, y_val_tmp_true = self.sess.run(val_batch)
                            y_val_tmp = self.sequence2sparseTensor(y_val_tmp_true)  # 將labels轉爲稀疏矩陣張量
                            val_loss, val_distance, val_pred = self.sess.run([self.cost, self.distance, self.pred], feed_dict={self.x: x_val_tmp, self.y: y_val_tmp})
                            val_loss_list.append(val_loss)
                            val_distance_list.append(val_distance)
                            val_sequence = self.sparseTensor2sequence(val_pred)
                            ok = 0.
                            for idx, val_seq in enumerate(val_sequence):
                                val_pred_tmp = [self.words.find(x) if self.words.find(x) > -1 else 26 for x in val_seq]
                                val_y_true_tmp = [x for x in y_val_tmp_true[idx] if x != 26]

                                is_eq = operator.eq(val_pred_tmp, val_y_true_tmp)

                                if idx == 0:
                                    print(is_eq, [self.words[n] for n in val_pred_tmp], '<<==>>', [self.words[n] for n in val_y_true_tmp])

                                if is_eq:
                                    ok += 1
                            val_acc_list.append(ok / len(val_sequence))
                        val_acc_list = np.array(val_acc_list, dtype=np.float32)

                        print('train_loss:%.10f train_distance:%.10f' % (train_loss, train_distance))
                        print(' val_loss:%.10f val_distance:%.10f val_acc:%.10f' % (np.mean(val_loss_list), np.mean(val_distance_list), np.mean(val_acc_list)))
                        print()
                        print()

                    total_step += 1

                # 保存模型
                if (epoch_num + 1) % self.save_epoch == 0:
                    saver.save(self.sess, './my-model/LstmCtcOcr', global_step=(epoch_num + 1))

            coordinator.request_stop()
            coordinator.join(threads)


    def rnnNet(self, inputs, weight, bias):
        """ 獲取LSTM網絡結構 :param inputs: :param weight: :param bias: :return: """
        # 對於tf.nn.dynamic_rnn,默認time_major=false,此時inputs的shape=[batch_size, max_time_steps, features]
        # (batch_size, im_height, im_width) ==> (batch_size, im_width, im_height)
        inputs = tf.transpose(inputs, [0, 2, 1])

        # 變長序列的最大值
        # seq_len = np.ones(self.batch_size) * self.im_width
        seq_len = np.ones(self.batch_size) * self.im_width

        cell = tf.nn.rnn_cell.LSTMCell(self.n_hidden, forget_bias=0.8, state_is_tuple=True)

        # 動態rnn實現輸入變長
        outputs1, _ = tf.nn.dynamic_rnn(cell, inputs, seq_len, dtype=tf.float32)

        # (self.batch_size * self.im_width, self.hidden)
        outputs = tf.reshape(outputs1, [-1, self.n_hidden])

        logits = tf.matmul(outputs, weight) + bias  # w * x + b
        logits = tf.reshape(logits, [self.batch_size, -1, self.n_classes])
        logits = tf.transpose(logits, (1, 0, 2))  # (im_width, batch_size, im_height)
        return logits, seq_len


    def sequence2sparseTensor(self, sequences, dtype=np.int32):
        """ 序列 轉化爲 稀疏矩陣 :param sequences: :param dtype: :return: """
        values, indices= [], []
        for n, seq in enumerate(sequences):
            indices.extend(zip([n] * len(seq), range(len(seq))))
            values.extend(seq)
        indices = np.asarray(indices, dtype=np.int64)
        values = np.asarray(values, dtype=dtype)
        shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)
        return indices, values, shape


    def sparseTensor2sequence(self, sparse_tensor):
        """ 稀疏矩陣 轉化爲 序列 :param sparse_tensor: :return: """
        decoded_indexes = list()
        current_i = 0
        current_seq = []
        for offset, i_and_index in enumerate(sparse_tensor[0]):
            i = i_and_index[0]
            if i != current_i:
                decoded_indexes.append(current_seq)
                current_i = i
                current_seq = list()
            current_seq.append(offset)
        decoded_indexes.append(current_seq)
        result = []
        for index in decoded_indexes:
            result.append(self.sequence2words(index, sparse_tensor))
        return result


    def sequence2words(self, indexes, spars_tensor):
        """ 序列 轉化爲 文本 :param indexes: :param spars_tensor: :return: """
        decoded = []
        for m in indexes:
            str_tmp = self.words[spars_tensor[1][m]]
            decoded.append(str_tmp)
        return decoded


    def dataset_opt(self, x_train_queue, y_train_queue):
        """ 處理圖片和標籤 :param queue: :return: """
        queue = x_train_queue[0]
        contents = tf.read_file('./dataset/train/' + queue)
        im = tf.image.decode_jpeg(contents)
        tf.image.rgb_to_grayscale(im)
        im = tf.image.resize_images(images=im, size=[self.im_height, self.im_width])
        im = tf.reshape(im[:, :, 0], tf.stack([self.im_height, self.im_width]))
        im -= 147  # 去均值化
        return im, y_train_queue[0]


    def getTrainDataset(self):
        train_data_list = os.listdir('./dataset/train/')
        print('共有%d張訓練圖片, 讀取%d張:' % (len(train_data_list), self.train_max_num))
        random.shuffle(train_data_list)  # 打亂順序

        y_val_list, y_train_list = [], []
        x_val_list = train_data_list[:self.val_num]
        for x_val in x_val_list:
            words_tmp = x_val.split('.')[0].split('_')[1]
            words_tmp = words_tmp + '?' * (self.words_max_num - len(words_tmp))
            y_val_list.append([self.words.find(x) if self.words.find(x) > -1 else 26 for x in words_tmp])

        x_train_list = train_data_list[self.val_num:self.train_max_num]
        for x_train in x_train_list:
            words_tmp = x_train.split('.')[0].split('_')[1]
            words_tmp = words_tmp + '?' * (self.words_max_num - len(words_tmp))
            y_train_list.append([self.words.find(x) if self.words.find(x) > -1 else 26 for x in words_tmp])

        return x_train_list, y_train_list, x_val_list, y_val_list


    def createCaptchaDataset(self):
        """ 生成訓練用圖片數據集 :return: """
        image = ImageCaptcha(width=self.im_width, height=self.im_height, font_sizes=(56,))
        for i in range(self.im_total_num):
            words_tmp = ''
            for j in range(random.randint(1, self.words_max_num)):
                words_tmp = words_tmp + random.choice(self.words)
            print(words_tmp, type(words_tmp))
            im_path = './dataset/train/%d_%s.png' % (i, words_tmp)
            print(im_path)
            image.write(words_tmp, im_path)




if __name__ == '__main__':
    opt_type = sys.argv[1:][0]

    instance = LstmCtcOcr()

    if opt_type == 'create_dataset':
        instance.createCaptchaDataset()
    elif opt_type == 'train':
        instance.train()
    elif opt_type == 'test':
        instance.test('./dataset/test/0_PIY.png')
    elif opt_type == 'start':
        # 將session持久化到內存中
        instance.test('./dataset/test/0_PIY.png')

        # 啓動web服務
        # http://127.0.0.1:5050/captchaOcr?img_path=./dataset/test/1_BCAVDPXT.png
        @app.route('/captchaOcr', methods=['GET'])
        def captchaOcr():
            img_path = request.args.to_dict().get('img_path')
            print(img_path)
            ret = instance.captchaOcr(img_path)
            print(ret)
            return json.dumps({'img_path': img_path, 'ocr_ret': ret})

        app.run(host='0.0.0.0', port=5050, debug=False)
複製代碼
相關文章
相關標籤/搜索