import os
import sys
import time
import re
import base64
import uuid
from io import BytesIO
import numpy as np
import torch
#import matplotlib.pyplot as plt # plt 用于显示图片
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
from model import Generator
def image_to_base64(img):
global scale
buffered = BytesIO()
imgSize = img.size
if scale > 1:
resize_img = img.resize((int(imgSize[0]*scale), int(imgSize[1]*scale)))
else:
resize_img = img
resize_img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode("utf-8")
def base64_to_image(base64_str, grayscale):
# 字符串替换
# print(base64_str)
global scale
base64_data = re.sub('^data:image/.+;base64,', '', base64_str)
byte_data = base64.b64decode(base64_data)
image_data = BytesIO(byte_data)
img = Image.load_img(image_data, grayscale)
imgSize = img.size
scale = max(imgSize)/500
if scale > 1:
resize_img = img.resize((int(imgSize[0]/scale), int(imgSize[1]/scale)))
return Image.img_to_array(resize_img)
else:
return Image.img_to_array(img)
def trans(img,conf):
UPSCALE_FACTOR=4
MODEL_NAME = 'netG_epoch_4_100.pth'
model = Generator(UPSCALE_FACTOR).eval()
model.load_state_dict(torch.load('results/epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))
image = Variable(ToTensor()(img), volatile=True).unsqueeze(0)
out = model(image)
out_img = ToPILImage()(out[0].data.cpu())
#plt.imshow(out_img)
#plt.axis('off') # 不显示坐标轴
#plt.show()
return out_img
#out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)
def handle(conf):
# paste your code here
start_time = time.time()
img = conf["输入图片"]#dic 字典 传入图片路径
# image_narry = base64_to_image(img, False)
img = np.asarray(Image.open(img))
time1 = time.time()
image_result = trans(img, conf)
time2 = time.time()
# image_str = image_to_base64(image_result)
tmp_img_path = '/tmp/'+str(uuid.uuid4())+'.jpeg'
image_result.save(tmp_img_path)
time3 = time.time()
return {"增强后的图片": tmp_img_path}