【AI實戰】手把手教你實現文字識別模型(入門篇:驗證碼識別)

文字識別在現實生活中有着很是重要的應用,主要由文字檢測、內容識別兩個關鍵步驟組成,在本博客以前的文章中已介紹了文字檢測、內容識別的經典模型原理(見文章:大話文本檢測經典模型:CTPN大話文本識別經典模型:CRNN),本文主要從實戰的角度介紹如何實現文字識別模型。python

在以前的文章中,已經介紹過了跟文字識別相關的實戰內容:基於MNIST數據集識別手寫數字的實戰內容(見文章:訓練你的第一個AI模型:MNIST手寫數字識別模型),這個相對簡單。今天再介紹文字識別的另外一個經典應用:驗證碼識別,做爲文字識別的實戰入門篇。git

 

驗證碼在手機APP、WEB網站中很是廣泛,主要是爲了防止惡意登陸、刷票、灌水、爬蟲等異常行爲,也多是爲了緩解系統的後臺壓力(例如在秒殺、搶票時,強制要求輸入驗證碼)。本文主要介紹文本型驗證碼的識別,文本型驗證碼由數字、英文大小寫字母,甚至中文隨機組成,再進行變形扭曲、加干擾線、加背景噪音等操做,主要是爲了防止被光學字符識別(OCR)之類的程序自動識別出圖片上的文字而失去效果,以下圖:算法

因爲存在着比較強的干擾信息,所以,直接使用OCR進行識別,效果很不理想,而經過AI可很好地實現這種複雜信息的識別。目前百度等AI開放平臺,也提供了驗證碼識別的開放接口,但因爲驗證碼可由各APP、網站根據任意自定的規則隨機組合生成,所以,這些AI平臺的驗證碼識別開放接口在某些場景下效果很好,在某些場景下可能就失靈了。針對具體的場景,咱們經過本身訓練驗證碼識別的AI模型,能很好地解決該場景下的驗證碼識別問題。網絡

 

下面開始介紹使用Tensorflow構建驗證碼的識別模型,主要步驟以下:app

  • step 1. 獲取驗證碼圖片
  • step 2. 圖片標註
  • step 3. 訓練模型
  • step 4. 模型應用

 

一、獲取驗證碼圖片dom

(1)若是是本身練習的,可直接隨機生成驗證碼圖片做爲基礎數據集。在python裏面使用captcha庫來快速生成驗證碼圖片,經過pip install captcha進行安裝,或者手動下載captcha-0.3-py3-none-any.whl文件進行安裝。(注:anaconda沒法經過conda install 直接安裝captcha,但可以使用anaconda裏面的pip來安裝captcha),核心代碼以下:分佈式

from captcha.image import ImageCaptcha
import random

# 生成驗證碼的字符集
CHAR_SET = ['0','1','2','3','4','5','6','7','8','9']
CHAR_SET_LEN = len(CHAR_SET)

# 驗證碼長度
CAPTCHA_LEN  = 4

for i in range(CHAR_SET_LEN):
    for j in range(CHAR_SET_LEN):
        for k in range(CHAR_SET_LEN):
            for l in range(CHAR_SET_LEN):
                captcha_text = CHAR_SET[i] + CHAR_SET[j] + CHAR_SET[k] + CHAR_SET[l]
                image = ImageCaptcha()
                image.write(captcha_text, '/tmp/mydata/' + captcha_text + '.jpg')

生成的效果以下圖ide

(2)若是是要針對某個網站的驗證碼進行識別的,則可以使用一些工具將對應的驗證碼下載下來。通常網站登陸的界面以下:函數

其中,一般可直接點擊驗證碼圖片,或旁邊的「換一張」按鈕,更換驗證碼圖片。這時,可以使用像「按鍵精靈」之類的模擬鼠標操做的軟件,錄製一段腳本,而後在驗證碼圖片處模擬右鍵鼠標保存圖片,再點擊驗證碼圖片更換新的驗證碼,如此反覆,便可下載該網站的大量驗證碼圖片,用於訓練模型。至於這個下載驗證碼圖片的腳本嘛,爲了避免教壞你們,此處省略500字,嘿嘿~工具

 

二、圖片標註

