diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2679d4f --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.idea/ +*.pyc +*.swp +.DS_Store +/.localenv/ +/datasets/ +/.ipynb_checkpoints/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..f90fde3 --- /dev/null +++ b/README.md @@ -0,0 +1,177 @@ +# SRGAN +A PyTorch implementation of SRGAN based on CVPR 2017 paper +[Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802). + +## Requirements +- [Anaconda](https://www.anaconda.com/download/) +- PyTorch +``` +conda install pytorch torchvision -c pytorch +``` +- opencv +``` +conda install opencv +``` + +## Datasets + +### Train、Val Dataset +The train and val datasets are sampled from [VOC2012](http://cvlab.postech.ac.kr/~mooyeol/pascal_voc_2012/). +Train dataset has 16700 images and Val dataset has 425 images. +Download the datasets from [here](https://pan.baidu.com/s/1xuFperu2WiYc5-_QXBemlA)(access code:5tzp), and then extract it into `data` directory. + +### Test Image Dataset +The test image dataset are sampled from +| **Set 5** | [Bevilacqua et al. BMVC 2012](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html) +| **Set 14** | [Zeyde et al. LNCS 2010](https://sites.google.com/site/romanzeyde/research-interests) +| **BSD 100** | [Martin et al. ICCV 2001](https://www.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/) +| **Sun-Hays 80** | [Sun and Hays ICCP 2012](http://cs.brown.edu/~lbsun/SRproj2012/SR_iccp2012.html) +| **Urban 100** | [Huang et al. CVPR 2015](https://sites.google.com/site/jbhuang0604/publications/struct_sr). +Download the image dataset from [here](https://pan.baidu.com/s/1vGosnyal21wGgVffriL1VQ)(access code:xwhy), and then extract it into `data` directory. + +### Test Video Dataset +The test video dataset are three trailers. Download the video dataset from +[here](https://pan.baidu.com/s/1NUZKm5xCHRj1O0JlCZIu8Q)(access code:zabi). + +## Usage + +### Train +``` +python train.py + +optional arguments: +--crop_size training images crop size [default value is 88] +--upscale_factor super resolution upscale factor [default value is 4](choices:[2, 4, 8]) +--num_epochs train epoch number [default value is 100] +``` +The output val super resolution images are on `training_results` directory. + +### Test Benchmark Datasets +``` +python test_benchmark.py + +optional arguments: +--upscale_factor super resolution upscale factor [default value is 4] +--model_name generator model epoch name [default value is netG_epoch_4_100.pth] +``` +The output super resolution images are on `benchmark_results` directory. + +### Test Single Image +``` +python test_image.py + +optional arguments: +--upscale_factor super resolution upscale factor [default value is 4] +--test_mode using GPU or CPU [default value is 'GPU'](choices:['GPU', 'CPU']) +--image_name test low resolution image name +--model_name generator model epoch name [default value is netG_epoch_4_100.pth] +``` +The output super resolution image are on the same directory. + +### Test Single Video +``` +python test_video.py + +optional arguments: +--upscale_factor super resolution upscale factor [default value is 4] +--video_name test low resolution video name +--model_name generator model epoch name [default value is netG_epoch_4_100.pth] +``` +The output super resolution video and compared video are on the same directory. + +## Benchmarks +**Upscale Factor = 2** + +Epochs with batch size of 64 takes ~2 minute 30 seconds on a NVIDIA GTX 1080Ti GPU. + +> Image Results + +The left is bicubic interpolation image, the middle is high resolution image, and +the right is super resolution image(output of the SRGAN). + +- BSD100_070(PSNR:32.4517; SSIM:0.9191) + +![BSD100_070](images/1.png) + +- Set14_005(PSNR:26.9171; SSIM:0.9119) + +![Set14_005](images/2.png) + +- Set14_013(PSNR:30.8040; SSIM:0.9651) + +![Set14_013](images/3.png) + +- Urban100_098(PSNR:24.3765; SSIM:0.7855) + +![Urban100_098](images/4.png) + +> Video Results + +The left is bicubic interpolation video, the right is super resolution video(output of the SRGAN). + +[![Watch the video](images/video_SRF_2.png)](https://youtu.be/05vx-vOJOZs) + +**Upscale Factor = 4** + +Epochs with batch size of 64 takes ~4 minute 30 seconds on a NVIDIA GTX 1080Ti GPU. + +> Image Results + +The left is bicubic interpolation image, the middle is high resolution image, and +the right is super resolution image(output of the SRGAN). + +- BSD100_035(PSNR:32.3980; SSIM:0.8512) + +![BSD100_035](images/5.png) + +- Set14_011(PSNR:29.5944; SSIM:0.9044) + +![Set14_011](images/6.png) + +- Set14_014(PSNR:25.1299; SSIM:0.7406) + +![Set14_014](images/7.png) + +- Urban100_060(PSNR:20.7129; SSIM:0.5263) + +![Urban100_060](images/8.png) + +> Video Results + +The left is bicubic interpolation video, the right is super resolution video(output of the SRGAN). + +[![Watch the video](images/video_SRF_4.png)](https://youtu.be/tNR2eiMeoQs) + +**Upscale Factor = 8** + +Epochs with batch size of 64 takes ~3 minute 30 seconds on a NVIDIA GTX 1080Ti GPU. + +> Image Results + +The left is bicubic interpolation image, the middle is high resolution image, and +the right is super resolution image(output of the SRGAN). + +- SunHays80_027(PSNR:29.4941; SSIM:0.8082) + +![SunHays80_027](images/9.png) + +- SunHays80_035(PSNR:32.1546; SSIM:0.8449) + +![SunHays80_035](images/10.png) + +- SunHays80_043(PSNR:30.9716; SSIM:0.8789) + +![SunHays80_043](images/11.png) + +- SunHays80_078(PSNR:31.9351; SSIM:0.8381) + +![SunHays80_078](images/12.png) + +> Video Results + +The left is bicubic interpolation video, the right is super resolution video(output of the SRGAN). + +[![Watch the video](images/video_SRF_8.png)](https://youtu.be/EuvXTKCRr8I) + +The complete test results could be downloaded from [here](https://pan.baidu.com/s/1pAaEgAQ4aRZbtKo8hNrb4Q)(access code:wnvn). + diff --git a/_OVERVIEW.md b/_OVERVIEW.md new file mode 100644 index 0000000..d2a86d2 --- /dev/null +++ b/_OVERVIEW.md @@ -0,0 +1,5 @@ +## 介绍 (Introduction) + +添加该项目的功能、使用场景和输入输出参数等相关信息。 + +You can describe the function, usage and parameters of the project. \ No newline at end of file diff --git a/_README.ipynb b/_README.ipynb new file mode 100644 index 0000000..2abc9d6 --- /dev/null +++ b/_README.ipynb @@ -0,0 +1,150 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. 项目介绍\n", + "\n", + " - 项目是由模块组成、有特定功能的程序。它能够满足用户的直接使用需求,例如[古诗词生成器](https://momodel.cn/explore/5bfb634e1afd943c623dd9cf?type=app&tab=1)、[风格迁移](https://momodel.cn/explore/5bfb634e1afd943c623dd9cf?type=app&tab=1)等。\n", + " - 开发项目过程中你可以导入数据集,也可以通过每个 cell 上方工具栏的`<+>`直接插入[模块](https://momodel.cn/modules)和代码块。\n", + " - 你可以将开发好的项目进行[部署](https://momodel.cn/docs/#/zh-cn/%E5%BC%80%E5%8F%91%E5%92%8C%E9%83%A8%E7%BD%B2%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8%EF%BC%88app%EF%BC%89),项目部署成功并选择正式版本发布后会展示在“项目”页面,用户可以在线使用,也可以通过 API 调用。\n", + "\n", + " - 项目目录结构:\n", + "\n", + " - ```results```*-----结果的文件存放地(如果你运行 job,务必将运行结果指定在此目录)*\n", + " - ```_OVERVIEW.md``` *-----项目的相关介绍*\n", + " - ```_README.md```*-----说明文档*\n", + " - ```app_spec.yml```*-----定义项目的输入输出,为部署服务*\n", + " - ```coding_here.ipynb```*-----输入并运行代码*" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## 2. 开发环境简介\n", + "\n", + "你当前所在的页面 Notebook 是一个内嵌 JupyterLab 的在线类 IDE 编程环境,开发过程中可以使用页面右侧的 API 文档进行快速查询。Notebook 有以下主要功能:\n", + "\n", + "- [调用数据集、模块和代码块资源](https://momodel.cn/docs/#/zh-cn/%E5%A6%82%E4%BD%95%E5%AF%BC%E5%85%A5%E5%B9%B6%E4%BD%BF%E7%94%A8%E6%A8%A1%E5%9D%97%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86)\n", + "- [多人代码协作](https://momodel.cn/docs/#/zh-cn/%E5%9C%A8Mo%E8%BF%90%E8%A1%8C%E4%BD%A0%E7%9A%84%E7%AC%AC%E4%B8%80%E6%AE%B5%E4%BB%A3%E7%A0%81?id=_7-%e4%bd%a0%e5%8f%af%e4%bb%a5%e9%82%80%e8%af%b7%e5%a5%bd%e5%8f%8b%e8%bf%9b%e8%a1%8c%e5%8d%8f%e4%bd%9c)\n", + "- [在 GPU 资源上训练机器学习模型](https://momodel.cn/docs/#/zh-cn/%E5%9C%A8GPU%E6%88%96CPU%E8%B5%84%E6%BA%90%E4%B8%8A%E8%AE%AD%E7%BB%83%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E6%A8%A1%E5%9E%8B)\n", + "- [简单部署](https://momodel.cn/docs/#/zh-cn/%E5%BC%80%E5%8F%91%E5%92%8C%E9%83%A8%E7%BD%B2%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8%EF%BC%88app%EF%BC%89)\n", + "\n", + "快来动手试试吧!点击左侧工具栏的新建文件图标即可选择你需要的文件类型。\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "左侧和右侧工具栏都可根据使用需要进行收合。\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. 快捷键与代码补全\n", + "Mo Notebook 已完全采用 Jupyter Notebook 的原生快捷键,并且支持 `tab` 代码补全。\n", + "\n", + "运行代码:`shift` + `enter` 或者 `shift` + `return`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. 常用指令介绍\n", + "\n", + "- 解压上传后的文件\n", + "\n", + "在 cell 中输入并运行以下命令:\n", + "```!unzip -o file_name.zip```\n", + "\n", + "- 查看所有包(package)\n", + "\n", + "`!pip list --format=columns`\n", + "\n", + "- 检查是否已有某个包\n", + "\n", + "`!pip show package_name`\n", + "\n", + "- 安装缺失的包\n", + "\n", + "`!pip install package_name`\n", + "\n", + "- 更新已有的包\n", + "\n", + "`!pip install package_name --upgrade`\n", + "\n", + "\n", + "- 使用包\n", + "\n", + "`import package_name`\n", + "\n", + "- 显示当前目录下的档案及目录\n", + "\n", + "`ls`\n", + "\n", + "- 使用引入的数据集\n", + "\n", + "数据集被引入后存放在 datasets 文件夹下,注意,这个文件夹是只读的,不可修改。如果需要修改,可在 Notebook 中使用\n", + "\n", + "`!cp -R ./datasets/ ./`\n", + "\n", + "指令将其复制到其他文件夹后再编辑,对于引入的数据集中的 zip 文件,可使用\n", + "\n", + "`!unzip ./datasets// -d ./`\n", + "\n", + "指令解压缩到其他文件夹后使用" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. 其他可参考资源\n", + "- [帮助文档](https://momodel.cn/docs/#/):基本页面介绍和常见问题都可以在里面找到\n", + "- [平台功能教程](https://momodel.cn/classroom/class?id=5c5696cd1afd9458d456bf54&type=doc):通过图文结合的 Notebook 详细介绍开发环境基本功能和操作\n", + "- [吴恩达机器学习](https://momodel.cn/classroom/class?id=5c5696191afd94720cc94533&type=video):机器学习经典课程\n", + "- [李宏毅机器学习](https://s.momodel.cn/classroom/class?id=5d40fdafb5113408a8dbb4a1&type=video):中文世界最好的机器学习课程\n", + "- [机器学习实战](https://momodel.cn/classroom/class?id=5c680b311afd943a9f70901b&type=practice):通过实操指引完成独立的模型,掌握相应的机器学习知识\n", + "- [Python 教程](https://momodel.cn/classroom/class?id=5d1f3ab81afd940ab7d298bf&type=notebook):简单易懂的 Python 新手教程\n", + "- [模块开发](https://momodel.cn/modules):关于模型训练、开发与部署的高阶教程" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "source": [], + "metadata": { + "collapsed": false + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/benchmark_results/.gitkeep b/benchmark_results/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/coding_here.ipynb b/coding_here.ipynb new file mode 100644 index 0000000..90d9432 --- /dev/null +++ b/coding_here.ipynb @@ -0,0 +1,36 @@ + + { + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print('Hello Mo!')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 + } + \ No newline at end of file diff --git a/data/.gitkeep b/data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/data_utils.py b/data_utils.py new file mode 100755 index 0000000..ff41d51 --- /dev/null +++ b/data_utils.py @@ -0,0 +1,98 @@ +from os import listdir +from os.path import join + +from PIL import Image +from torch.utils.data.dataset import Dataset +from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']) + + +def calculate_valid_crop_size(crop_size, upscale_factor): + return crop_size - (crop_size % upscale_factor) + + +def train_hr_transform(crop_size): + return Compose([ + RandomCrop(crop_size), + ToTensor(), + ]) + + +def train_lr_transform(crop_size, upscale_factor): + return Compose([ + ToPILImage(), + Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC), + ToTensor() + ]) + + +def display_transform(): + return Compose([ + ToPILImage(), + Resize(400), + CenterCrop(400), + ToTensor() + ]) + + +class TrainDatasetFromFolder(Dataset): + def __init__(self, dataset_dir, crop_size, upscale_factor): + super(TrainDatasetFromFolder, self).__init__() + self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)] + crop_size = calculate_valid_crop_size(crop_size, upscale_factor) + self.hr_transform = train_hr_transform(crop_size) + self.lr_transform = train_lr_transform(crop_size, upscale_factor) + + def __getitem__(self, index): + hr_image = self.hr_transform(Image.open(self.image_filenames[index])) + lr_image = self.lr_transform(hr_image) + return lr_image, hr_image + + def __len__(self): + return len(self.image_filenames) + + +class ValDatasetFromFolder(Dataset): + def __init__(self, dataset_dir, upscale_factor): + super(ValDatasetFromFolder, self).__init__() + self.upscale_factor = upscale_factor + self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)] + + def __getitem__(self, index): + hr_image = Image.open(self.image_filenames[index]) + w, h = hr_image.size + crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor) + lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC) + hr_scale = Resize(crop_size, interpolation=Image.BICUBIC) + hr_image = CenterCrop(crop_size)(hr_image) + lr_image = lr_scale(hr_image) + hr_restore_img = hr_scale(lr_image) + return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image) + + def __len__(self): + return len(self.image_filenames) + + +class TestDatasetFromFolder(Dataset): + def __init__(self, dataset_dir, upscale_factor): + super(TestDatasetFromFolder, self).__init__() + self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/' + self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/' + self.upscale_factor = upscale_factor + self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)] + self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)] + + def __getitem__(self, index): + image_name = self.lr_filenames[index].split('/')[-1] + lr_image = Image.open(self.lr_filenames[index]) + w, h = lr_image.size + hr_image = Image.open(self.hr_filenames[index]) + hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=Image.BICUBIC) + hr_restore_img = hr_scale(lr_image) + return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image) + + def __len__(self): + return len(self.lr_filenames) diff --git a/epochs/.gitkeep b/epochs/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/images/1.png b/images/1.png new file mode 100644 index 0000000..5ee502f Binary files /dev/null and b/images/1.png differ diff --git a/images/10.png b/images/10.png new file mode 100644 index 0000000..337f02a Binary files /dev/null and b/images/10.png differ diff --git a/images/11.png b/images/11.png new file mode 100644 index 0000000..8f12162 Binary files /dev/null and b/images/11.png differ diff --git a/images/12.png b/images/12.png new file mode 100644 index 0000000..c4900c2 Binary files /dev/null and b/images/12.png differ diff --git a/images/2.png b/images/2.png new file mode 100644 index 0000000..8b795f2 Binary files /dev/null and b/images/2.png differ diff --git a/images/3.png b/images/3.png new file mode 100644 index 0000000..8144308 Binary files /dev/null and b/images/3.png differ diff --git a/images/4.png b/images/4.png new file mode 100644 index 0000000..54ca285 Binary files /dev/null and b/images/4.png differ diff --git a/images/5.png b/images/5.png new file mode 100644 index 0000000..a6b9f1e Binary files /dev/null and b/images/5.png differ diff --git a/images/6.png b/images/6.png new file mode 100644 index 0000000..41fa8a9 Binary files /dev/null and b/images/6.png differ diff --git a/images/7.png b/images/7.png new file mode 100644 index 0000000..8ba13c8 Binary files /dev/null and b/images/7.png differ diff --git a/images/8.png b/images/8.png new file mode 100644 index 0000000..0b1ce2d Binary files /dev/null and b/images/8.png differ diff --git a/images/9.png b/images/9.png new file mode 100644 index 0000000..7cd1237 Binary files /dev/null and b/images/9.png differ diff --git a/images/video_SRF_2.png b/images/video_SRF_2.png new file mode 100644 index 0000000..2877591 Binary files /dev/null and b/images/video_SRF_2.png differ diff --git a/images/video_SRF_4.png b/images/video_SRF_4.png new file mode 100644 index 0000000..e5a5fda Binary files /dev/null and b/images/video_SRF_4.png differ diff --git a/images/video_SRF_8.png b/images/video_SRF_8.png new file mode 100644 index 0000000..f7a8fe2 Binary files /dev/null and b/images/video_SRF_8.png differ diff --git a/loss.py b/loss.py new file mode 100644 index 0000000..ab30472 --- /dev/null +++ b/loss.py @@ -0,0 +1,51 @@ +import torch +from torch import nn +from torchvision.models.vgg import vgg16 + + +class GeneratorLoss(nn.Module): + def __init__(self): + super(GeneratorLoss, self).__init__() + vgg = vgg16(pretrained=True) + loss_network = nn.Sequential(*list(vgg.features)[:31]).eval() + for param in loss_network.parameters(): + param.requires_grad = False + self.loss_network = loss_network + self.mse_loss = nn.MSELoss() + self.tv_loss = TVLoss() + + def forward(self, out_labels, out_images, target_images): + # Adversarial Loss + adversarial_loss = torch.mean(1 - out_labels) + # Perception Loss + perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images)) + # Image Loss + image_loss = self.mse_loss(out_images, target_images) + # TV Loss + tv_loss = self.tv_loss(out_images) + return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss + + +class TVLoss(nn.Module): + def __init__(self, tv_loss_weight=1): + super(TVLoss, self).__init__() + self.tv_loss_weight = tv_loss_weight + + def forward(self, x): + batch_size = x.size()[0] + h_x = x.size()[2] + w_x = x.size()[3] + count_h = self.tensor_size(x[:, :, 1:, :]) + count_w = self.tensor_size(x[:, :, :, 1:]) + h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() + w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() + return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size + + @staticmethod + def tensor_size(t): + return t.size()[1] * t.size()[2] * t.size()[3] + + +if __name__ == "__main__": + g_loss = GeneratorLoss() + print(g_loss) diff --git a/model.py b/model.py new file mode 100755 index 0000000..261f0d0 --- /dev/null +++ b/model.py @@ -0,0 +1,117 @@ +import math +import torch +from torch import nn + + +class Generator(nn.Module): + def __init__(self, scale_factor): + upsample_block_num = int(math.log(scale_factor, 2)) + + super(Generator, self).__init__() + self.block1 = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=9, padding=4), + nn.PReLU() + ) + self.block2 = ResidualBlock(64) + self.block3 = ResidualBlock(64) + self.block4 = ResidualBlock(64) + self.block5 = ResidualBlock(64) + self.block6 = ResidualBlock(64) + self.block7 = nn.Sequential( + nn.Conv2d(64, 64, kernel_size=3, padding=1), + nn.BatchNorm2d(64) + ) + block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)] + block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4)) + self.block8 = nn.Sequential(*block8) + + def forward(self, x): + block1 = self.block1(x) + block2 = self.block2(block1) + block3 = self.block3(block2) + block4 = self.block4(block3) + block5 = self.block5(block4) + block6 = self.block6(block5) + block7 = self.block7(block6) + block8 = self.block8(block1 + block7) + + return (torch.tanh(block8) + 1) / 2 + + +class Discriminator(nn.Module): + def __init__(self): + super(Discriminator, self).__init__() + self.net = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, padding=1), + nn.LeakyReLU(0.2), + + nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(64), + nn.LeakyReLU(0.2), + + nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + + nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + + nn.Conv2d(128, 256, kernel_size=3, padding=1), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2), + + nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2), + + nn.Conv2d(256, 512, kernel_size=3, padding=1), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2), + + nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2), + + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(512, 1024, kernel_size=1), + nn.LeakyReLU(0.2), + nn.Conv2d(1024, 1, kernel_size=1) + ) + + def forward(self, x): + batch_size = x.size(0) + return torch.sigmoid(self.net(x).view(batch_size)) + + +class ResidualBlock(nn.Module): + def __init__(self, channels): + super(ResidualBlock, self).__init__() + self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(channels) + self.prelu = nn.PReLU() + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(channels) + + def forward(self, x): + residual = self.conv1(x) + residual = self.bn1(residual) + residual = self.prelu(residual) + residual = self.conv2(residual) + residual = self.bn2(residual) + + return x + residual + + +class UpsampleBLock(nn.Module): + def __init__(self, in_channels, up_scale): + super(UpsampleBLock, self).__init__() + self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1) + self.pixel_shuffle = nn.PixelShuffle(up_scale) + self.prelu = nn.PReLU() + + def forward(self, x): + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.prelu(x) + return x diff --git a/pytorch_ssim/__init__.py b/pytorch_ssim/__init__.py new file mode 100755 index 0000000..1ab54e9 --- /dev/null +++ b/pytorch_ssim/__init__.py @@ -0,0 +1,77 @@ +from math import exp + +import torch +import torch.nn.functional as F +from torch.autograd import Variable + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + + +def ssim(img1, img2, window_size=11, size_average=True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/results/README.md b/results/README.md new file mode 100644 index 0000000..81902f1 --- /dev/null +++ b/results/README.md @@ -0,0 +1,2 @@ +Please store your training checkpoints or results here +请在此处存储 checkpoints 和结果文件 \ No newline at end of file diff --git a/results/tb_results/README.md b/results/tb_results/README.md new file mode 100644 index 0000000..23659ee --- /dev/null +++ b/results/tb_results/README.md @@ -0,0 +1,2 @@ +Please store your tensorboard results here +请在此处存储 tensorboard 结果 diff --git a/statistics/.gitkeep b/statistics/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/test_benchmark.py b/test_benchmark.py new file mode 100755 index 0000000..0bd9870 --- /dev/null +++ b/test_benchmark.py @@ -0,0 +1,80 @@ +import argparse +import os +from math import log10 + +import numpy as np +import pandas as pd +import torch +import torchvision.utils as utils +from torch.autograd import Variable +from torch.utils.data import DataLoader +from tqdm import tqdm + +import pytorch_ssim +from data_utils import TestDatasetFromFolder, display_transform +from model import Generator + +parser = argparse.ArgumentParser(description='Test Benchmark Datasets') +parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor') +parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name') +opt = parser.parse_args() + +UPSCALE_FACTOR = opt.upscale_factor +MODEL_NAME = opt.model_name + +results = {'Set5': {'psnr': [], 'ssim': []}, 'Set14': {'psnr': [], 'ssim': []}, 'BSD100': {'psnr': [], 'ssim': []}, + 'Urban100': {'psnr': [], 'ssim': []}, 'SunHays80': {'psnr': [], 'ssim': []}} + +model = Generator(UPSCALE_FACTOR).eval() +if torch.cuda.is_available(): + model = model.cuda() +model.load_state_dict(torch.load('epochs/' + MODEL_NAME)) + +test_set = TestDatasetFromFolder('data/test', upscale_factor=UPSCALE_FACTOR) +test_loader = DataLoader(dataset=test_set, num_workers=4, batch_size=1, shuffle=False) +test_bar = tqdm(test_loader, desc='[testing benchmark datasets]') + +out_path = 'benchmark_results/SRF_' + str(UPSCALE_FACTOR) + '/' +if not os.path.exists(out_path): + os.makedirs(out_path) + +for image_name, lr_image, hr_restore_img, hr_image in test_bar: + image_name = image_name[0] + lr_image = Variable(lr_image, volatile=True) + hr_image = Variable(hr_image, volatile=True) + if torch.cuda.is_available(): + lr_image = lr_image.cuda() + hr_image = hr_image.cuda() + + sr_image = model(lr_image) + mse = ((hr_image - sr_image) ** 2).data.mean() + psnr = 10 * log10(1 / mse) + ssim = pytorch_ssim.ssim(sr_image, hr_image).data[0] + + test_images = torch.stack( + [display_transform()(hr_restore_img.squeeze(0)), display_transform()(hr_image.data.cpu().squeeze(0)), + display_transform()(sr_image.data.cpu().squeeze(0))]) + image = utils.make_grid(test_images, nrow=3, padding=5) + utils.save_image(image, out_path + image_name.split('.')[0] + '_psnr_%.4f_ssim_%.4f.' % (psnr, ssim) + + image_name.split('.')[-1], padding=5) + + # save psnr\ssim + results[image_name.split('_')[0]]['psnr'].append(psnr) + results[image_name.split('_')[0]]['ssim'].append(ssim) + +out_path = 'statistics/' +saved_results = {'psnr': [], 'ssim': []} +for item in results.values(): + psnr = np.array(item['psnr']) + ssim = np.array(item['ssim']) + if (len(psnr) == 0) or (len(ssim) == 0): + psnr = 'No data' + ssim = 'No data' + else: + psnr = psnr.mean() + ssim = ssim.mean() + saved_results['psnr'].append(psnr) + saved_results['ssim'].append(ssim) + +data_frame = pd.DataFrame(saved_results, results.keys()) +data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_test_results.csv', index_label='DataSet') diff --git a/test_image.py b/test_image.py new file mode 100755 index 0000000..ef7df58 --- /dev/null +++ b/test_image.py @@ -0,0 +1,40 @@ +import argparse +import time + +import torch +from PIL import Image +from torch.autograd import Variable +from torchvision.transforms import ToTensor, ToPILImage + +from model import Generator + +parser = argparse.ArgumentParser(description='Test Single Image') +parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor') +parser.add_argument('--test_mode', default='GPU', type=str, choices=['GPU', 'CPU'], help='using GPU or CPU') +parser.add_argument('--image_name', type=str, help='test low resolution image name') +parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name') +opt = parser.parse_args() + +UPSCALE_FACTOR = opt.upscale_factor +TEST_MODE = True if opt.test_mode == 'GPU' else False +IMAGE_NAME = opt.image_name +MODEL_NAME = opt.model_name + +model = Generator(UPSCALE_FACTOR).eval() +if TEST_MODE: + model.cuda() + model.load_state_dict(torch.load('epochs/' + MODEL_NAME)) +else: + model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage)) + +image = Image.open(IMAGE_NAME) +image = Variable(ToTensor()(image), volatile=True).unsqueeze(0) +if TEST_MODE: + image = image.cuda() + +start = time.clock() +out = model(image) +elapsed = (time.clock() - start) +print('cost' + str(elapsed) + 's') +out_img = ToPILImage()(out[0].data.cpu()) +out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME) diff --git a/test_video.py b/test_video.py new file mode 100755 index 0000000..38a3d90 --- /dev/null +++ b/test_video.py @@ -0,0 +1,85 @@ +import argparse + +import cv2 +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image +from torch.autograd import Variable +from torchvision.transforms import ToTensor, ToPILImage +from tqdm import tqdm + +from model import Generator + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Test Single Video') + parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor') + parser.add_argument('--video_name', type=str, help='test low resolution video name') + parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name') + opt = parser.parse_args() + + UPSCALE_FACTOR = opt.upscale_factor + VIDEO_NAME = opt.video_name + MODEL_NAME = opt.model_name + + model = Generator(UPSCALE_FACTOR).eval() + if torch.cuda.is_available(): + model = model.cuda() + # for cpu + # model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage)) + model.load_state_dict(torch.load('epochs/' + MODEL_NAME)) + + videoCapture = cv2.VideoCapture(VIDEO_NAME) + fps = videoCapture.get(cv2.CAP_PROP_FPS) + frame_numbers = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) + sr_video_size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR), + int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR) + compared_video_size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR * 2 + 10), + int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR + 10 + int( + int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR * 2 + 10) / int( + 10 * int(int( + videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR) // 5 + 1)) * int( + int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR) // 5 - 9))) + output_sr_name = 'out_srf_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.avi' + output_compared_name = 'compare_srf_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.avi' + sr_video_writer = cv2.VideoWriter(output_sr_name, cv2.VideoWriter_fourcc('M', 'P', 'E', 'G'), fps, sr_video_size) + compared_video_writer = cv2.VideoWriter(output_compared_name, cv2.VideoWriter_fourcc('M', 'P', 'E', 'G'), fps, + compared_video_size) + # read frame + success, frame = videoCapture.read() + test_bar = tqdm(range(int(frame_numbers)), desc='[processing video and saving result videos]') + for index in test_bar: + if success: + image = Variable(ToTensor()(frame), volatile=True).unsqueeze(0) + if torch.cuda.is_available(): + image = image.cuda() + + out = model(image) + out = out.cpu() + out_img = out.data[0].numpy() + out_img *= 255.0 + out_img = (np.uint8(out_img)).transpose((1, 2, 0)) + # save sr video + sr_video_writer.write(out_img) + + # make compared video and crop shot of left top\right top\center\left bottom\right bottom + out_img = ToPILImage()(out_img) + crop_out_imgs = transforms.FiveCrop(size=out_img.width // 5 - 9)(out_img) + crop_out_imgs = [np.asarray(transforms.Pad(padding=(10, 5, 0, 0))(img)) for img in crop_out_imgs] + out_img = transforms.Pad(padding=(5, 0, 0, 5))(out_img) + compared_img = transforms.Resize(size=(sr_video_size[1], sr_video_size[0]), interpolation=Image.BICUBIC)( + ToPILImage()(frame)) + crop_compared_imgs = transforms.FiveCrop(size=compared_img.width // 5 - 9)(compared_img) + crop_compared_imgs = [np.asarray(transforms.Pad(padding=(0, 5, 10, 0))(img)) for img in crop_compared_imgs] + compared_img = transforms.Pad(padding=(0, 0, 5, 5))(compared_img) + # concatenate all the pictures to one single picture + top_image = np.concatenate((np.asarray(compared_img), np.asarray(out_img)), axis=1) + bottom_image = np.concatenate(crop_compared_imgs + crop_out_imgs, axis=1) + bottom_image = np.asarray(transforms.Resize( + size=(int(top_image.shape[1] / bottom_image.shape[1] * bottom_image.shape[0]), top_image.shape[1]))( + ToPILImage()(bottom_image))) + final_image = np.concatenate((top_image, bottom_image)) + # save compared video + compared_video_writer.write(final_image) + # next frame + success, frame = videoCapture.read() diff --git a/train.py b/train.py new file mode 100755 index 0000000..a2e21b7 --- /dev/null +++ b/train.py @@ -0,0 +1,166 @@ +import argparse +import os +from math import log10 + +import pandas as pd +import torch.optim as optim +import torch.utils.data +import torchvision.utils as utils +from torch.autograd import Variable +from torch.utils.data import DataLoader +from tqdm import tqdm + +import pytorch_ssim +from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform +from loss import GeneratorLoss +from model import Generator, Discriminator + +parser = argparse.ArgumentParser(description='Train Super Resolution Models') +parser.add_argument('--crop_size', default=88, type=int, help='training images crop size') +parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8], + help='super resolution upscale factor') +parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number') + + +if __name__ == '__main__': + opt = parser.parse_args() + + CROP_SIZE = opt.crop_size + UPSCALE_FACTOR = opt.upscale_factor + NUM_EPOCHS = opt.num_epochs + + train_set = TrainDatasetFromFolder('data/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR) + val_set = ValDatasetFromFolder('data/DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR) + train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True) + val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False) + + netG = Generator(UPSCALE_FACTOR) + print('# generator parameters:', sum(param.numel() for param in netG.parameters())) + netD = Discriminator() + print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) + + generator_criterion = GeneratorLoss() + + if torch.cuda.is_available(): + netG.cuda() + netD.cuda() + generator_criterion.cuda() + + optimizerG = optim.Adam(netG.parameters()) + optimizerD = optim.Adam(netD.parameters()) + + results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []} + + for epoch in range(1, NUM_EPOCHS + 1): + train_bar = tqdm(train_loader) + running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0} + + netG.train() + netD.train() + for data, target in train_bar: + g_update_first = True + batch_size = data.size(0) + running_results['batch_sizes'] += batch_size + + ############################ + # (1) Update D network: maximize D(x)-1-D(G(z)) + ########################### + real_img = Variable(target) + if torch.cuda.is_available(): + real_img = real_img.cuda() + z = Variable(data) + if torch.cuda.is_available(): + z = z.cuda() + fake_img = netG(z) + + netD.zero_grad() + real_out = netD(real_img).mean() + fake_out = netD(fake_img).mean() + d_loss = 1 - real_out + fake_out + d_loss.backward(retain_graph=True) + optimizerD.step() + + ############################ + # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss + ########################### + netG.zero_grad() + g_loss = generator_criterion(fake_out, fake_img, real_img) + g_loss.backward() + + fake_img = netG(z) + fake_out = netD(fake_img).mean() + + + optimizerG.step() + + # loss for current batch before optimization + running_results['g_loss'] += g_loss.item() * batch_size + running_results['d_loss'] += d_loss.item() * batch_size + running_results['d_score'] += real_out.item() * batch_size + running_results['g_score'] += fake_out.item() * batch_size + + train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % ( + epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'], + running_results['g_loss'] / running_results['batch_sizes'], + running_results['d_score'] / running_results['batch_sizes'], + running_results['g_score'] / running_results['batch_sizes'])) + + netG.eval() + out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/' + if not os.path.exists(out_path): + os.makedirs(out_path) + + with torch.no_grad(): + val_bar = tqdm(val_loader) + valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0} + val_images = [] + for val_lr, val_hr_restore, val_hr in val_bar: + batch_size = val_lr.size(0) + valing_results['batch_sizes'] += batch_size + lr = val_lr + hr = val_hr + if torch.cuda.is_available(): + lr = lr.cuda() + hr = hr.cuda() + sr = netG(lr) + + batch_mse = ((sr - hr) ** 2).data.mean() + valing_results['mse'] += batch_mse * batch_size + batch_ssim = pytorch_ssim.ssim(sr, hr).item() + valing_results['ssims'] += batch_ssim * batch_size + valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes'])) + valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes'] + val_bar.set_description( + desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % ( + valing_results['psnr'], valing_results['ssim'])) + + val_images.extend( + [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)), + display_transform()(sr.data.cpu().squeeze(0))]) + val_images = torch.stack(val_images) + val_images = torch.chunk(val_images, val_images.size(0) // 15) + val_save_bar = tqdm(val_images, desc='[saving training results]') + index = 1 + for image in val_save_bar: + image = utils.make_grid(image, nrow=3, padding=5) + utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5) + index += 1 + + # save model parameters + torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch)) + torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch)) + # save loss\scores\psnr\ssim + results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes']) + results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes']) + results['d_score'].append(running_results['d_score'] / running_results['batch_sizes']) + results['g_score'].append(running_results['g_score'] / running_results['batch_sizes']) + results['psnr'].append(valing_results['psnr']) + results['ssim'].append(valing_results['ssim']) + + if epoch % 10 == 0 and epoch != 0: + out_path = 'statistics/' + data_frame = pd.DataFrame( + data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'], + 'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']}, + index=range(1, epoch + 1)) + data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch') diff --git a/training_results/.gitkeep b/training_results/.gitkeep new file mode 100644 index 0000000..e69de29