{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## crnn ocr 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "from torch.autograd import Variable\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.optim as optim\n",
    "import torch.utils.data\n",
    "import numpy as np\n",
    "from warpctc_pytorch import CTCLoss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 创建数据软连接"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ln -s /home/lywen/data/ocr ../data/ocr/1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../../')\n",
    "from train.ocr.dataset import PathDataset,randomSequentialSampler,alignCollate\n",
    "from glob import glob\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "roots = glob('./train/data/ocr/*/*.jpg')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练字符集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alphabetChinese = '1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainP,testP = train_test_split(roots,test_size=0.1)##此处未考虑字符平衡划分\n",
    "traindataset = PathDataset(trainP,alphabetChinese)\n",
    "testdataset = PathDataset(testP,alphabetChinese)\n",
    "\n",
    "batchSize = 32\n",
    "workers = 1\n",
    "imgH = 32\n",
    "imgW = 280\n",
    "keep_ratio = True\n",
    "cuda = True\n",
    "ngpu = 1\n",
    "nh =256\n",
    "sampler = randomSequentialSampler(traindataset, batchSize)\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    traindataset, batch_size=batchSize,\n",
    "    shuffle=False, sampler=None,\n",
    "    num_workers=int(workers),\n",
    "    collate_fn=alignCollate(imgH=imgH, imgW=imgW, keep_ratio=keep_ratio))\n",
    "\n",
    "train_iter = iter(train_loader)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载预训练模型权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def weights_init(m):\n",
    "    classname = m.__class__.__name__\n",
    "    if classname.find('Conv') != -1:\n",
    "        m.weight.data.normal_(0.0, 0.02)\n",
    "    elif classname.find('BatchNorm') != -1:\n",
    "        m.weight.data.normal_(1.0, 0.02)\n",
    "        m.bias.data.fill_(0)\n",
    "        \n",
    "from crnn.models.crnn import CRNN\n",
    "from config import ocrModel,LSTMFLAG,GPU\n",
    "model = CRNN(32, 1, len(alphabetChinese)+1, 256, 1,lstmFlag=LSTMFLAG)\n",
    "model.apply(weights_init)\n",
    "preWeightDict = torch.load(ocrModel,map_location=lambda storage, loc: storage)##加入项目训练的权重\n",
    "\n",
    "modelWeightDict = model.state_dict()\n",
    "\n",
    "for k, v in preWeightDict.items():\n",
    "            name = k.replace('module.','') # remove `module.`\n",
    "            if  'rnn.1.embedding' not in name:##不加载最后一层权重\n",
    "                 modelWeightDict[name] = v\n",
    "            \n",
    "model.load_state_dict(modelWeightDict)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##优化器\n",
    "from crnn.util import strLabelConverter\n",
    "lr = 0.1\n",
    "optimizer = optim.Adadelta(model.parameters(), lr=lr)\n",
    "converter = strLabelConverter(''.join(alphabetChinese))\n",
    "criterion = CTCLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from train.ocr.dataset import resizeNormalize\n",
    "from crnn.util import loadData\n",
    "image = torch.FloatTensor(batchSize, 3, imgH, imgH)\n",
    "text = torch.IntTensor(batchSize * 5)\n",
    "length = torch.IntTensor(batchSize)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model.cuda()\n",
    "    model = torch.nn.DataParallel(model, device_ids=[0])##转换为多GPU训练模型\n",
    "    image = image.cuda()\n",
    "    criterion = criterion.cuda()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def trainBatch(net, criterion, optimizer,cpu_images, cpu_texts):\n",
    "    #data = train_iter.next()\n",
    "    #cpu_images, cpu_texts = data\n",
    "    batch_size = cpu_images.size(0)\n",
    "    loadData(image, cpu_images)\n",
    "    t, l = converter.encode(cpu_texts)\n",
    "    \n",
    "    loadData(text, t)\n",
    "    loadData(length, l)\n",
    "    preds = net(image)\n",
    "    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))\n",
    "    cost = criterion(preds, text, preds_size, length) / batch_size\n",
    "    net.zero_grad()\n",
    "    cost.backward()\n",
    "    optimizer.step()\n",
    "    return cost\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def predict(im):\n",
    "    \"\"\"\n",
    "    预测\n",
    "    \"\"\"\n",
    "    image = im.convert('L')\n",
    "    scale = image.size[1]*1.0 / 32\n",
    "    w = image.size[0] / scale\n",
    "    w = int(w)\n",
    "    transformer = resizeNormalize((w, 32))\n",
    "    \n",
    "    image = transformer(image)\n",
    "    if torch.cuda.is_available():\n",
    "        image = image.cuda()\n",
    "    image = image.view(1, *image.size())\n",
    "    image = Variable(image)\n",
    "    preds = model(image)\n",
    "    _, preds = preds.max(2)\n",
    "    preds = preds.transpose(1, 0).contiguous().view(-1)\n",
    "    preds_size = Variable(torch.IntTensor([preds.size(0)]))\n",
    "    sim_pred = converter.decode(preds.data, preds_size.data, raw=False)\n",
    "    return sim_pred\n",
    "   \n",
    "   \n",
    "def val(net, dataset, max_iter=100):\n",
    "\n",
    "    for p in net.parameters():\n",
    "        p.requires_grad = False\n",
    "    net.eval()\n",
    "    i = 0\n",
    "    n_correct = 0\n",
    "    N = len(dataset)\n",
    "    \n",
    "    max_iter = min(max_iter, N)\n",
    "    for i in range(max_iter):\n",
    "        im,label = dataset[np.random.randint(0,N)]\n",
    "        if im.size[0]>1024:\n",
    "            continue\n",
    "        \n",
    "        pred = predict(im)\n",
    "        if pred.strip() ==label:\n",
    "                n_correct += 1\n",
    "\n",
    "    accuracy = n_correct / float(max_iter )\n",
    "    return accuracy\n",
    "    \n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from train.ocr.generic_utils import Progbar\n",
    "##进度条参考 https://github.com/keras-team/keras/blob/master/keras/utils/generic_utils.py\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 冻结预训练模型层参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nepochs = 10\n",
    "acc  = 0\n",
    "\n",
    "interval = len(train_loader)//2##评估模型\n",
    "\n",
    "            \n",
    "for i in range(nepochs):\n",
    "    print('epoch:{}/{}'.format(i,nepochs))\n",
    "    n = len(train_loader)\n",
    "    pbar = Progbar(target=n)\n",
    "    train_iter = iter(train_loader)\n",
    "    loss = 0\n",
    "    for j in range(n):\n",
    "        for p in model.named_parameters():\n",
    "            p[1].requires_grad = True\n",
    "            if 'rnn.1.embedding' in p[0]:\n",
    "               p[1].requires_grad = True\n",
    "            else:\n",
    "                p[1].requires_grad = False##冻结模型层\n",
    "\n",
    "        model.train()\n",
    "        cpu_images, cpu_texts = train_iter.next()\n",
    "        cost = trainBatch(model, criterion, optimizer,cpu_images, cpu_texts)\n",
    "\n",
    "        loss += cost.data.numpy()\n",
    "        \n",
    "        if (j+1)%interval==0:\n",
    "            curAcc = val(model, testdataset, max_iter=1024)\n",
    "            if curAcc>acc:\n",
    "                acc = curAcc\n",
    "                torch.save(model.state_dict(), 'train/ocr/modellstm.pth')\n",
    "            \n",
    "               \n",
    "        pbar.update(j+1,values=[('loss',loss/((j+1)*batchSize)),('acc',acc)])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 释放模型层参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nepochs = 10\n",
    "#acc  = 0\n",
    "\n",
    "interval = len(train_loader)//2##评估模型\n",
    "\n",
    "            \n",
    "for i in range(10,10+nepochs):\n",
    "    print('epoch:{}/{}'.format(i,nepochs))\n",
    "    n = len(train_loader)\n",
    "    pbar = Progbar(target=n)\n",
    "    train_iter = iter(train_loader)\n",
    "    loss = 0\n",
    "    for j in range(n):\n",
    "        for p in model.named_parameters():\n",
    "            p[1].requires_grad = True\n",
    "\n",
    "\n",
    "        model.train()\n",
    "        cpu_images, cpu_texts = train_iter.next()\n",
    "        cost = trainBatch(model, criterion, optimizer,cpu_images, cpu_texts)\n",
    "\n",
    "        loss += cost.data.numpy()\n",
    "        \n",
    "        if (j+1)%interval==0:\n",
    "            curAcc = val(model, testdataset, max_iter=1024)\n",
    "            if curAcc>acc:\n",
    "                acc = curAcc\n",
    "                torch.save(model.state_dict(), 'train/ocr/modellstm.pth')\n",
    "            \n",
    "               \n",
    "        pbar.update(j+1,values=[('loss',loss/((j+1)*batchSize)),('acc',acc)])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N  = len(testdataset)\n",
    "im,label = testdataset[np.random.randint(0,N)]\n",
    "pred = predict(im)\n",
    "print('true:{},pred:{}'.format(label,pred))\n",
    "im\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chineseocr",
   "language": "python",
   "name": "chineseocr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
