{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## crnn ocr 模型训练"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import numpy as np\n",
"import torch\n",
"from PIL import Image\n",
"import numpy as np\n",
"from torch.autograd import Variable\n",
"import torch.backends.cudnn as cudnn\n",
"import torch.optim as optim\n",
"import torch.utils.data\n",
"import numpy as np\n",
"from warpctc_pytorch import CTCLoss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 创建数据软连接"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!ln -s /home/lywen/data/ocr ../data/ocr/1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 加载数据集"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.chdir('../../')\n",
"from train.ocr.dataset import PathDataset,randomSequentialSampler,alignCollate\n",
"from glob import glob\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"roots = glob('./train/data/ocr/*/*.jpg')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 训练字符集"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"alphabetChinese = '1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainP,testP = train_test_split(roots,test_size=0.1)##此处未考虑字符平衡划分\n",
"traindataset = PathDataset(trainP,alphabetChinese)\n",
"testdataset = PathDataset(testP,alphabetChinese)\n",
"\n",
"batchSize = 32\n",
"workers = 1\n",
"imgH = 32\n",
"imgW = 280\n",
"keep_ratio = True\n",
"cuda = True\n",
"ngpu = 1\n",
"nh =256\n",
"sampler = randomSequentialSampler(traindataset, batchSize)\n",
"train_loader = torch.utils.data.DataLoader(\n",
" traindataset, batch_size=batchSize,\n",
" shuffle=False, sampler=None,\n",
" num_workers=int(workers),\n",
" collate_fn=alignCollate(imgH=imgH, imgW=imgW, keep_ratio=keep_ratio))\n",
"\n",
"train_iter = iter(train_loader)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 加载预训练模型权重"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def weights_init(m):\n",
" classname = m.__class__.__name__\n",
" if classname.find('Conv') != -1:\n",
" m.weight.data.normal_(0.0, 0.02)\n",
" elif classname.find('BatchNorm') != -1:\n",
" m.weight.data.normal_(1.0, 0.02)\n",
" m.bias.data.fill_(0)\n",
" \n",
"from crnn.models.crnn import CRNN\n",
"from config import ocrModel,LSTMFLAG,GPU\n",
"model = CRNN(32, 1, len(alphabetChinese)+1, 256, 1,lstmFlag=LSTMFLAG)\n",
"model.apply(weights_init)\n",
"preWeightDict = torch.load(ocrModel,map_location=lambda storage, loc: storage)##加入项目训练的权重\n",
"\n",
"modelWeightDict = model.state_dict()\n",
"\n",
"for k, v in preWeightDict.items():\n",
" name = k.replace('module.','') # remove `module.`\n",
" if 'rnn.1.embedding' not in name:##不加载最后一层权重\n",
" modelWeightDict[name] = v\n",
" \n",
"model.load_state_dict(modelWeightDict)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"##优化器\n",
"from crnn.util import strLabelConverter\n",
"lr = 0.1\n",
"optimizer = optim.Adadelta(model.parameters(), lr=lr)\n",
"converter = strLabelConverter(''.join(alphabetChinese))\n",
"criterion = CTCLoss()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from train.ocr.dataset import resizeNormalize\n",
"from crnn.util import loadData\n",
"image = torch.FloatTensor(batchSize, 3, imgH, imgH)\n",
"text = torch.IntTensor(batchSize * 5)\n",
"length = torch.IntTensor(batchSize)\n",
"\n",
"if torch.cuda.is_available():\n",
" model.cuda()\n",
" model = torch.nn.DataParallel(model, device_ids=[0])##转换为多GPU训练模型\n",
" image = image.cuda()\n",
" criterion = criterion.cuda()\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def trainBatch(net, criterion, optimizer,cpu_images, cpu_texts):\n",
" #data = train_iter.next()\n",
" #cpu_images, cpu_texts = data\n",
" batch_size = cpu_images.size(0)\n",
" loadData(image, cpu_images)\n",
" t, l = converter.encode(cpu_texts)\n",
" \n",
" loadData(text, t)\n",
" loadData(length, l)\n",
" preds = net(image)\n",
" preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))\n",
" cost = criterion(preds, text, preds_size, length) / batch_size\n",
" net.zero_grad()\n",
" cost.backward()\n",
" optimizer.step()\n",
" return cost\n",
"\n",
"\n",
"\n",
"\n",
"def predict(im):\n",
" \"\"\"\n",
" 预测\n",
" \"\"\"\n",
" image = im.convert('L')\n",
" scale = image.size[1]*1.0 / 32\n",
" w = image.size[0] / scale\n",
" w = int(w)\n",
" transformer = resizeNormalize((w, 32))\n",
" \n",
" image = transformer(image)\n",
" if torch.cuda.is_available():\n",
" image = image.cuda()\n",
" image = image.view(1, *image.size())\n",
" image = Variable(image)\n",
" preds = model(image)\n",
" _, preds = preds.max(2)\n",
" preds = preds.transpose(1, 0).contiguous().view(-1)\n",
" preds_size = Variable(torch.IntTensor([preds.size(0)]))\n",
" sim_pred = converter.decode(preds.data, preds_size.data, raw=False)\n",
" return sim_pred\n",
" \n",
" \n",
"def val(net, dataset, max_iter=100):\n",
"\n",
" for p in net.parameters():\n",
" p.requires_grad = False\n",
" net.eval()\n",
" i = 0\n",
" n_correct = 0\n",
" N = len(dataset)\n",
" \n",
" max_iter = min(max_iter, N)\n",
" for i in range(max_iter):\n",
" im,label = dataset[np.random.randint(0,N)]\n",
" if im.size[0]>1024:\n",
" continue\n",
" \n",
" pred = predict(im)\n",
" if pred.strip() ==label:\n",
" n_correct += 1\n",
"\n",
" accuracy = n_correct / float(max_iter )\n",
" return accuracy\n",
" \n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from train.ocr.generic_utils import Progbar\n",
"##进度条参考 https://github.com/keras-team/keras/blob/master/keras/utils/generic_utils.py\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 模型训练"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 冻结预训练模型层参数"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nepochs = 10\n",
"acc = 0\n",
"\n",
"interval = len(train_loader)//2##评估模型\n",
"\n",
" \n",
"for i in range(nepochs):\n",
" print('epoch:{}/{}'.format(i,nepochs))\n",
" n = len(train_loader)\n",
" pbar = Progbar(target=n)\n",
" train_iter = iter(train_loader)\n",
" loss = 0\n",
" for j in range(n):\n",
" for p in model.named_parameters():\n",
" p[1].requires_grad = True\n",
" if 'rnn.1.embedding' in p[0]:\n",
" p[1].requires_grad = True\n",
" else:\n",
" p[1].requires_grad = False##冻结模型层\n",
"\n",
" model.train()\n",
" cpu_images, cpu_texts = train_iter.next()\n",
" cost = trainBatch(model, criterion, optimizer,cpu_images, cpu_texts)\n",
"\n",
" loss += cost.data.numpy()\n",
" \n",
" if (j+1)%interval==0:\n",
" curAcc = val(model, testdataset, max_iter=1024)\n",
" if curAcc>acc:\n",
" acc = curAcc\n",
" torch.save(model.state_dict(), 'train/ocr/modellstm.pth')\n",
" \n",
" \n",
" pbar.update(j+1,values=[('loss',loss/((j+1)*batchSize)),('acc',acc)])\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 释放模型层参数"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nepochs = 10\n",
"#acc = 0\n",
"\n",
"interval = len(train_loader)//2##评估模型\n",
"\n",
" \n",
"for i in range(10,10+nepochs):\n",
" print('epoch:{}/{}'.format(i,nepochs))\n",
" n = len(train_loader)\n",
" pbar = Progbar(target=n)\n",
" train_iter = iter(train_loader)\n",
" loss = 0\n",
" for j in range(n):\n",
" for p in model.named_parameters():\n",
" p[1].requires_grad = True\n",
"\n",
"\n",
" model.train()\n",
" cpu_images, cpu_texts = train_iter.next()\n",
" cost = trainBatch(model, criterion, optimizer,cpu_images, cpu_texts)\n",
"\n",
" loss += cost.data.numpy()\n",
" \n",
" if (j+1)%interval==0:\n",
" curAcc = val(model, testdataset, max_iter=1024)\n",
" if curAcc>acc:\n",
" acc = curAcc\n",
" torch.save(model.state_dict(), 'train/ocr/modellstm.pth')\n",
" \n",
" \n",
" pbar.update(j+1,values=[('loss',loss/((j+1)*batchSize)),('acc',acc)])\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 预测demo"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"N = len(testdataset)\n",
"im,label = testdataset[np.random.randint(0,N)]\n",
"pred = predict(im)\n",
"print('true:{},pred:{}'.format(label,pred))\n",
"im\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "chineseocr",
"language": "python",
"name": "chineseocr"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}