SVHN的Keras實現

Abstract

SVHN是街景數字的數據集,Google在2013年發表的論文「Multi-digit Number Recognition from Street View Imagery using Deep Convolutional Neural Networks」提供瞭解決方法,並聲稱該方法能夠破解全部的驗證碼。python

本篇博客將簡要的總結這篇論文,並使用Keras實現模型並訓練SVHN數據集。git

這篇文章的方法主要做爲訓練SVHN數據集的一個baseline。做者說他的方法能達到百分之96以上的準確率。github

任務要求

這裏先看一下數據集的樣子,其實也就是Number from Street View(街景數字)。bash

過去的作法

Traditional approaches to solve this problem typically separate out the localization, segmentation, and recognition steps.(過去的作法要經歷三步:定位,分割,而後識別)網絡

做者的作法

Propose a unified approach that integrates these three steps via the use of a deep convolutional neural network that operates directly on the image pixels.(做者把這三個步驟經過一個深度的卷積網絡就能完成)app

做者的貢獻

  • (a) A unified model to localize, segment, and recognize multi-digit numbers from street level photographs
  • (b) A new kind of output layer, providing a conditional probabilistic model of sequences
  • (c) Empirical results that show this model performing best with a deep architecture
  • (d) Reaching human level performance at specific operating thresholds.

問題描述

圖片中的數字:每張圖片的數字是一個字符串序列: s = s_1 , s_2 , . . . , s_n,如上面的第一張圖片結果爲"379",s_1=3, s_2=7, s_3=9學習

字符的長度:定義爲n,絕大多數的長度小於5。做者這裏假設字符的長度最大爲5。測試

實現方法

做者的方法是對於圖片的label,訓練一個機率模型。這裏做者定義:ui

  • S:輸出序列,也就是訓練數據的label。
  • X:輸入的圖片。

這裏的目標也就是經過最大化log P (S | X ),來學習模型P (S | X )this

X其實就是輸入的圖片,這裏看一下S,S是:圖片的數字序列S_1,...,S_N + 數字序列的長度L的一個集合。好比上面的"379"是圖片的數字序列,序列的長度len("379")爲3。那麼S就是"3"+"379",也就是"3379"。

這裏P (S | X )能夠定義爲:字符長度的機率再乘以每一個字符取值的機率。(每一個字符取值是獨立的)。

P(S=s|X)=P(L=n|X)\prod_{i=1}^nP(S_i =s_i |X)

上面的變量都是離散的,L的取值有七種:0,1,2,....,5,比5大;S_i有10種:10個數字。

訓練這個模型,就是在訓練集上最大化log P (S | X ),做者這裏每一個參數都使用一個Softmax層。

s = (l,s_1,...,s_l) = \arg\max_{L,S_1,...,S_L}logP(S | X).

模型結構

下面,看下做者在論文裏面發表的模型。

  • 輸入圖片X是一個128x128x3的圖片。
  • 而後通過一系列的CNN層進行特徵提取,變成了一個含有4096個特徵的向量。
  • 以後根據這4096個特徵,分別讓LS_1S_2S_3S_4S_5分別通過一個Softmax層P(S_i|H)=softmax(W_{S_i}H+b_{S_i})
  • 對於每一個變量,s_i = \arg\max_{S_i}logP(S_i | H).

Keras的實現

代碼見:github.com/nladuo/ml-s…

環境依賴

  • python 3.x
  • TensorFlow 1.11
  • Keras 2.x
  • Pillow
  • h5py

數據下載

首先去ufldl.stanford.edu/housenumber…下載Format 1的數據。

wget http://ufldl.stanford.edu/housenumbers/test.tar.gz
wget http://ufldl.stanford.edu/housenumbers/train.tar.gz
複製代碼

而後對數據進行解壓,解壓後發現會多兩個文件夾:test和train。

tar zvxf test.tar.gz
tar zvxf train.tar.gz
複製代碼

構建數據集

這個數據集可使用h5py讀取。

其中bbox存了圖片數字的框,而name則是圖片的文件名。好比說讀取出來是下面這樣的圖片:
而後經過框框把圖片裁剪一下。

網絡模型

網絡模型這裏分爲卷積層+全鏈接層部分,代碼以下。

卷積層

這裏就是三個卷積層。爲了讓神經網絡接受了參數符合同一分佈,這裏使用了Batch Normalization層,對ConvNet的輸入進行批歸一化

全鏈接層

卷積層最後通過Flatten以後,進入了全鏈接層。全鏈接層最後,輸出了到6個softmax層中,分別表明:字符的長度、第一個字符、第二個字符、第三個字符、第四個字符、第四個字符。

注意:這裏字符的類別是0-10,一共11種,10表明不存在。

訓練與測試

接下來調用fit方法訓練就行了,這裏一共有7個loss和6個accuracy。loss的話是每一個softmax輸出層都有一個,還有個總的。accuracy就是6個softmax層的accuracy了。

最後咱們evaluate一下,能夠6個accuracy都達到了85%以上的準確率了。若是想提升的話,可使用VGG16等結構,網上說能夠提高到百分之97,不過訓練的話估計就要很慢了。

擴展案例:微博驗證碼

另外,筆者這裏還使用該方法給出了新浪微博的登陸驗證碼識別的實現代碼,見:captcha-break/weibo.com

不過不知道如今新浪微博的驗證碼變了沒有,我當時用雲打碼的時候它的驗證碼長下面這個樣子:

參考

相關文章
相關標籤/搜索