{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.chdir('../../')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## GPU 设置"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"GPUID='0'##调用GPU序号\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = GPUID"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"from glob import glob\n",
"from PIL import Image\n",
"import cv2\n",
"Input =tf.keras.layers.Input\n",
"Lambda = tf.keras.layers.Lambda\n",
"load_model = tf.keras.models.load_model\n",
"Model = tf.keras.models.Model\n",
"\n",
"from apphelper.image import get_box_spilt,read_voc_xml,resize_im,read_singLine_for_yolo\n",
"from text.keras_yolo3 import preprocess_true_boxes, yolo_text\n",
"from train.text.utils import get_random_data_ as get_random_data\n",
"\n",
"\n",
"def data_generator(roots, anchors, num_classes,splitW):\n",
" '''data generator for fit_generator\n",
" @@roots:jpg/png\n",
" '''\n",
" n = len(roots)\n",
" np.random.shuffle(roots)\n",
" scales = [416,608,608,608]##多尺度训练\n",
" i = 0\n",
" j = 0\n",
" m = len(scales)\n",
" while True:\n",
" root = roots[i]\n",
" i+=1\n",
" if i>=n:\n",
" i=0\n",
" scale = scales[j]\n",
" j+=1\n",
" if j>=m:\n",
" j=0\n",
" \n",
" xmlP = os.path.splitext(root)[0]+'.xml'\n",
" boxes = read_voc_xml(xmlP)\n",
" im = Image.open(root)\n",
" \n",
" w,h = resize_im(im.size[0],im.size[1], scale=scale, max_scale=None)\n",
" if max(w,h)>2048:\n",
" w,h = resize_im(im.size[0],im.size[1], scale=scale, max_scale=2048)\n",
" \n",
" input_shape = (h,w)\n",
" isRoate=True\n",
" rorateDegree=np.random.uniform(-5,5)\n",
" rorateDegree = 0\n",
" newBoxes,newIm = get_box_spilt(boxes, im, w,h,splitW=splitW, isRoate=isRoate, rorateDegree=rorateDegree)\n",
" newBoxes = np.array(newBoxes)\n",
" if len(newBoxes)==0:\n",
" continue\n",
" if np.random.randint(0,100)>70:\n",
" if np.random.randint(0,100)>50:\n",
" ##图像水平翻转\n",
" newBoxes[:,[0,2]] = w-newBoxes[:,[2,0]]\n",
" \n",
" im = Image.fromarray(cv2.flip(np.array(im),1))\n",
" else:\n",
" ##垂直翻转\n",
" newBoxes[:,[1,3]] = h-newBoxes[:,[3,1]]\n",
" \n",
" im = Image.fromarray(cv2.flip(np.array(im),0))\n",
" \n",
" \n",
" maxN = 128##随机选取128个box用于训练\n",
" image_data = []\n",
" box_data = []\n",
" \n",
" image, box = get_random_data(newIm,newBoxes, input_shape,max_boxes=maxN)\n",
" \n",
" image_data = np.array([image])\n",
" box_data = np.array([box])\n",
" y_true = preprocess_true_boxes(box_data, input_shape, anchors, num_classes)\n",
" yield [image_data, *y_true], [np.zeros(1)]*4\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 加载训练数据集,标注XML软件参考https://github.com/cgvict/roLabelImg.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val_split = 0.1\n",
"root='train/data/text/*/*.[j|p|J]*'\n",
"jpgPath = glob(root)\n",
"##剔除为标记的图像\n",
"delPaths = []\n",
"for p in jpgPath:\n",
" xmlP = os.path.splitext(p)[0]+'.xml'\n",
" if not os.path.exists(xmlP):\n",
" delPaths.append(p)\n",
" \n",
"print('total:',len(jpgPath))\n",
"jpgPath = list(set(jpgPath) - set(delPaths))\n",
"print('total:',len(jpgPath))\n",
"np.random.shuffle(jpgPath)\n",
"\n",
"\n",
"num_val = int(len(jpgPath)*val_split)\n",
"num_train = len(jpgPath) - num_val \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 定义anchors及加载训练模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## 计算训练集anchors\n",
"from train.text.gen_anchors import YOLO_Kmeans## anchors生产\n",
"splitW = 8##文本分割最小宽度\n",
"#cluster = YOLO_Kmeans(cluster_number=9, root=root, scales=[416, 512, 608, 608, 608, 768, 960, 1024], splitW=splitW)\n",
"#8,9, 8,18, 8,31, 8,59, 8,124, 8,351, 8,509, 8,605, 8,800\n",
"#print(cluster.anchors)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## 数据事例\n",
"from apphelper.image import xy_rotate_box,box_rotate\n",
"def plot_boxes(img,angle, result,color=(0,0,0)):\n",
" tmp = np.array(img)\n",
" c = color\n",
" w,h = img.size\n",
" thick = int((h + w) / 300)\n",
" i = 0\n",
" if angle in [90,270]:\n",
" imgW,imgH = img.size[::-1]\n",
" \n",
" else:\n",
" imgW,imgH = img.size\n",
"\n",
" for line in result:\n",
" cx =line['cx']\n",
" cy = line['cy']\n",
" degree =line['angle']\n",
" w = line['w']\n",
" h = line['h']\n",
" x1,y1,x2,y2,x3,y3,x4,y4 = xy_rotate_box(cx, cy, w, h, degree)\n",
" x1,y1,x2,y2,x3,y3,x4,y4 = box_rotate([x1,y1,x2,y2,x3,y3,x4,y4],angle=(360-angle)%360,imgH=imgH,imgW=imgW)\n",
" cx =np.mean([x1,x2,x3,x4])\n",
" cy = np.mean([y1,y2,y3,y4])\n",
" cv2.line(tmp,(int(x1),int(y1)),(int(x2),int(y2)),c,1)\n",
" cv2.line(tmp,(int(x2),int(y2)),(int(x3),int(y3)),c,1)\n",
" cv2.line(tmp,(int(x3),int(y3)),(int(x4),int(y4)),c,1)\n",
" cv2.line(tmp,(int(x4),int(y4)),(int(x1),int(y1)),c,1)\n",
" mess=str(i)\n",
" cv2.putText(tmp, mess, (int(cx), int(cy)),0, 1e-3 * h, c, thick // 2)\n",
" i+=1\n",
" return Image.fromarray(tmp)\n",
"\n",
"def plot_box(img,boxes):\n",
" blue = (0, 0, 0) #18\n",
" tmp = np.copy(img)\n",
" for box in boxes:\n",
" cv2.rectangle(tmp, (int(box[0]),int(box[1])), (int(box[2]), int(box[3])), blue, 1) #19\n",
" \n",
" return Image.fromarray(tmp) \n",
"\n",
"def show(p,scale=608):\n",
" im = Image.open(p)\n",
" xmlP = p.replace('.jpg','.xml').replace('.png','.xml')\n",
" boxes = read_voc_xml(xmlP)\n",
" im = Image.open(p)\n",
" w,h = resize_im(im.size[0],im.size[1], scale=scale, max_scale=4096)\n",
" input_shape = (h,w)\n",
" isRoate=True\n",
" \n",
" rorateDegree=np.random.uniform(-5,5) \n",
" rorateDegree=0\n",
" newBoxes,newIm = get_box_spilt(boxes, im, sizeW=w, SizeH=h, splitW=splitW, isRoate=isRoate, rorateDegree=rorateDegree)\n",
" return plot_boxes(im,0, boxes,color=(0,0,0)),plot_box(newIm,newBoxes),newBoxes\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a,b,newBoxes = show(jpgPath[9])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"b"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#anchors = cluster.anchors\n",
"anchors = '8,9, 8,18, 8,31, 8,59, 8,124, 8,351, 8,509, 8,605, 8,800'\n",
"anchors = [float(x) for x in anchors.split(',')]\n",
"anchors = np.array(anchors).reshape(-1, 2)\n",
"num_anchors = len(anchors)\n",
"class_names = ['none','text',]##text \n",
"num_classes = len(class_names)\n",
"textModel = yolo_text(num_classes,anchors,train=True)\n",
"#textModel.load_weights('models/text.h5')##加载预训练模型权重\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"textModel.load_weights('models/text.h5')##加载预训练模型权重"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainLoad = data_generator(jpgPath[:num_train], anchors, num_classes,splitW)\n",
"testLoad = data_generator(jpgPath[num_train:], anchors, num_classes,splitW)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"adam = tf.keras.optimizers.Adam(lr=0.0005)\n",
"textModel.compile(optimizer=adam, loss={'xy_loss':lambda y_true, y_pred:y_pred,\n",
" 'wh_loss':lambda y_true, y_pred:y_pred,\n",
" 'confidence_loss':lambda y_true, y_pred:y_pred,\n",
" 'class_loss':lambda y_true, y_pred:y_pred,\n",
" }\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"textModel.fit_generator(generator=trainLoad, \n",
" steps_per_epoch=num_train, \n",
" epochs=2,\n",
" verbose=2, \n",
" callbacks=None,\n",
" validation_data=testLoad, \n",
" validation_steps=num_val)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from text.keras_yolo3 import yolo_text,box_layer,K\n",
"from config import kerasTextModel,IMGSIZE,keras_anchors,class_names\n",
"from apphelper.image import resize_im,letterbox_image\n",
"from PIL import Image\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"graph = tf.get_default_graph()##解决web.py 相关报错问题\n",
"\n",
"anchors = [float(x) for x in keras_anchors.split(',')]\n",
"anchors = np.array(anchors).reshape(-1, 2)\n",
"num_anchors = len(anchors)\n",
"\n",
"num_classes = len(class_names)\n",
"textModelTest = yolo_text(num_classes,anchors)\n",
"kerasTextModel = '/tmp/textModel.h5'\n",
"textModelTest.load_weights(kerasTextModel)\n",
"\n",
"\n",
"sess = K.get_session()\n",
"image_shape = K.placeholder(shape=(2, ))##图像原尺寸:h,w\n",
"input_shape = K.placeholder(shape=(2, ))##图像resize尺寸:h,w\n",
"box_score = box_layer([*textModelTest.output,image_shape,input_shape],anchors, num_classes)\n",
"\n",
"\n",
"\n",
"def text_detect(img,prob = 0.05):\n",
" im = Image.fromarray(img)\n",
" scale = IMGSIZE[0]\n",
" w,h = im.size\n",
" w_,h_ = resize_im(w,h, scale=scale, max_scale=2048)##短边固定为608,长边max_scale<4000\n",
" #boxed_image,f = letterbox_image(im, (w_,h_))\n",
" boxed_image = im.resize((w_,h_), Image.BICUBIC)\n",
" image_data = np.array(boxed_image, dtype='float32')\n",
" image_data /= 255.\n",
" image_data = np.expand_dims(image_data, 0) # Add batch dimension.\n",
" imgShape = np.array([[h,w]])\n",
" inputShape = np.array([[h_,w_]])\n",
" \n",
" \n",
" global graph\n",
" with graph.as_default():\n",
" ##定义 graph变量 解决web.py 相关报错问题\n",
" \"\"\"\n",
" pred = textModel.predict_on_batch([image_data,imgShape,inputShape])\n",
" box,scores = pred[:,:4],pred[:,-1]\n",
" \n",
" \"\"\"\n",
" box,scores = sess.run(\n",
" [box_score],\n",
" feed_dict={\n",
" textModelTest.input: image_data,\n",
" input_shape: [h_, w_],\n",
" image_shape: [h, w],\n",
" K.learning_phase(): 0\n",
" })[0]\n",
" \n",
"\n",
" keep = np.where(scores>prob)\n",
" \n",
" box[:, 0:4][box[:, 0:4]<0] = 0\n",
" box[:, 0][box[:, 0]>=w] = w-1\n",
" box[:, 1][box[:, 1]>=h] = h-1\n",
" box[:, 2][box[:, 2]>=w] = w-1\n",
" box[:, 3][box[:, 3]>=h] = h-1\n",
" box = box[keep[0]]\n",
"\n",
" scores = scores[keep[0]]\n",
" return box,scores\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"p='./train/text/26BB94CA21C11AB38BC5FC2E08D140CD.jpg'\n",
"IMGSIZE=416,416\n",
"img = np.array(Image.open(p))\n",
"box,scores = text_detect(img,prob = 0.01)\n",
"plot_box(img,box)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "chineseocr",
"language": "python",
"name": "chineseocr"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}