若是第1步是本身隨機生成驗證碼圖片的,那麼在保存圖片時,文件名即是該驗證碼圖片的文本內容,無須再進行標註。

若是第1步是下載了某個網站的驗證碼圖片的,那麼須要先人工對驗證碼圖片的文本內容進行標註,以方便接下來的模型訓練。可經過觀察,將驗證碼圖片的文本信息記在文件名中(重命名),經過這種方式進行圖片標註,也能夠單獨記錄在文本文件中。

 

三、訓練模型

(1)標籤one-hot編碼

爲了可以將驗證碼圖片的文本信息輸入到卷積神經網絡模型裏面去訓練,須要將文本信息向量化編碼。在這裏使用「熱獨編碼」(one-hot),即便用01編碼表示文本信息。本項目的驗證碼文本長度爲4位,驗證碼編碼由0至9的數字組成,例如驗證碼文本信息爲「1086」,則one-hot編碼時在相應的位置標爲1,其他爲0,以下圖

則「1086」經one-hot編碼後變爲[0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0] 。將驗證碼文本信息進行one-hot編碼的核心代碼以下:

def text2label(text):
    label = np.zeros(CAPTCHA_LEN * CHAR_SET_LEN)
    for i in range(len(text)):
        idx = i * CHAR_SET_LEN + CHAR_SET.index(text[i])
        label[idx] = 1
    return label

(2)讀取圖片文件

讀取驗證碼圖片、驗證碼文本內容(保存在文件名中),並編寫獲取下個批量數據的方法,主要函數以下:

# 獲取驗證碼圖片路徑及文本內容
def get_image_file_name(img_path):
    img_files = []
    img_labels = []
    for root, dirs, files in os.walk(img_path):
        for file in files:
            if os.path.splitext(file)[1] == '.jpg':
                img_files.append(root+'/'+file)
                img_labels.append(text2label(os.path.splitext(file)[0]))
    return img_files,img_labels

# 批量獲取數據
def get_next_batch(img_files,img_labels,batch_size):
    batch_x = np.zeros([batch_size, IMAGE_WIDTH*IMAGE_HEIGHT])
    batch_y = np.zeros([batch_size, CAPTCHA_LEN * CHAR_SET_LEN])

    for i in range(batch_size):
        idx = random.randint(0, len(img_files) - 1)
        file_path = img_files[idx]
        image = cv2.imread(file_path)
        image = cv2.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT))
        image = image.astype(np.float32)
        image = np.multiply(image, 1.0 / 255.0)
        batch_x[i, :] = image
        batch_y[i, :] = img_labels[idx]

    return batch_x,batch_y

(3)構建CNN模型

因爲驗證碼的識別相對比較簡單,借鑑LeNet的網絡結構構建CNN模型,由3個卷積層和1個全鏈接層組成,網絡結構圖以下:

核心代碼以下:

# 圖像尺寸
IMAGE_HEIGHT = 60
IMAGE_WIDTH = 160

# 網絡相關變量
X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH])
Y = tf.placeholder(tf.float32, [None, CAPTCHA_LEN * CHAR_SET_LEN])
keep_prob = tf.placeholder(tf.float32)  # dropout

# 驗證碼 CNN 網絡
def crack_captcha_cnn_network (w_alpha=0.01, b_alpha=0.1):
    x = tf.reshape(X, shape=[-1, IMAGE_HEIGHT, IMAGE_WIDTH, 1])

    w_c1 = tf.Variable(w_alpha * tf.random_normal([3, 3, 1, 32]))
    b_c1 = tf.Variable(b_alpha * tf.random_normal([32]))
    conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1))
    conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv1 = tf.nn.dropout(conv1, keep_prob)

    w_c2 = tf.Variable(w_alpha * tf.random_normal([3, 3, 32, 64]))
    b_c2 = tf.Variable(b_alpha * tf.random_normal([64]))
    conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1, 1, 1, 1], padding='SAME'), b_c2))
    conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv2 = tf.nn.dropout(conv2, keep_prob)

    w_c3 = tf.Variable(w_alpha * tf.random_normal([3, 3, 64, 64]))
    b_c3 = tf.Variable(b_alpha * tf.random_normal([64]))
    conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3))
    conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv3 = tf.nn.dropout(conv3, keep_prob)

    w_d = tf.Variable(w_alpha * tf.random_normal([8 * 20 * 64, 1024]))
    b_d = tf.Variable(b_alpha * tf.random_normal([1024]))
    dense = tf.reshape(conv3, [-1, w_d.get_shape().as_list()[0]])
    dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))
    dense = tf.nn.dropout(dense, keep_prob)

    w_out = tf.Variable(w_alpha * tf.random_normal([1024, CAPTCHA_LEN * CHAR_SET_LEN]))
    b_out = tf.Variable(b_alpha * tf.random_normal([CAPTCHA_LEN * CHAR_SET_LEN]))
    out = tf.add(tf.matmul(dense, w_out), b_out)
    return out

