tensorflow LSTM+CTC使用詳解

  最近用tensorflow寫了個OCR的程序,在實現的過程當中,發現本身仍是跳了很多坑,在這裏作一個記錄,便於之後回憶。主要的內容有lstm+ctc具體的輸入輸出,以及TF中的CTC和百度開源的warpCTC在具體使用中的區別。python

正文

輸入輸出

由於我最後要最小化的目標函數就是ctc_loss,因此下面就從如何構造輸入輸出提及。git

tf.nn.ctc_loss

先從TF自帶的tf.nn.ctc_loss提及,官方給的定義以下,所以咱們須要作的就是將圖片的label(須要OCR出的結果),圖片,以及圖片的長度轉換爲label,input,和sequence_length。github

ctc_loss(
labels,
inputs,
sequence_length,
preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
time_major=True
)
input: 輸入(訓練)數據,是一個三維float型的數據結構 [max_time_step , batch_size , num_classes],當修改time_major = False時, [batch_size,max_time_step,num_classes]
整體的數據流:
image_batch
-> [batch_size,max_time_step,num_features]->lstm
-> [batch_size,max_time_step,cell.output_size]->reshape
-> [batch_size*max_time_step,num_hidden]->affine projection A*W+b
-> [batch_size*max_time_step,num_classes]->reshape
-> [batch_size,max_time_step,num_classes]->transpose
-> [max_time_step,batch_size,num_classes]
下面詳細解釋一下,
假如一張圖片有以下shape:[60,160,3],咱們若是讀取灰度圖則shape=[60,160],此時,咱們將其一列做爲feature,那麼共有60個features,160個time_step,這時假設一個batch爲64,那麼咱們此時得到到了一個 [batch_size,max_time_step,num_features] = [64,160,60]的訓練數據。
而後將該訓練數據送入 構建的lstm網絡中,(須要注意的是 dynamic_rnn的輸入數據在一個batch內的長度是固定的,可是不一樣batch之間能夠不一樣,咱們須要給他一個 sequence_length(長度爲batch_size的向量)來記錄本次batch數據的長度,對於OCR這個問題,sequence_length就是長度爲64,而值爲160的一維向量)
獲得形如 [batch_size,max_time_step,cell.output_size]的輸出,其中cell.output_size == num_hidden。
下面咱們須要作一個線性變換將其送入ctc_loos中進行計算,lstm中不一樣time_step之間共享權值,因此咱們只需定義 W的結構爲 [num_hidden,num_classes]b的結構爲[num_classes]。而 tf.matmul操做中,兩個矩陣相乘階數應當匹配,因此咱們將上一步的輸出reshape成 [batch_size*max_time_step,num_hidden](num_hidden爲本身定義的lstm的unit個數)記爲 A,而後將其作一個線性變換,因而 A*w+b獲得形如 [batch_size*max_time_step,num_classes]而後在reshape回來獲得 [batch_size,max_time_step,num_classes]最後因爲ctc_loss的要求,咱們再作一次轉置,獲得 [max_time_step,batch_size,num_classes]形狀的數據做爲input

labels: 標籤序列
因爲OCR的結果是不定長的,因此label其實是一個稀疏矩陣SparseTensor
其中:api

  • indices:二維int64的矩陣,表明非0的座標點
  • values:二維tensor,表明indice位置的數據值
  • dense_shape:一維,表明稀疏矩陣的大小
    好比有兩幅圖,分別是123,和4567那麼
    indecs = [[0,0],[0,1],[0,2],[1,0],[1,1],[1,2],[1,3]]
    values = [1,2,3,4,5,6,7]
    dense_shape = [2,4]
    表明dense tensor:
    1
    2
    [[1,2,3,0]
    [4,5,6,7]]

seq_len: 在input一節中已經講過,一維數據,[time_step,…,time_step]長度爲batch_size,值爲time_step網絡

相關文章
相關標籤/搜索