- benchmark_results
- data
- images
- job_logs
- pytorch_ssim
- results
- training_results
- .gitignore
- _overview.md
- _readme.ipynb
- app_spec.yml
- BSD100_009.png
- coding_here.ipynb
- data_utils.py
- handler.py
- loss.py
- main.ipynb
- model.py
- out_srf_4_BSD100_009.png
- project_requirements.txt
- README.md
- test_benchmark.py
- test_image.py
- test_video.py
- train.py
- vgg16.pth
main.ipynb @master — view markup · raw · history · blame
In [9]:
import os
import sys
import time
import re
import base64
from io import BytesIO
import torch
#import matplotlib.pyplot as plt # plt 用于显示图片
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
from model import Generator
In [ ]:
def image_to_base64(img):
global scale
buffered = BytesIO()
imgSize = img.size
if scale > 1:
resize_img = img.resize((int(imgSize[0]*scale), int(imgSize[1]*scale)))
else:
resize_img = img
resize_img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode("utf-8")
def base64_to_image(base64_str, grayscale):
# 字符串替换
# print(base64_str)
global scale
base64_data = re.sub('^data:image/.+;base64,', '', base64_str)
byte_data = base64.b64decode(base64_data)
image_data = BytesIO(byte_data)
img = image.load_img(image_data, grayscale)
imgSize = img.size
scale = max(imgSize)/500
if scale > 1:
resize_img = img.resize((int(imgSize[0]/scale), int(imgSize[1]/scale)))
return image.img_to_array(resize_img)
else:
return image.img_to_array(img)
In [7]:
def trans(img,conf):
UPSCALE_FACTOR=4
MODEL_NAME = 'netG_epoch_4_100.pth'
model = Generator(UPSCALE_FACTOR).eval()
model.load_state_dict(torch.load('results/epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))
image = Variable(ToTensor()(img), volatile=True).unsqueeze(0)
out = model(image)
out_img = ToPILImage()(out[0].data.cpu())
#plt.imshow(out_img)
#plt.axis('off') # 不显示坐标轴
#plt.show()
return out_img
#out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)
In [ ]:
def handle(conf):
# paste your code here
start_time = time.time()
img = conf["输入图片"]
image_narry = base64_to_image(img, False)
time1 = time.time()
image_result = trans(image_narry, conf)
time2 = time.time()
image_str = image_to_base64(image_result)
time3 = time.time()
return {"增强后的图片": image_str}