master
/ train / text / text-train.ipynb

text-train.ipynb @masterview markup · raw · history · blame

Notebook
In [ ]:
import os
os.chdir('../../')

GPU 设置

In [ ]:
GPUID='0'##调用GPU序号
os.environ["CUDA_VISIBLE_DEVICES"] = GPUID
In [ ]:
import numpy as np
import tensorflow as tf
from glob import glob
from PIL import Image
import cv2
Input =tf.keras.layers.Input
Lambda = tf.keras.layers.Lambda
load_model = tf.keras.models.load_model
Model = tf.keras.models.Model

from apphelper.image import get_box_spilt,read_voc_xml,resize_im,read_singLine_for_yolo
from text.keras_yolo3 import  preprocess_true_boxes, yolo_text
from train.text.utils import get_random_data_ as get_random_data


def data_generator(roots, anchors, num_classes,splitW):
    '''data generator for fit_generator
    @@roots:jpg/png
    '''
    n = len(roots)
    np.random.shuffle(roots)
    scales = [416,608,608,608]##多尺度训练
    i = 0
    j = 0
    m = len(scales)
    while True:
        root = roots[i]
        i+=1
        if i>=n:
            i=0
        scale = scales[j]
        j+=1
        if j>=m:
            j=0
            
        xmlP  = os.path.splitext(root)[0]+'.xml'
        boxes = read_voc_xml(xmlP)
        im    = Image.open(root)
        
        w,h   = resize_im(im.size[0],im.size[1], scale=scale, max_scale=None)
        if max(w,h)>2048:
             w,h   = resize_im(im.size[0],im.size[1], scale=scale, max_scale=2048)
        
        input_shape = (h,w)
        isRoate=True
        rorateDegree=np.random.uniform(-5,5)
        rorateDegree = 0
        newBoxes,newIm = get_box_spilt(boxes, im, w,h,splitW=splitW, isRoate=isRoate, rorateDegree=rorateDegree)
        newBoxes = np.array(newBoxes)
        if len(newBoxes)==0:
            continue
        if np.random.randint(0,100)>70:
            if np.random.randint(0,100)>50:
                ##图像水平翻转
                newBoxes[:,[0,2]] = w-newBoxes[:,[2,0]]
                
                im = Image.fromarray(cv2.flip(np.array(im),1))
            else:
                ##垂直翻转
                newBoxes[:,[1,3]] = h-newBoxes[:,[3,1]]
                
                im = Image.fromarray(cv2.flip(np.array(im),0))
                
        
        maxN = 128##随机选取128个box用于训练
        image_data = []
        box_data = []
        
        image, box = get_random_data(newIm,newBoxes, input_shape,max_boxes=maxN)
        
        image_data = np.array([image])
        box_data = np.array([box])
        y_true = preprocess_true_boxes(box_data, input_shape, anchors, num_classes)
        yield [image_data, *y_true], [np.zeros(1)]*4
        

加载训练数据集,标注XML软件参考https://github.com/cgvict/roLabelImg.git

In [ ]:
val_split = 0.1
root='train/data/text/*/*.[j|p|J]*'
jpgPath   = glob(root)
##剔除为标记的图像
delPaths = []
for p in jpgPath:
    xmlP = os.path.splitext(p)[0]+'.xml'
    if not os.path.exists(xmlP):
        delPaths.append(p)
 
print('total:',len(jpgPath))
jpgPath = list(set(jpgPath) - set(delPaths))
print('total:',len(jpgPath))
np.random.shuffle(jpgPath)


num_val   = int(len(jpgPath)*val_split)
num_train = len(jpgPath) - num_val 

定义anchors及加载训练模型

