master
/ crnn / crnn_keras.py

crnn_keras.py @c1cd888 raw · history · blame

#coding:utf-8
from crnn.utils import strLabelConverter,resizeNormalize

from crnn.network_keras import keras_crnn as CRNN
import tensorflow as tf
graph = tf.get_default_graph()##解决web.py 相关报错问题

from crnn import keys
from config import ocrModelKeras
import numpy as np
def crnnSource():
    alphabet = keys.alphabetChinese##中英文模型
    converter = strLabelConverter(alphabet)
    model = CRNN(32, 1, len(alphabet)+1, 256, 1,lstmFlag=False)
    model.load_weights(ocrModelKeras)
    return model,converter

##加载模型
model,converter = crnnSource()

def crnnOcr(image):
       """
       crnn模型,ocr识别
       image:PIL.Image.convert("L")
       """
       scale = image.size[1]*1.0 / 32
       w = image.size[0] / scale
       w = int(w)
       transformer = resizeNormalize((w, 32))
       image = transformer(image)
       image = image.astype(np.float32)
       image = np.array([[image]])
       global graph
       with graph.as_default():
          preds       = model.predict(image)
       preds = preds[0]
       preds = np.argmax(preds,axis=2).reshape((-1,))
       sim_pred  = converter.decode(preds)
       return sim_pred