master
/ handler.py

handler.py @64b73ac

1aaa8ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
import os
import sys
import time
import re
import base64
from io import BytesIO

import torch
#import matplotlib.pyplot as plt # plt 用于显示图片
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
from model import Generator

def image_to_base64(img):
    global scale
    buffered = BytesIO()
    imgSize = img.size
    if scale > 1:
        resize_img = img.resize((int(imgSize[0]*scale), int(imgSize[1]*scale)))
    else:
        resize_img = img
    resize_img.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue())
    return img_str.decode("utf-8")


def base64_to_image(base64_str, grayscale):
    # 字符串替换
    # print(base64_str)
    global scale
    base64_data = re.sub('^data:image/.+;base64,', '', base64_str)
    byte_data = base64.b64decode(base64_data)
    image_data = BytesIO(byte_data)
    img = image.load_img(image_data, grayscale)
    imgSize = img.size
    scale = max(imgSize)/500
    if scale > 1:
        resize_img = img.resize((int(imgSize[0]/scale), int(imgSize[1]/scale)))
        return image.img_to_array(resize_img)
    else:
        return image.img_to_array(img)

def trans(img,conf):
    UPSCALE_FACTOR=4
    MODEL_NAME = 'netG_epoch_4_100.pth'
    model = Generator(UPSCALE_FACTOR).eval()
    model.load_state_dict(torch.load('results/epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))
    image = Variable(ToTensor()(img), volatile=True).unsqueeze(0)
    out = model(image)
    out_img = ToPILImage()(out[0].data.cpu())
    #plt.imshow(out_img)
    #plt.axis('off') # 不显示坐标轴
    #plt.show()
    return out_img
    #out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)


def handle(conf):
    # paste your code here
    start_time = time.time()
    img = conf["输入图片"]
    image_narry = base64_to_image(img, False)
    time1 = time.time()
    image_result = trans(image_narry, conf)
    time2 = time.time()
    image_str = image_to_base64(image_result)
    time3 = time.time()
    return {"增强后的图片": image_str}