{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../../')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GPU 设置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "GPUID='0'##调用GPU序号\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = GPUID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from glob import glob\n",
    "from PIL import Image\n",
    "import cv2\n",
    "Input =tf.keras.layers.Input\n",
    "Lambda = tf.keras.layers.Lambda\n",
    "load_model = tf.keras.models.load_model\n",
    "Model = tf.keras.models.Model\n",
    "\n",
    "from apphelper.image import get_box_spilt,read_voc_xml,resize_im,read_singLine_for_yolo\n",
    "from text.keras_yolo3 import  preprocess_true_boxes, yolo_text\n",
    "from train.text.utils import get_random_data_ as get_random_data\n",
    "\n",
    "\n",
    "def data_generator(roots, anchors, num_classes,splitW):\n",
    "    '''data generator for fit_generator\n",
    "    @@roots:jpg/png\n",
    "    '''\n",
    "    n = len(roots)\n",
    "    np.random.shuffle(roots)\n",
    "    scales = [416,608,608,608]##多尺度训练\n",
    "    i = 0\n",
    "    j = 0\n",
    "    m = len(scales)\n",
    "    while True:\n",
    "        root = roots[i]\n",
    "        i+=1\n",
    "        if i>=n:\n",
    "            i=0\n",
    "        scale = scales[j]\n",
    "        j+=1\n",
    "        if j>=m:\n",
    "            j=0\n",
    "            \n",
    "        xmlP  = os.path.splitext(root)[0]+'.xml'\n",
    "        boxes = read_voc_xml(xmlP)\n",
    "        im    = Image.open(root)\n",
    "        \n",
    "        w,h   = resize_im(im.size[0],im.size[1], scale=scale, max_scale=None)\n",
    "        if max(w,h)>2048:\n",
    "             w,h   = resize_im(im.size[0],im.size[1], scale=scale, max_scale=2048)\n",
    "        \n",
    "        input_shape = (h,w)\n",
    "        isRoate=True\n",
    "        rorateDegree=np.random.uniform(-5,5)\n",
    "        rorateDegree = 0\n",
    "        newBoxes,newIm = get_box_spilt(boxes, im, w,h,splitW=splitW, isRoate=isRoate, rorateDegree=rorateDegree)\n",
    "        newBoxes = np.array(newBoxes)\n",
    "        if len(newBoxes)==0:\n",
    "            continue\n",
    "        if np.random.randint(0,100)>70:\n",
    "            if np.random.randint(0,100)>50:\n",
    "                ##图像水平翻转\n",
    "                newBoxes[:,[0,2]] = w-newBoxes[:,[2,0]]\n",
    "                \n",
    "                im = Image.fromarray(cv2.flip(np.array(im),1))\n",
    "            else:\n",
    "                ##垂直翻转\n",
    "                newBoxes[:,[1,3]] = h-newBoxes[:,[3,1]]\n",
    "                \n",
    "                im = Image.fromarray(cv2.flip(np.array(im),0))\n",
    "                \n",
    "        \n",
    "        maxN = 128##随机选取128个box用于训练\n",
    "        image_data = []\n",
    "        box_data = []\n",
    "        \n",
    "        image, box = get_random_data(newIm,newBoxes, input_shape,max_boxes=maxN)\n",
    "        \n",
    "        image_data = np.array([image])\n",
    "        box_data = np.array([box])\n",
    "        y_true = preprocess_true_boxes(box_data, input_shape, anchors, num_classes)\n",
    "        yield [image_data, *y_true], [np.zeros(1)]*4\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载训练数据集，标注XML软件参考https://github.com/cgvict/roLabelImg.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_split = 0.1\n",
    "root='train/data/text/*/*.[j|p|J]*'\n",
    "jpgPath   = glob(root)\n",
    "##剔除为标记的图像\n",
    "delPaths = []\n",
    "for p in jpgPath:\n",
    "    xmlP = os.path.splitext(p)[0]+'.xml'\n",
    "    if not os.path.exists(xmlP):\n",
    "        delPaths.append(p)\n",
    " \n",
    "print('total:',len(jpgPath))\n",
    "jpgPath = list(set(jpgPath) - set(delPaths))\n",
    "print('total:',len(jpgPath))\n",
    "np.random.shuffle(jpgPath)\n",
    "\n",
    "\n",
    "num_val   = int(len(jpgPath)*val_split)\n",
    "num_train = len(jpgPath) - num_val \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义anchors及加载训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 计算训练集anchors\n",
    "from train.text.gen_anchors import YOLO_Kmeans## anchors生产\n",
    "splitW = 8##文本分割最小宽度\n",
    "#cluster = YOLO_Kmeans(cluster_number=9, root=root, scales=[416, 512, 608, 608, 608, 768, 960, 1024], splitW=splitW)\n",
    "#8,9, 8,18, 8,31, 8,59, 8,124, 8,351, 8,509, 8,605, 8,800\n",
    "#print(cluster.anchors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 数据事例\n",
    "from apphelper.image import xy_rotate_box,box_rotate\n",
    "def plot_boxes(img,angle, result,color=(0,0,0)):\n",
    "    tmp = np.array(img)\n",
    "    c = color\n",
    "    w,h = img.size\n",
    "    thick = int((h + w) / 300)\n",
    "    i = 0\n",
    "    if angle in [90,270]:\n",
    "        imgW,imgH = img.size[::-1]\n",
    "        \n",
    "    else:\n",
    "        imgW,imgH = img.size\n",
    "\n",
    "    for line in result:\n",
    "        cx =line['cx']\n",
    "        cy = line['cy']\n",
    "        degree =line['angle']\n",
    "        w  = line['w']\n",
    "        h = line['h']\n",
    "        x1,y1,x2,y2,x3,y3,x4,y4 = xy_rotate_box(cx, cy, w, h, degree)\n",
    "        x1,y1,x2,y2,x3,y3,x4,y4 = box_rotate([x1,y1,x2,y2,x3,y3,x4,y4],angle=(360-angle)%360,imgH=imgH,imgW=imgW)\n",
    "        cx  =np.mean([x1,x2,x3,x4])\n",
    "        cy  = np.mean([y1,y2,y3,y4])\n",
    "        cv2.line(tmp,(int(x1),int(y1)),(int(x2),int(y2)),c,1)\n",
    "        cv2.line(tmp,(int(x2),int(y2)),(int(x3),int(y3)),c,1)\n",
    "        cv2.line(tmp,(int(x3),int(y3)),(int(x4),int(y4)),c,1)\n",
    "        cv2.line(tmp,(int(x4),int(y4)),(int(x1),int(y1)),c,1)\n",
    "        mess=str(i)\n",
    "        cv2.putText(tmp, mess, (int(cx), int(cy)),0, 1e-3 * h, c, thick // 2)\n",
    "        i+=1\n",
    "    return Image.fromarray(tmp)\n",
    "\n",
    "def plot_box(img,boxes):\n",
    "    blue = (0, 0, 0) #18\n",
    "    tmp = np.copy(img)\n",
    "    for box in boxes:\n",
    "         cv2.rectangle(tmp, (int(box[0]),int(box[1])), (int(box[2]), int(box[3])), blue, 1) #19\n",
    "    \n",
    "    return Image.fromarray(tmp) \n",
    "\n",
    "def show(p,scale=608):\n",
    "    im = Image.open(p)\n",
    "    xmlP  = p.replace('.jpg','.xml').replace('.png','.xml')\n",
    "    boxes = read_voc_xml(xmlP)\n",
    "    im    = Image.open(p)\n",
    "    w,h   = resize_im(im.size[0],im.size[1], scale=scale, max_scale=4096)\n",
    "    input_shape = (h,w)\n",
    "    isRoate=True\n",
    "    \n",
    "    rorateDegree=np.random.uniform(-5,5) \n",
    "    rorateDegree=0\n",
    "    newBoxes,newIm = get_box_spilt(boxes, im, sizeW=w, SizeH=h, splitW=splitW, isRoate=isRoate, rorateDegree=rorateDegree)\n",
    "    return plot_boxes(im,0, boxes,color=(0,0,0)),plot_box(newIm,newBoxes),newBoxes\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a,b,newBoxes = show(jpgPath[9])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#anchors = cluster.anchors\n",
    "anchors = '8,9, 8,18, 8,31, 8,59, 8,124, 8,351, 8,509, 8,605, 8,800'\n",
    "anchors = [float(x) for x in anchors.split(',')]\n",
    "anchors = np.array(anchors).reshape(-1, 2)\n",
    "num_anchors = len(anchors)\n",
    "class_names = ['none','text',]##text \n",
    "num_classes = len(class_names)\n",
    "textModel = yolo_text(num_classes,anchors,train=True)\n",
    "#textModel.load_weights('models/text.h5')##加载预训练模型权重\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "textModel.load_weights('models/text.h5')##加载预训练模型权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainLoad = data_generator(jpgPath[:num_train], anchors, num_classes,splitW)\n",
    "testLoad  = data_generator(jpgPath[num_train:], anchors, num_classes,splitW)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adam = tf.keras.optimizers.Adam(lr=0.0005)\n",
    "textModel.compile(optimizer=adam, loss={'xy_loss':lambda y_true, y_pred:y_pred,\n",
    "                                        'wh_loss':lambda y_true, y_pred:y_pred,\n",
    "                                        'confidence_loss':lambda y_true, y_pred:y_pred,\n",
    "                                        'class_loss':lambda y_true, y_pred:y_pred,\n",
    "                                   }\n",
    "                                    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "textModel.fit_generator(generator=trainLoad, \n",
    "                         steps_per_epoch=num_train, \n",
    "                         epochs=2,\n",
    "                         verbose=2, \n",
    "                         callbacks=None,\n",
    "                         validation_data=testLoad, \n",
    "                         validation_steps=num_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from text.keras_yolo3 import yolo_text,box_layer,K\n",
    "from config import kerasTextModel,IMGSIZE,keras_anchors,class_names\n",
    "from apphelper.image import resize_im,letterbox_image\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "graph = tf.get_default_graph()##解决web.py 相关报错问题\n",
    "\n",
    "anchors = [float(x) for x in keras_anchors.split(',')]\n",
    "anchors = np.array(anchors).reshape(-1, 2)\n",
    "num_anchors = len(anchors)\n",
    "\n",
    "num_classes = len(class_names)\n",
    "textModelTest = yolo_text(num_classes,anchors)\n",
    "kerasTextModel = '/tmp/textModel.h5'\n",
    "textModelTest.load_weights(kerasTextModel)\n",
    "\n",
    "\n",
    "sess = K.get_session()\n",
    "image_shape = K.placeholder(shape=(2, ))##图像原尺寸:h,w\n",
    "input_shape = K.placeholder(shape=(2, ))##图像resize尺寸:h,w\n",
    "box_score = box_layer([*textModelTest.output,image_shape,input_shape],anchors, num_classes)\n",
    "\n",
    "\n",
    "\n",
    "def text_detect(img,prob = 0.05):\n",
    "    im    = Image.fromarray(img)\n",
    "    scale = IMGSIZE[0]\n",
    "    w,h   = im.size\n",
    "    w_,h_ = resize_im(w,h, scale=scale, max_scale=2048)##短边固定为608,长边max_scale<4000\n",
    "    #boxed_image,f = letterbox_image(im, (w_,h_))\n",
    "    boxed_image = im.resize((w_,h_), Image.BICUBIC)\n",
    "    image_data = np.array(boxed_image, dtype='float32')\n",
    "    image_data /= 255.\n",
    "    image_data = np.expand_dims(image_data, 0)  # Add batch dimension.\n",
    "    imgShape   = np.array([[h,w]])\n",
    "    inputShape = np.array([[h_,w_]])\n",
    "    \n",
    "    \n",
    "    global graph\n",
    "    with graph.as_default():\n",
    "         ##定义 graph变量 解决web.py 相关报错问题\n",
    "         \"\"\"\n",
    "         pred = textModel.predict_on_batch([image_data,imgShape,inputShape])\n",
    "         box,scores = pred[:,:4],pred[:,-1]\n",
    "         \n",
    "         \"\"\"\n",
    "         box,scores = sess.run(\n",
    "            [box_score],\n",
    "            feed_dict={\n",
    "                textModelTest.input: image_data,\n",
    "                input_shape: [h_, w_],\n",
    "                image_shape: [h, w],\n",
    "                K.learning_phase(): 0\n",
    "            })[0]\n",
    "        \n",
    "\n",
    "    keep = np.where(scores>prob)\n",
    "    \n",
    "    box[:, 0:4][box[:, 0:4]<0] = 0\n",
    "    box[:, 0][box[:, 0]>=w] = w-1\n",
    "    box[:, 1][box[:, 1]>=h] = h-1\n",
    "    box[:, 2][box[:, 2]>=w] = w-1\n",
    "    box[:, 3][box[:, 3]>=h] = h-1\n",
    "    box = box[keep[0]]\n",
    "\n",
    "    scores = scores[keep[0]]\n",
    "    return box,scores\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p='./train/text/26BB94CA21C11AB38BC5FC2E08D140CD.jpg'\n",
    "IMGSIZE=416,416\n",
    "img = np.array(Image.open(p))\n",
    "box,scores = text_detect(img,prob = 0.01)\n",
    "plot_box(img,box)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chineseocr",
   "language": "python",
   "name": "chineseocr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
