master
/ main.ipynb

main.ipynb @master

1aaa8ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import time\n",
    "import re\n",
    "import base64\n",
    "from io import BytesIO\n",
    "\n",
    "import torch\n",
    "#import matplotlib.pyplot as plt # plt 用于显示图片\n",
    "from PIL import Image\n",
    "from torch.autograd import Variable\n",
    "from torchvision.transforms import ToTensor, ToPILImage\n",
    "from model import Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def image_to_base64(img):\n",
    "    global scale\n",
    "    buffered = BytesIO()\n",
    "    imgSize = img.size\n",
    "    if scale > 1:\n",
    "        resize_img = img.resize((int(imgSize[0]*scale), int(imgSize[1]*scale)))\n",
    "    else:\n",
    "        resize_img = img\n",
    "    resize_img.save(buffered, format=\"JPEG\")\n",
    "    img_str = base64.b64encode(buffered.getvalue())\n",
    "    return img_str.decode(\"utf-8\")\n",
    "\n",
    "\n",
    "def base64_to_image(base64_str, grayscale):\n",
    "    # 字符串替换\n",
    "    # print(base64_str)\n",
    "    global scale\n",
    "    base64_data = re.sub('^data:image/.+;base64,', '', base64_str)\n",
    "    byte_data = base64.b64decode(base64_data)\n",
    "    image_data = BytesIO(byte_data)\n",
    "    img = image.load_img(image_data, grayscale)\n",
    "    imgSize = img.size\n",
    "    scale = max(imgSize)/500\n",
    "    if scale > 1:\n",
    "        resize_img = img.resize((int(imgSize[0]/scale), int(imgSize[1]/scale)))\n",
    "        return image.img_to_array(resize_img)\n",
    "    else:\n",
    "        return image.img_to_array(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def trans(img,conf):\n",
    "    UPSCALE_FACTOR=4\n",
    "    MODEL_NAME = 'netG_epoch_4_100.pth'\n",
    "    model = Generator(UPSCALE_FACTOR).eval()\n",
    "    model.load_state_dict(torch.load('results/epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))\n",
    "    image = Variable(ToTensor()(img), volatile=True).unsqueeze(0)\n",
    "    out = model(image)\n",
    "    out_img = ToPILImage()(out[0].data.cpu())\n",
    "    #plt.imshow(out_img)\n",
    "    #plt.axis('off') # 不显示坐标轴\n",
    "    #plt.show()\n",
    "    return out_img\n",
    "    #out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def handle(conf):\n",
    "    # paste your code here\n",
    "    start_time = time.time()\n",
    "    img = conf[\"输入图片\"]\n",
    "    image_narry = base64_to_image(img, False)\n",
    "    time1 = time.time()\n",
    "    image_result = trans(image_narry, conf)\n",
    "    time2 = time.time()\n",
    "    image_str = image_to_base64(image_result)\n",
    "    time3 = time.time()\n",
    "    return {\"增强后的图片\": image_str}\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}