{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import time\n",
"import re\n",
"import base64\n",
"from io import BytesIO\n",
"\n",
"import torch\n",
"#import matplotlib.pyplot as plt # plt 用于显示图片\n",
"from PIL import Image\n",
"from torch.autograd import Variable\n",
"from torchvision.transforms import ToTensor, ToPILImage\n",
"from model import Generator"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def image_to_base64(img):\n",
" global scale\n",
" buffered = BytesIO()\n",
" imgSize = img.size\n",
" if scale > 1:\n",
" resize_img = img.resize((int(imgSize[0]*scale), int(imgSize[1]*scale)))\n",
" else:\n",
" resize_img = img\n",
" resize_img.save(buffered, format=\"JPEG\")\n",
" img_str = base64.b64encode(buffered.getvalue())\n",
" return img_str.decode(\"utf-8\")\n",
"\n",
"\n",
"def base64_to_image(base64_str, grayscale):\n",
" # 字符串替换\n",
" # print(base64_str)\n",
" global scale\n",
" base64_data = re.sub('^data:image/.+;base64,', '', base64_str)\n",
" byte_data = base64.b64decode(base64_data)\n",
" image_data = BytesIO(byte_data)\n",
" img = image.load_img(image_data, grayscale)\n",
" imgSize = img.size\n",
" scale = max(imgSize)/500\n",
" if scale > 1:\n",
" resize_img = img.resize((int(imgSize[0]/scale), int(imgSize[1]/scale)))\n",
" return image.img_to_array(resize_img)\n",
" else:\n",
" return image.img_to_array(img)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def trans(img,conf):\n",
" UPSCALE_FACTOR=4\n",
" MODEL_NAME = 'netG_epoch_4_100.pth'\n",
" model = Generator(UPSCALE_FACTOR).eval()\n",
" model.load_state_dict(torch.load('results/epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))\n",
" image = Variable(ToTensor()(img), volatile=True).unsqueeze(0)\n",
" out = model(image)\n",
" out_img = ToPILImage()(out[0].data.cpu())\n",
" #plt.imshow(out_img)\n",
" #plt.axis('off') # 不显示坐标轴\n",
" #plt.show()\n",
" return out_img\n",
" #out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def handle(conf):\n",
" # paste your code here\n",
" start_time = time.time()\n",
" img = conf[\"输入图片\"]\n",
" image_narry = base64_to_image(img, False)\n",
" time1 = time.time()\n",
" image_result = trans(image_narry, conf)\n",
" time2 = time.time()\n",
" image_str = image_to_base64(image_result)\n",
" time3 = time.time()\n",
" return {\"增强后的图片\": image_str}\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}