import lmdb import cv2 import numpy as np import os def checkImageIsValid(imageBin): if imageBin is None: return False try: imageBuf = np.fromstring(imageBin, dtype=np.uint8) img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) imgH, imgW = img.shape[0], img.shape[1] except: return False else: if imgH * imgW == 0: return False return True def writeCache(env, cache): with env.begin(write=True) as txn: for k, v in cache.items(): txn.put(k, v) def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True): """ Create LMDB dataset for CRNN training. ARGS: outputPath : LMDB output path imagePathList : list of image path labelList : list of corresponding groundtruth texts lexiconList : (optional) list of lexicon lists checkValid : if true, check the validity of every image """ assert (len(imagePathList) == len(labelList)) nSamples = len(imagePathList) env =, map_size=1099511627776) cache = {} cnt = 1 for i in range(nSamples): imagePath = ''.join(imagePathList[i]).split()[0].replace('\n', '').replace('\r\n', '') # print(imagePath) label = ''.join(labelList[i]) print(label) # if not os.path.exists(imagePath): # print('%s does not exist' % imagePath) # continue with open('.' + imagePath, 'r') as f: imageBin = if checkValid: if not checkImageIsValid(imageBin): print('%s is not a valid image' % imagePath) continue imageKey = 'image-%09d' % cnt labelKey = 'label-%09d' % cnt cache[imageKey] = imageBin cache[labelKey] = label if lexiconList: lexiconKey = 'lexicon-%09d' % cnt cache[lexiconKey] = ' '.join(lexiconList[i]) if cnt % 1000 == 0: writeCache(env, cache) cache = {} print('Written %d / %d' % (cnt, nSamples)) cnt += 1 print(cnt) nSamples = cnt - 1 cache['num-samples'] = str(nSamples) writeCache(env, cache) print('Created dataset with %d samples' % nSamples) OUT_PATH = '../crnn_train_lmdb' IN_PATH = './train.txt' if __name__ == '__main__': outputPath = OUT_PATH if not os.path.exists(OUT_PATH): os.mkdir(OUT_PATH) imgdata = open(IN_PATH) imagePathList = list(imgdata) labelList = [] for line in imagePathList: word = line.split()[1] labelList.append(word) createDataset(outputPath, imagePathList, labelList)
class strLabelConverter(object): """Convert between str and label. NOTE: Insert `blank` to the alphabet for CTC. Args: alphabet (str): set of the possible characters. ignore_case (bool, default=True): whether or not to ignore all of the case. """ def __init__(self, alphabet, ignore_case=False): self._ignore_case = ignore_case if self._ignore_case: alphabet = alphabet.lower() self.alphabet = alphabet + '-' # for `-1` index self.dict = {} for i, char in enumerate(alphabet): # NOTE: 0 is reserved for 'blank' required by wrap_ctc self.dict[char] = i + 1 def encode(self, text): """Support batch or single str. Args: text (str or list of str): texts to convert. Returns: torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. torch.IntTensor [n]: length of each text. """ length = [] result = [] for item in text: item = item.decode('utf-8', 'strict') length.append(len(item)) for char in item: index = self.dict[char] result.append(index) text = result # print(text,length) return (torch.IntTensor(text), torch.IntTensor(length)) def decode(self, t, length, raw=False): """Decode encoded texts back into strs. Args: torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. torch.IntTensor [n]: length of each text. Raises: AssertionError: when the texts and its length does not match. Returns: text (str or list of str): texts to convert. """ if length.numel() == 1: length = length[0] assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length) if raw: return ''.join([self.alphabet[i - 1] for i in t]) else: char_list = [] for i in range(length): if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): char_list.append(self.alphabet[t[i] - 1]) return ''.join(char_list) else: # batch mode assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format( t.numel(), length.sum()) texts = [] index = 0 for i in range(length.numel()): l = length[i] texts.append( self.decode( t[index:index + l], torch.IntTensor([l]), raw=raw)) index += l return texts
爲了將特徵輸入到Recurrent Layers,作以下處理:app
以上是理想訓練時的操做,可是CRNN論文提到的網絡輸入是歸一化好的100×32大小的灰度圖像,即高度統一爲32個像素。下面是CRNN的深度神經網絡結構圖,CNN採起了經典的VGG16,值得注意的是,在VGG16的第3第4個max pooling層CRNN採起的是1×2的矩形池化窗口(w×h),這有別於經典的VGG16的2×2的正方形池化窗口,這個改動是由於文本圖像多數都是高較小而寬較長,因此其feature map也是這種高小寬長的矩形形狀,若是使用1×2的池化窗口則更適合英文字母識別(好比區分i和l)。VGG16部分還引入了BatchNormalization模塊,旨在加速模型收斂。還有值得注意一點,CRNN的輸入是灰度圖像,即圖像深度爲1。CNN部分的輸出是512x1x16(c×h×w)的特徵向量。ide
接下來分析RNN層。RNN部分使用了雙向LSTM,隱藏層單元數爲256,CRNN採用了兩層BiLSTM來組成這個RNN層,RNN層的輸出維度將是(s,b,class_num) ,其中class_num爲文字類別總數。
值得注意的是:Pytorch裏的LSTM單元接受的輸入都必須是3維的張量(Tensors).每一維表明的意思不能弄錯。第一維體現的是序列(sequence)結構,第二維度體現的是小塊(mini-batch)結構,第三位體現的是輸入的元素(elements of input)。若是在應用中不適用小塊結構,那麼能夠將輸入的張量中該維度設爲1,但必需要體現出這個維度。
input of shape (seq_len, batch, input_size): tensor containing the features of the input sequence. The input can also be a packed variable length sequence. input shape(a,b,c) a:seq_len -> 序列長度 b:batch c:input_size 輸入特徵數目
根據LSTM的輸入要求,咱們要對CNN的輸出作些調整,即把CNN層的輸出調整爲[seq_len, batch, input_size]形式,下面爲具體操做:先使用squeeze函數移除h維度,再使用permute函數調整各維順序,即從原來[w, b, c]的調整爲[seq_len, batch, input_size],具體尺寸爲[16,batch,512],調整好以後便可以將該矩陣送入RNN層。
x = self.cnn(x) b, c, h, w = x.size() # print(x.size()): b,c,h,w assert h == 1 # "the height of conv must be 1" x = x.squeeze(2) # remove h dimension, b *512 * width x = x.permute(2, 0, 1) # [w, b, c] = [seq_len, batch, input_size] x = self.rnn(x)
RNN層輸出格式以下,由於咱們採用的是雙向BiLSTM,因此輸出維度將是hidden_unit * 2
Outputs: output, (h_n, c_n) output of shape (seq_len, batch, num_directions * hidden_size) h_n of shape (num_layers * num_directions, batch, hidden_size) c_n (num_layers * num_directions, batch, hidden_size)
而後咱們再經過線性變換操做self.embedding1 = torch.nn.Linear(hidden_unit * 2, 512)
是的輸出維度再次變爲512,繼續送入第二個LSTM層。第二個LSTM層後繼續接線性操做torch.nn.Linear(hidden_unit * 2, class_num)
import torch import torch.nn.functional as F class Vgg_16(torch.nn.Module): def __init__(self): super(Vgg_16, self).__init__() self.convolution1 = torch.nn.Conv2d(1, 64, 3, padding=1) self.pooling1 = torch.nn.MaxPool2d(2, stride=2) self.convolution2 = torch.nn.Conv2d(64, 128, 3, padding=1) self.pooling2 = torch.nn.MaxPool2d(2, stride=2) self.convolution3 = torch.nn.Conv2d(128, 256, 3, padding=1) self.convolution4 = torch.nn.Conv2d(256, 256, 3, padding=1) self.pooling3 = torch.nn.MaxPool2d((1, 2), stride=(2, 1)) # notice stride of the non-square pooling self.convolution5 = torch.nn.Conv2d(256, 512, 3, padding=1) self.BatchNorm1 = torch.nn.BatchNorm2d(512) self.convolution6 = torch.nn.Conv2d(512, 512, 3, padding=1) self.BatchNorm2 = torch.nn.BatchNorm2d(512) self.pooling4 = torch.nn.MaxPool2d((1, 2), stride=(2, 1)) self.convolution7 = torch.nn.Conv2d(512, 512, 2) def forward(self, x): x = F.relu(self.convolution1(x), inplace=True) x = self.pooling1(x) x = F.relu(self.convolution2(x), inplace=True) x = self.pooling2(x) x = F.relu(self.convolution3(x), inplace=True) x = F.relu(self.convolution4(x), inplace=True) x = self.pooling3(x) x = self.convolution5(x) x = F.relu(self.BatchNorm1(x), inplace=True) x = self.convolution6(x) x = F.relu(self.BatchNorm2(x), inplace=True) x = self.pooling4(x) x = F.relu(self.convolution7(x), inplace=True) return x # b*512x1x16 class RNN(torch.nn.Module): def __init__(self, class_num, hidden_unit): super(RNN, self).__init__() self.Bidirectional_LSTM1 = torch.nn.LSTM(512, hidden_unit, bidirectional=True) self.embedding1 = torch.nn.Linear(hidden_unit * 2, 512) self.Bidirectional_LSTM2 = torch.nn.LSTM(512, hidden_unit, bidirectional=True) self.embedding2 = torch.nn.Linear(hidden_unit * 2, class_num) def forward(self, x): x = self.Bidirectional_LSTM1(x) # LSTM output: output, (h_n, c_n) T, b, h = x[0].size() # x[0]: (seq_len, batch, num_directions * hidden_size) x = self.embedding1(x[0].view(T * b, h)) # pytorch view() reshape as [T * b, nOut] x = x.view(T, b, -1) # [16, b, 512] x = self.Bidirectional_LSTM2(x) T, b, h = x[0].size() x = self.embedding2(x[0].view(T * b, h)) x = x.view(T, b, -1) return x # [16,b,class_num] # output: [s,b,class_num] class CRNN(torch.nn.Module): def __init__(self, class_num, hidden_unit=256): super(CRNN, self).__init__() self.cnn = torch.nn.Sequential() self.cnn.add_module('vgg_16', Vgg_16()) self.rnn = torch.nn.Sequential() self.rnn.add_module('rnn', RNN(class_num, hidden_unit)) def forward(self, x): x = self.cnn(x) b, c, h, w = x.size() # print(x.size()): b,c,h,w assert h == 1 # "the height of conv must be 1" x = x.squeeze(2) # remove h dimension, b *512 * width x = x.permute(2, 0, 1) # [w, b, c] = [seq_len, batch, input_size] # x = x.transpose(0, 2) # x = x.transpose(1, 2) x = self.rnn(x) return x
剛剛完成了CNN層和RNN層的設計,如今開始設計轉錄層,即將RNN層輸出的結果翻譯成最終的識別文字結果,從而實現不定長的文字識別。pytorch沒有內置的CTC loss,因此只能去Github下載別人實現的CTC loss來完成損失函數部分的設計。安裝CTC-loss的方式以下:
git clone cd warp-ctc mkdir build; cd build cmake .. make cd ../pytorch_binding/ python install cd ../build cp ../../usr/lib
待安裝完畢後,咱們能夠直接調用CTC loss了,以一個小例子來講明ctc loss的用法。
import torch from warpctc_pytorch import CTCLoss ctc_loss = CTCLoss() # expected shape of seqLength x batchSize x alphabet_size probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous() labels = torch.IntTensor([1, 2]) label_sizes = torch.IntTensor([2]) probs_sizes = torch.IntTensor([2]) probs.requires_grad_(True) # tells autograd to compute gradients for probs cost = ctc_loss(probs, labels, probs_sizes, label_sizes) cost.backward()
CTCLoss(size_average=False, length_average=False) # size_average (bool): normalize the loss by the batch size (default: False) # length_average (bool): normalize the loss by the total number of frames in the batch. If True, supersedes size_average (default: False) forward(acts, labels, act_lens, label_lens) # acts: Tensor of (seqLength x batch x outputDim) containing output activations from network (before softmax) # labels: 1 dimensional Tensor containing all the targets of the batch in one large sequence # act_lens: Tensor of size (batch) containing size of each output sequence from the network # label_lens: Tensor of (batch) containing label length of each example
從上面的代碼能夠看出,CTCLoss的輸入爲[probs, labels, probs_sizes, label_sizes],即預測結果、標籤、預測結果的數目和標籤數目。那麼咱們仿照這個例子開始設計CRNN的CTC LOSS。
preds = net(image) preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) # preds.size(0)=w=16 cost = criterion(preds, text, preds_size, length) / batch_size # 這裏的length就是包含每一個文本標籤的長度的list,除以batch_size來求平均loss cost.backward()
def trainBatch(net, criterion, optimizer, train_iter): data = cpu_images, cpu_texts = data batch_size = cpu_images.size(0) lib.dataset.loadData(image, cpu_images) t, l = converter.encode(cpu_texts) lib.dataset.loadData(text, t) lib.dataset.loadData(length, l) preds = net(image) #print("preds.size=%s" % preds.size) preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) # preds.size(0)=w=22 cost = criterion(preds, text, preds_size, length) / batch_size # length= a list that contains the len of text label in a batch net.zero_grad() cost.backward() optimizer.step() return cost
criterion = CTCLoss() net = Net.CRNN(n_class) print(net) net.apply(lib.utility.weights_init) image = torch.FloatTensor(Config.batch_size, 3, Config.img_height, Config.img_width) text = torch.IntTensor(Config.batch_size * 5) length = torch.IntTensor(Config.batch_size) if cuda: net.cuda() image = image.cuda() criterion = criterion.cuda() image = Variable(image) text = Variable(text) length = Variable(length) loss_avg = lib.utility.averager() optimizer = optim.RMSprop(net.parameters(), #optimizer = optim.Adadelta(net.parameters(), #optimizer = optim.Adam(net.parameters(),, #betas=(Config.beta1, 0.999)) for epoch in range(Config.epoch): train_iter = iter(train_loader) i = 0 while i < len(train_loader): for p in net.parameters(): p.requires_grad = True net.train() cost = trainBatch(net, criterion, optimizer, train_iter) loss_avg.add(cost) i += 1 if i % Config.display_interval == 0: print('[%d/%d][%d/%d] Loss: %f' % (epoch, Config.epoch, i, len(train_loader), loss_avg.val())) loss_avg.reset() if i % Config.test_interval == 0: val(net, test_dataset, criterion) # do checkpointing if i % Config.save_interval == 0: net.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(Config.model_dir, epoch, i))
import time import torch import os from torch.autograd import Variable import lib.convert import lib.dataset from PIL import Image import as Net import alphabets import sys import Config os.environ['CUDA_VISIBLE_DEVICES'] = "4" crnn_model_path = './bs64_model/netCRNN_9_48000.pth' IMG_ROOT = './test_images' running_mode = 'gpu' alphabet = alphabets.alphabet nclass = len(alphabet) + 1 def crnn_recognition(cropped_image, model): converter = lib.convert.strLabelConverter(alphabet) # 標籤轉換 image = cropped_image.convert('L') # 圖像灰度化 ### Testing images are scaled to have height 32. Widths are # proportionally scaled with heights, but at least 100 pixels w = int(image.size[0] / (280 * 1.0 / Config.infer_img_w)) #scale = image.size[1] * 1.0 / Config.img_height #w = int(image.size[0] / scale) transformer = lib.dataset.resizeNormalize((w, Config.img_height)) image = transformer(image) if torch.cuda.is_available(): image = image.cuda() image = image.view(1, *image.size()) image = Variable(image) model.eval() preds = model(image) _, preds = preds.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) preds_size = Variable(torch.IntTensor([preds.size(0)])) sim_pred = converter.decode(,, raw=False) # 預測輸出解碼成文字 print('results: {0}'.format(sim_pred)) if __name__ == '__main__': # crnn network model = Net.CRNN(nclass) # 載入訓練好的模型,CPU和GPU的載入方式不同,需分開處理 if running_mode == 'gpu' and torch.cuda.is_available(): model = model.cuda() model.load_state_dict(torch.load(crnn_model_path)) else: model.load_state_dict(torch.load(crnn_model_path, map_location='cpu')) print('loading pretrained model from {0}'.format(crnn_model_path)) files = sorted(os.listdir(IMG_ROOT)) # 按文件名排序 for file in files: started = time.time() full_path = os.path.join(IMG_ROOT, file) print("=============================================") print("ocr image is %s" % full_path) image = crnn_recognition(image, model) finished = time.time() print('elapsed time: {0}'.format(finished - started))