master
/ handler.py

handler.py @master

1aaa8ec
 
 
 
 
f9ec0ac
1aaa8ec
3086e8e
1aaa8ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48246e6
1aaa8ec
 
 
 
48246e6
1aaa8ec
48246e6
1aaa8ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b26d493
3086e8e
 
1aaa8ec
3086e8e
1aaa8ec
5759f77
f9ec0ac
 
1aaa8ec
f9ec0ac
1aaa8ec
import os
import sys
import time
import re
import base64
import uuid
from io import BytesIO
import numpy as np
import torch
#import matplotlib.pyplot as plt # plt 用于显示图片
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
from model import Generator

def image_to_base64(img):
    global scale
    buffered = BytesIO()
    imgSize = img.size
    if scale > 1:
        resize_img = img.resize((int(imgSize[0]*scale), int(imgSize[1]*scale)))
    else:
        resize_img = img
    resize_img.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue())
    return img_str.decode("utf-8")


def base64_to_image(base64_str, grayscale):
    # 字符串替换
    # print(base64_str)
    global scale
    base64_data = re.sub('^data:image/.+;base64,', '', base64_str)
    byte_data = base64.b64decode(base64_data)
    image_data = BytesIO(byte_data)
    img = Image.load_img(image_data, grayscale)
    imgSize = img.size
    scale = max(imgSize)/500
    if scale > 1:
        resize_img = img.resize((int(imgSize[0]/scale), int(imgSize[1]/scale)))
        return Image.img_to_array(resize_img)
    else:
        return Image.img_to_array(img)

def trans(img,conf):
    UPSCALE_FACTOR=4
    MODEL_NAME = 'netG_epoch_4_100.pth'
    model = Generator(UPSCALE_FACTOR).eval()
    model.load_state_dict(torch.load('results/epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))
    image = Variable(ToTensor()(img), volatile=True).unsqueeze(0)
    out = model(image)
    out_img = ToPILImage()(out[0].data.cpu())
    #plt.imshow(out_img)
    #plt.axis('off') # 不显示坐标轴
    #plt.show()
    return out_img
    #out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)


def handle(conf):
    # paste your code here
    start_time = time.time()
    img = conf["输入图片"]#dic 字典 传入图片路径
    # image_narry = base64_to_image(img, False)
    img = np.asarray(Image.open(img))
    time1 = time.time()
    image_result = trans(img, conf)
    time2 = time.time()
    # image_str = image_to_base64(image_result)
    tmp_img_path = '/tmp/'+str(uuid.uuid4())+'.jpeg'
    image_result.save(tmp_img_path)
    time3 = time.time()
    return {"增强后的图片": tmp_img_path}