(4)訓練模型

經過設置好模型訓練的迭代輪次、批量獲取樣本數量、學習率等參數,讀取驗證碼圖片集,並隨機劃分出訓練集、測試集,再加載本項目的網絡模型進行訓練,每100步評估一次準確率和保存模型文件。核心代碼以下:

# 模型的相關參數
step_cnt = 200000  # 迭代輪數
batch_size = 16  # 批量獲取樣本數量
learning_rate = 0.0001  # 學習率

# 讀取驗證碼圖片集
img_path = '/tmp/mydata/'
img_files, img_labels = get_image_file_name(img_path)

# 劃分出訓練集、測試集
x_train,x_test,y_train,y_test=train_test_split(img_files,img_labels,test_size=0.2,random_state=33)

# 加載網絡結構
output = crack_captcha_cnn_network()

# 損失函數、優化器
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=output, labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

# 評估準確率
predict = tf.reshape(output, [-1, CAPTCHA_LEN, CHAR_SET_LEN])
max_idx_p = tf.argmax(predict, 2)
max_idx_l = tf.argmax(tf.reshape(Y, [-1, CAPTCHA_LEN, CHAR_SET_LEN]), 2)
correct_pred = tf.equal(max_idx_p, max_idx_l)
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)

for step in range(step_cnt):
    # 訓練模型
        batch_x, batch_y = get_next_batch(x_train, y_train,batch_size)
        _, loss_ = sess.run([optimizer, loss], feed_dict={X: batch_x, Y: batch_y, keep_prob: 0.75})
        print('step:',step, 'loss:',loss_)

        # 每100步評估一次準確率
        if step % 100 == 0:
            batch_x_test, batch_y_test = get_next_batch(x_test, y_test,batch_size)
            acc = sess.run(accuracy, feed_dict={X: batch_x_test, Y: batch_y_test, keep_prob: 1.})
            print('step:',step,'acc:',acc)

            # 保存模型
            saver.save(sess, '/tmp/mymodel/crack_captcha.ctpk', global_step=step)

        step += 1

訓練的過程以下圖所示:

通過一段時間的訓練後,評估的準確率可達到99%以上,能很是準確地識別出驗證碼。

 

四、模型應用

經過加載訓練好後的模型文件,便可輸入圖片進行驗證碼識別,核心代碼以下:

# 加載網絡結構
output = crack_captcha_cnn_network()

saver = tf.train.Saver()
with tf.Session() as sess:
    model_path = '/tmp/mymodel/'
    saver.restore(sess, tf.train.latest_checkpoint(model_path))

    output_rate=tf.reshape(output, [-1, CAPTCHA_LEN, CHAR_SET_LEN])
    predict = tf.argmax(output_rate, 2)
    text_list,rate_list = sess.run([predict,output_rate], feed_dict={X: [captcha_image], keep_prob: 1})   # captcha_image 爲待識別的驗證碼圖片

    tmptext = text_list[0].tolist()
    text=''
    for i in range(len(tmptext)):
        text = text + CHAR_SET[tmptext[i]]

    print('識別結果:',text)

以上就是文字識別的入門實戰內容:驗證碼圖片文本識別。經過本次的學習,可瞭解簡單的文本識別的實現方式。

 

關注本人公衆號「大數據與人工智能Lab」(BigdataAILab),而後回覆「代碼」關鍵字可獲取 完整源代碼

 

推薦相關閱讀

相關文章
相關標籤/搜索