In [ ]:
## 计算训练集anchors
from train.text.gen_anchors import YOLO_Kmeans## anchors生产
splitW = 8##文本分割最小宽度
#cluster = YOLO_Kmeans(cluster_number=9, root=root, scales=[416, 512, 608, 608, 608, 768, 960, 1024], splitW=splitW)
#8,9, 8,18, 8,31, 8,59, 8,124, 8,351, 8,509, 8,605, 8,800
#print(cluster.anchors)
In [ ]:
## 数据事例
from apphelper.image import xy_rotate_box,box_rotate
def plot_boxes(img,angle, result,color=(0,0,0)):
    tmp = np.array(img)
    c = color
    w,h = img.size
    thick = int((h + w) / 300)
    i = 0
    if angle in [90,270]:
        imgW,imgH = img.size[::-1]
        
    else:
        imgW,imgH = img.size

    for line in result:
        cx =line['cx']
        cy = line['cy']
        degree =line['angle']
        w  = line['w']
        h = line['h']
        x1,y1,x2,y2,x3,y3,x4,y4 = xy_rotate_box(cx, cy, w, h, degree)
        x1,y1,x2,y2,x3,y3,x4,y4 = box_rotate([x1,y1,x2,y2,x3,y3,x4,y4],angle=(360-angle)%360,imgH=imgH,imgW=imgW)
        cx  =np.mean([x1,x2,x3,x4])
        cy  = np.mean([y1,y2,y3,y4])
        cv2.line(tmp,(int(x1),int(y1)),(int(x2),int(y2)),c,1)
        cv2.line(tmp,(int(x2),int(y2)),(int(x3),int(y3)),c,1)
        cv2.line(tmp,(int(x3),int(y3)),(int(x4),int(y4)),c,1)
        cv2.line(tmp,(int(x4),int(y4)),(int(x1),int(y1)),c,1)
        mess=str(i)
        cv2.putText(tmp, mess, (int(cx), int(cy)),0, 1e-3 * h, c, thick // 2)
        i+=1
    return Image.fromarray(tmp)

def plot_box(img,boxes):
    blue = (0, 0, 0) #18
    tmp = np.copy(img)
    for box in boxes:
         cv2.rectangle(tmp, (int(box[0]),int(box[1])), (int(box[2]), int(box[3])), blue, 1) #19
    
    return Image.fromarray(tmp) 

def show(p,scale=608):
    im = Image.open(p)
    xmlP  = p.replace('.jpg','.xml').replace('.png','.xml')
    boxes = read_voc_xml(xmlP)
    im    = Image.open(p)
    w,h   = resize_im(im.size[0],im.size[1], scale=scale, max_scale=4096)
    input_shape = (h,w)
    isRoate=True
    
    rorateDegree=np.random.uniform(-5,5) 
    rorateDegree=0
    newBoxes,newIm = get_box_spilt(boxes, im, sizeW=w, SizeH=h, splitW=splitW, isRoate=isRoate, rorateDegree=rorateDegree)
    return plot_boxes(im,0, boxes,color=(0,0,0)),plot_box(newIm,newBoxes),newBoxes
In [ ]:
a,b,newBoxes = show(jpgPath[9])
In [ ]:
b
In [ ]:
#anchors = cluster.anchors
anchors = '8,9, 8,18, 8,31, 8,59, 8,124, 8,351, 8,509, 8,605, 8,800'
anchors = [float(x) for x in anchors.split(',')]
anchors = np.array(anchors).reshape(-1, 2)
num_anchors = len(anchors)
class_names = ['none','text',]##text 
num_classes = len(class_names)
textModel = yolo_text(num_classes,anchors,train=True)
#textModel.load_weights('models/text.h5')##加载预训练模型权重
In [ ]:
textModel.load_weights('models/text.h5')##加载预训练模型权重
In [ ]:
trainLoad = data_generator(jpgPath[:num_train], anchors, num_classes,splitW)
testLoad  = data_generator(jpgPath[num_train:], anchors, num_classes,splitW)
In [ ]:
adam = tf.keras.optimizers.Adam(lr=0.0005)
textModel.compile(optimizer=adam, loss={'xy_loss':lambda y_true, y_pred:y_pred,
                                        'wh_loss':lambda y_true, y_pred:y_pred,
                                        'confidence_loss':lambda y_true, y_pred:y_pred,
                                        'class_loss':lambda y_true, y_pred:y_pred,
                                   }
                                    )
In [ ]:
textModel.fit_generator(generator=trainLoad, 
                         steps_per_epoch=num_train, 
                         epochs=2,
                         verbose=2, 
                         callbacks=None,
                         validation_data=testLoad, 
                         validation_steps=num_val)
In [ ]:
 
In [ ]:
from text.keras_yolo3 import yolo_text,box_layer,K
from config import kerasTextModel,IMGSIZE,keras_anchors,class_names
from apphelper.image import resize_im,letterbox_image
from PIL import Image
import numpy as np
import tensorflow as tf
graph = tf.get_default_graph()##解决web.py 相关报错问题

anchors = [float(x) for x in keras_anchors.split(',')]
anchors = np.array(anchors).reshape(-1, 2)
num_anchors = len(anchors)

num_classes = len(class_names)
textModelTest = yolo_text(num_classes,anchors)
kerasTextModel = '/tmp/textModel.h5'
textModelTest.load_weights(kerasTextModel)


sess = K.get_session()
image_shape = K.placeholder(shape=(2, ))##图像原尺寸:h,w
input_shape = K.placeholder(shape=(2, ))##图像resize尺寸:h,w
box_score = box_layer([*textModelTest.output,image_shape,input_shape],anchors, num_classes)



def text_detect(img,prob = 0.05):
    im    = Image.fromarray(img)
    scale = IMGSIZE[0]
    w,h   = im.size
    w_,h_ = resize_im(w,h, scale=scale, max_scale=2048)##短边固定为608,长边max_scale<4000
    #boxed_image,f = letterbox_image(im, (w_,h_))
    boxed_image = im.resize((w_,h_), Image.BICUBIC)
    image_data = np.array(boxed_image, dtype='float32')
    image_data /= 255.
    image_data = np.expand_dims(image_data, 0)  # Add batch dimension.
    imgShape   = np.array([[h,w]])
    inputShape = np.array([[h_,w_]])
    
    
    global graph
    with graph.as_default():
         ##定义 graph变量 解决web.py 相关报错问题
         """
         pred = textModel.predict_on_batch([image_data,imgShape,inputShape])
         box,scores = pred[:,:4],pred[:,-1]
         
         """
         box,scores = sess.run(
            [box_score],
            feed_dict={
                textModelTest.input: image_data,
                input_shape: [h_, w_],
                image_shape: [h, w],
                K.learning_phase(): 0
            })[0]
        

    keep = np.where(scores>prob)
    
    box[:, 0:4][box[:, 0:4]<0] = 0
    box[:, 0][box[:, 0]>=w] = w-1
    box[:, 1][box[:, 1]>=h] = h-1
    box[:, 2][box[:, 2]>=w] = w-1
    box[:, 3][box[:, 3]>=h] = h-1
    box = box[keep[0]]

    scores = scores[keep[0]]
    return box,scores
In [ ]:
p='./train/text/26BB94CA21C11AB38BC5FC2E08D140CD.jpg'
IMGSIZE=416,416
img = np.array(Image.open(p))
box,scores = text_detect(img,prob = 0.01)
plot_box(img,box)