Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refinement and refactor: model.py (ocr/model.py) #177

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 71 additions & 81 deletions ocr/model.py
Original file line number Diff line number Diff line change
@@ -1,122 +1,112 @@
# -*- coding: utf-8 -*-
## 修复K.ctc_decode bug 当大量测试时将GPU显存消耗完,导致错误,用decode 替代
###
import os,sys
parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parentdir)
# from PIL import Image
import keras.backend as K

import keys_ocr
import os
import numpy as np
from keras.layers import Flatten, BatchNormalization, Permute, TimeDistributed, Dense, Bidirectional, GRU
from keras.layers import Input, Conv2D, MaxPooling2D, ZeroPadding2D
from keras.layers import Lambda
from keras.layers import (Input, Conv2D, MaxPooling2D, ZeroPadding2D,
BatchNormalization, Permute, TimeDistributed,
Flatten, Bidirectional, GRU, Dense, Lambda)
from keras.models import Model
from keras.optimizers import SGD
import keras.backend as K
import keys_ocr


# from keras.models import load_model


# Define CTC loss function
def ctc_lambda_func(args):
y_pred, labels, input_length, label_length = args
y_pred = y_pred[:, 2:, :]
y_pred = y_pred[:, 2:, :] # Remove first two frames for CTC
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)


# Model architecture
def get_model(height, nclass):
rnnunit = 256
input = Input(shape=(height, None, 1), name='the_input')
m = Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same', name='conv1')(input)
m = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool1')(m)
m = Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same', name='conv2')(m)
m = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool2')(m)
m = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same', name='conv3')(m)
m = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same', name='conv4')(m)

m = ZeroPadding2D(padding=(0, 1))(m)
m = MaxPooling2D(pool_size=(2, 2), strides=(2, 1), padding='valid', name='pool3')(m)

m = Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same', name='conv5')(m)
m = BatchNormalization(axis=1)(m)
m = Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same', name='conv6')(m)
m = BatchNormalization(axis=1)(m)
m = ZeroPadding2D(padding=(0, 1))(m)
m = MaxPooling2D(pool_size=(2, 2), strides=(2, 1), padding='valid', name='pool4')(m)
m = Conv2D(512, kernel_size=(2, 2), activation='relu', padding='valid', name='conv7')(m)
# m的输出维度为HWC?
# 将输入的维度按照给定模式进行重排,例如,当需要将RNN和CNN网络连接时,可能会用到该层
# 将维度转成WHC
m = Permute((2, 1, 3), name='permute')(m)
m = TimeDistributed(Flatten(), name='timedistrib')(m)

m = Bidirectional(GRU(rnnunit, return_sequences=True), name='blstm1')(m)
m = Dense(rnnunit, name='blstm1_out', activation='linear')(m)
m = Bidirectional(GRU(rnnunit, return_sequences=True), name='blstm2')(m)
y_pred = Dense(nclass, name='blstm2_out', activation='softmax')(m)

basemodel = Model(inputs=input, outputs=y_pred)

input_tensor = Input(shape=(height, None, 1), name='the_input')

# CNN layers
x = Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same')(input_tensor)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same')(x)
x = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same')(x)
x = ZeroPadding2D(padding=(0, 1))(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 1))(x)
x = Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same')(x)
x = BatchNormalization(axis=1)(x)
x = Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same')(x)
x = BatchNormalization(axis=1)(x)
x = ZeroPadding2D(padding=(0, 1))(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 1))(x)
x = Conv2D(512, kernel_size=(2, 2), activation='relu', padding='valid')(x)

# Reshape for RNN
x = Permute((2, 1, 3))(x)
x = TimeDistributed(Flatten())(x)

# RNN layers
x = Bidirectional(GRU(rnnunit, return_sequences=True))(x)
x = Dense(rnnunit, activation='linear')(x)
x = Bidirectional(GRU(rnnunit, return_sequences=True))(x)
y_pred = Dense(nclass, activation='softmax')(x)

# Create model for training
basemodel = Model(inputs=input_tensor, outputs=y_pred)

# Define inputs for CTC loss
labels = Input(name='the_labels', shape=[None, ], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])
model = Model(inputs=[input, labels, input_length, label_length], outputs=[loss_out])
model = Model(inputs=[input_tensor, labels, input_length, label_length], outputs=[loss_out])

# Compile model
sgd = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
# model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta')
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd)
# model.summary()
return model, basemodel

return model, basemodel

# Load model
characters = keys_ocr.alphabet[:]
modelPath = os.path.join(os.getcwd(), "ocr/ocr0.2.h5")
# modelPath = '/Users/xiaofeng/Code/Github/dataset/CHINESE_OCR/save_model/my_model_keras.h5'
height = 32
nclass=len(characters)+1
nclass = len(characters) + 1

if os.path.exists(modelPath):
model, basemodel = get_model(height, nclass)
basemodel.load_weights(modelPath)
# model.load_weights(modelPath)


def predict(im):
"""
输入图片,输出keras模型的识别结果
Input an image and return the recognized result from the keras model.
"""
im = im.convert('L')
scale = im.size[1] * 1.0 / 32
w = im.size[0] / scale
w = int(w)
im = im.convert('L') # Convert image to grayscale
scale = im.size[1] / 32.0
w = int(im.size[0] / scale)
im = im.resize((w, 32))

img = np.array(im).astype(np.float32) / 255.0
X = img.reshape((32, w, 1))
X = np.array([X])

# Predict
y_pred = basemodel.predict(X)
y_pred = y_pred[:, 2:, :]
out = decode(y_pred) ##
# out = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(y_pred.shape[0])*y_pred.shape[1], )[0][0])[:, :]

# out = u''.join([characters[x] for x in out[0]])

if len(out) > 0:
while out[0] == u'。':
if len(out) > 1:
out = out[1:]
else:
break
y_pred = y_pred[:, 2:, :] # Remove first two frames
out = decode(y_pred)

# Clean output
out = clean_output(out)
return out


def decode(pred):
charactersS = characters + u' '
charactersS = characters + ' ' # Add space character
t = pred.argmax(axis=2)[0]
length = len(t)
char_list = []
n = len(characters)
for i in range(length):
if t[i] != n and (not (i > 0 and t[i - 1] == t[i])):

for i in range(len(t)):
if t[i] != n and (i == 0 or t[i] != t[i - 1]): # Avoid duplicates
char_list.append(charactersS[t[i]])
return u''.join(char_list)

return ''.join(char_list)

def clean_output(out):
while out and out[0] == '。':
out = out[1:] # Remove leading punctuation
return out