Initial Commit
lzfxxx
6 years ago
| 0 | # SRGAN | |
| 1 | A PyTorch implementation of SRGAN based on CVPR 2017 paper | |
| 2 | [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802). | |
| 3 | ||
| 4 | ## Requirements | |
| 5 | - [Anaconda](https://www.anaconda.com/download/) | |
| 6 | - PyTorch | |
| 7 | ``` | |
| 8 | conda install pytorch torchvision -c pytorch | |
| 9 | ``` | |
| 10 | - opencv | |
| 11 | ``` | |
| 12 | conda install opencv | |
| 13 | ``` | |
| 14 | ||
| 15 | ## Datasets | |
| 16 | ||
| 17 | ### Train、Val Dataset | |
| 18 | The train and val datasets are sampled from [VOC2012](http://cvlab.postech.ac.kr/~mooyeol/pascal_voc_2012/). | |
| 19 | Train dataset has 16700 images and Val dataset has 425 images. | |
| 20 | Download the datasets from [here](https://pan.baidu.com/s/1xuFperu2WiYc5-_QXBemlA)(access code:5tzp), and then extract it into `data` directory. | |
| 21 | ||
| 22 | ### Test Image Dataset | |
| 23 | The test image dataset are sampled from | |
| 24 | | **Set 5** | [Bevilacqua et al. BMVC 2012](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html) | |
| 25 | | **Set 14** | [Zeyde et al. LNCS 2010](https://sites.google.com/site/romanzeyde/research-interests) | |
| 26 | | **BSD 100** | [Martin et al. ICCV 2001](https://www.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/) | |
| 27 | | **Sun-Hays 80** | [Sun and Hays ICCP 2012](http://cs.brown.edu/~lbsun/SRproj2012/SR_iccp2012.html) | |
| 28 | | **Urban 100** | [Huang et al. CVPR 2015](https://sites.google.com/site/jbhuang0604/publications/struct_sr). | |
| 29 | Download the image dataset from [here](https://pan.baidu.com/s/1vGosnyal21wGgVffriL1VQ)(access code:xwhy), and then extract it into `data` directory. | |
| 30 | ||
| 31 | ### Test Video Dataset | |
| 32 | The test video dataset are three trailers. Download the video dataset from | |
| 33 | [here](https://pan.baidu.com/s/1NUZKm5xCHRj1O0JlCZIu8Q)(access code:zabi). | |
| 34 | ||
| 35 | ## Usage | |
| 36 | ||
| 37 | ### Train | |
| 38 | ``` | |
| 39 | python train.py | |
| 40 | ||
| 41 | optional arguments: | |
| 42 | --crop_size training images crop size [default value is 88] | |
| 43 | --upscale_factor super resolution upscale factor [default value is 4](choices:[2, 4, 8]) | |
| 44 | --num_epochs train epoch number [default value is 100] | |
| 45 | ``` | |
| 46 | The output val super resolution images are on `training_results` directory. | |
| 47 | ||
| 48 | ### Test Benchmark Datasets | |
| 49 | ``` | |
| 50 | python test_benchmark.py | |
| 51 | ||
| 52 | optional arguments: | |
| 53 | --upscale_factor super resolution upscale factor [default value is 4] | |
| 54 | --model_name generator model epoch name [default value is netG_epoch_4_100.pth] | |
| 55 | ``` | |
| 56 | The output super resolution images are on `benchmark_results` directory. | |
| 57 | ||
| 58 | ### Test Single Image | |
| 59 | ``` | |
| 60 | python test_image.py | |
| 61 | ||
| 62 | optional arguments: | |
| 63 | --upscale_factor super resolution upscale factor [default value is 4] | |
| 64 | --test_mode using GPU or CPU [default value is 'GPU'](choices:['GPU', 'CPU']) | |
| 65 | --image_name test low resolution image name | |
| 66 | --model_name generator model epoch name [default value is netG_epoch_4_100.pth] | |
| 67 | ``` | |
| 68 | The output super resolution image are on the same directory. | |
| 69 | ||
| 70 | ### Test Single Video | |
| 71 | ``` | |
| 72 | python test_video.py | |
| 73 | ||
| 74 | optional arguments: | |
| 75 | --upscale_factor super resolution upscale factor [default value is 4] | |
| 76 | --video_name test low resolution video name | |
| 77 | --model_name generator model epoch name [default value is netG_epoch_4_100.pth] | |
| 78 | ``` | |
| 79 | The output super resolution video and compared video are on the same directory. | |
| 80 | ||
| 81 | ## Benchmarks | |
| 82 | **Upscale Factor = 2** | |
| 83 | ||
| 84 | Epochs with batch size of 64 takes ~2 minute 30 seconds on a NVIDIA GTX 1080Ti GPU. | |
| 85 | ||
| 86 | > Image Results | |
| 87 | ||
| 88 | The left is bicubic interpolation image, the middle is high resolution image, and | |
| 89 | the right is super resolution image(output of the SRGAN). | |
| 90 | ||
| 91 | - BSD100_070(PSNR:32.4517; SSIM:0.9191) | |
| 92 | ||
| 93 |  | |
| 94 | ||
| 95 | - Set14_005(PSNR:26.9171; SSIM:0.9119) | |
| 96 | ||
| 97 |  | |
| 98 | ||
| 99 | - Set14_013(PSNR:30.8040; SSIM:0.9651) | |
| 100 | ||
| 101 |  | |
| 102 | ||
| 103 | - Urban100_098(PSNR:24.3765; SSIM:0.7855) | |
| 104 | ||
| 105 |  | |
| 106 | ||
| 107 | > Video Results | |
| 108 | ||
| 109 | The left is bicubic interpolation video, the right is super resolution video(output of the SRGAN). | |
| 110 | ||
| 111 | [](https://youtu.be/05vx-vOJOZs) | |
| 112 | ||
| 113 | **Upscale Factor = 4** | |
| 114 | ||
| 115 | Epochs with batch size of 64 takes ~4 minute 30 seconds on a NVIDIA GTX 1080Ti GPU. | |
| 116 | ||
| 117 | > Image Results | |
| 118 | ||
| 119 | The left is bicubic interpolation image, the middle is high resolution image, and | |
| 120 | the right is super resolution image(output of the SRGAN). | |
| 121 | ||
| 122 | - BSD100_035(PSNR:32.3980; SSIM:0.8512) | |
| 123 | ||
| 124 |  | |
| 125 | ||
| 126 | - Set14_011(PSNR:29.5944; SSIM:0.9044) | |
| 127 | ||
| 128 |  | |
| 129 | ||
| 130 | - Set14_014(PSNR:25.1299; SSIM:0.7406) | |
| 131 | ||
| 132 |  | |
| 133 | ||
| 134 | - Urban100_060(PSNR:20.7129; SSIM:0.5263) | |
| 135 | ||
| 136 |  | |
| 137 | ||
| 138 | > Video Results | |
| 139 | ||
| 140 | The left is bicubic interpolation video, the right is super resolution video(output of the SRGAN). | |
| 141 | ||
| 142 | [](https://youtu.be/tNR2eiMeoQs) | |
| 143 | ||
| 144 | **Upscale Factor = 8** | |
| 145 | ||
| 146 | Epochs with batch size of 64 takes ~3 minute 30 seconds on a NVIDIA GTX 1080Ti GPU. | |
| 147 | ||
| 148 | > Image Results | |
| 149 | ||
| 150 | The left is bicubic interpolation image, the middle is high resolution image, and | |
| 151 | the right is super resolution image(output of the SRGAN). | |
| 152 | ||
| 153 | - SunHays80_027(PSNR:29.4941; SSIM:0.8082) | |
| 154 | ||
| 155 |  | |
| 156 | ||
| 157 | - SunHays80_035(PSNR:32.1546; SSIM:0.8449) | |
| 158 | ||
| 159 |  | |
| 160 | ||
| 161 | - SunHays80_043(PSNR:30.9716; SSIM:0.8789) | |
| 162 | ||
| 163 |  | |
| 164 | ||
| 165 | - SunHays80_078(PSNR:31.9351; SSIM:0.8381) | |
| 166 | ||
| 167 |  | |
| 168 | ||
| 169 | > Video Results | |
| 170 | ||
| 171 | The left is bicubic interpolation video, the right is super resolution video(output of the SRGAN). | |
| 172 | ||
| 173 | [](https://youtu.be/EuvXTKCRr8I) | |
| 174 | ||
| 175 | The complete test results could be downloaded from [here](https://pan.baidu.com/s/1pAaEgAQ4aRZbtKo8hNrb4Q)(access code:wnvn). | |
| 176 |
| 0 | ## 介绍 (Introduction) | |
| 1 | ||
| 2 | 添加该项目的功能、使用场景和输入输出参数等相关信息。 | |
| 3 | ||
| 4 | You can describe the function, usage and parameters of the project.⏎ |
| 0 | { | |
| 1 | "cells": [ | |
| 2 | { | |
| 3 | "cell_type": "markdown", | |
| 4 | "metadata": {}, | |
| 5 | "source": [ | |
| 6 | "## 1. 项目介绍\n", | |
| 7 | "\n", | |
| 8 | " - 项目是由模块组成、有特定功能的程序。它能够满足用户的直接使用需求,例如[古诗词生成器](https://momodel.cn/explore/5bfb634e1afd943c623dd9cf?type=app&tab=1)、[风格迁移](https://momodel.cn/explore/5bfb634e1afd943c623dd9cf?type=app&tab=1)等。\n", | |
| 9 | " - 开发项目过程中你可以导入数据集,也可以通过每个 cell 上方工具栏的`<+>`直接插入[模块](https://momodel.cn/modules)和代码块。\n", | |
| 10 | " - 你可以将开发好的项目进行[部署](https://momodel.cn/docs/#/zh-cn/%E5%BC%80%E5%8F%91%E5%92%8C%E9%83%A8%E7%BD%B2%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8%EF%BC%88app%EF%BC%89),项目部署成功并选择正式版本发布后会展示在“项目”页面,用户可以在线使用,也可以通过 API 调用。\n", | |
| 11 | "\n", | |
| 12 | " - 项目目录结构:\n", | |
| 13 | "\n", | |
| 14 | " - ```results```*-----结果的文件存放地(如果你运行 job,务必将运行结果指定在此目录)*\n", | |
| 15 | " - ```_OVERVIEW.md``` *-----项目的相关介绍*\n", | |
| 16 | " - ```_README.md```*-----说明文档*\n", | |
| 17 | " - ```app_spec.yml```*-----定义项目的输入输出,为部署服务*\n", | |
| 18 | " - ```coding_here.ipynb```*-----输入并运行代码*" | |
| 19 | ] | |
| 20 | }, | |
| 21 | { | |
| 22 | "cell_type": "markdown", | |
| 23 | "metadata": {}, | |
| 24 | "source": [ | |
| 25 | "\n", | |
| 26 | "## 2. 开发环境简介\n", | |
| 27 | "\n", | |
| 28 | "你当前所在的页面 Notebook 是一个内嵌 JupyterLab 的在线类 IDE 编程环境,开发过程中可以使用页面右侧的 API 文档进行快速查询。Notebook 有以下主要功能:\n", | |
| 29 | "\n", | |
| 30 | "- [调用数据集、模块和代码块资源](https://momodel.cn/docs/#/zh-cn/%E5%A6%82%E4%BD%95%E5%AF%BC%E5%85%A5%E5%B9%B6%E4%BD%BF%E7%94%A8%E6%A8%A1%E5%9D%97%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86)\n", | |
| 31 | "- [多人代码协作](https://momodel.cn/docs/#/zh-cn/%E5%9C%A8Mo%E8%BF%90%E8%A1%8C%E4%BD%A0%E7%9A%84%E7%AC%AC%E4%B8%80%E6%AE%B5%E4%BB%A3%E7%A0%81?id=_7-%e4%bd%a0%e5%8f%af%e4%bb%a5%e9%82%80%e8%af%b7%e5%a5%bd%e5%8f%8b%e8%bf%9b%e8%a1%8c%e5%8d%8f%e4%bd%9c)\n", | |
| 32 | "- [在 GPU 资源上训练机器学习模型](https://momodel.cn/docs/#/zh-cn/%E5%9C%A8GPU%E6%88%96CPU%E8%B5%84%E6%BA%90%E4%B8%8A%E8%AE%AD%E7%BB%83%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E6%A8%A1%E5%9E%8B)\n", | |
| 33 | "- [简单部署](https://momodel.cn/docs/#/zh-cn/%E5%BC%80%E5%8F%91%E5%92%8C%E9%83%A8%E7%BD%B2%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8%EF%BC%88app%EF%BC%89)\n", | |
| 34 | "\n", | |
| 35 | "快来动手试试吧!点击左侧工具栏的新建文件图标即可选择你需要的文件类型。\n", | |
| 36 | "\n", | |
| 37 | "<img src='https://imgbed.momodel.cn/006tNc79gy1g61agfcv23j31c30u0789.jpg' width=100% height=100%>\n", | |
| 38 | "\n", | |
| 39 | "\n", | |
| 40 | "\n", | |
| 41 | "左侧和右侧工具栏都可根据使用需要进行收合。\n", | |
| 42 | "<img src='https://imgbed.momodel.cn/collapse_tab.2019-09-06 11_07_44.gif' width=100% height=100%>" | |
| 43 | ] | |
| 44 | }, | |
| 45 | { | |
| 46 | "cell_type": "markdown", | |
| 47 | "metadata": {}, | |
| 48 | "source": [ | |
| 49 | "## 3. 快捷键与代码补全\n", | |
| 50 | "Mo Notebook 已完全采用 Jupyter Notebook 的原生快捷键,并且支持 `tab` 代码补全。\n", | |
| 51 | "\n", | |
| 52 | "运行代码:`shift` + `enter` 或者 `shift` + `return`" | |
| 53 | ] | |
| 54 | }, | |
| 55 | { | |
| 56 | "cell_type": "markdown", | |
| 57 | "metadata": {}, | |
| 58 | "source": [ | |
| 59 | "## 4. 常用指令介绍\n", | |
| 60 | "\n", | |
| 61 | "- 解压上传后的文件\n", | |
| 62 | "\n", | |
| 63 | "在 cell 中输入并运行以下命令:\n", | |
| 64 | "```!unzip -o file_name.zip```\n", | |
| 65 | "\n", | |
| 66 | "- 查看所有包(package)\n", | |
| 67 | "\n", | |
| 68 | "`!pip list --format=columns`\n", | |
| 69 | "\n", | |
| 70 | "- 检查是否已有某个包\n", | |
| 71 | "\n", | |
| 72 | "`!pip show package_name`\n", | |
| 73 | "\n", | |
| 74 | "- 安装缺失的包\n", | |
| 75 | "\n", | |
| 76 | "`!pip install package_name`\n", | |
| 77 | "\n", | |
| 78 | "- 更新已有的包\n", | |
| 79 | "\n", | |
| 80 | "`!pip install package_name --upgrade`\n", | |
| 81 | "\n", | |
| 82 | "\n", | |
| 83 | "- 使用包\n", | |
| 84 | "\n", | |
| 85 | "`import package_name`\n", | |
| 86 | "\n", | |
| 87 | "- 显示当前目录下的档案及目录\n", | |
| 88 | "\n", | |
| 89 | "`ls`\n", | |
| 90 | "\n", | |
| 91 | "- 使用引入的数据集\n", | |
| 92 | "\n", | |
| 93 | "数据集被引入后存放在 datasets 文件夹下,注意,这个文件夹是只读的,不可修改。如果需要修改,可在 Notebook 中使用\n", | |
| 94 | "\n", | |
| 95 | "`!cp -R ./datasets/<imported_dataset_dir> ./<your_folder>`\n", | |
| 96 | "\n", | |
| 97 | "指令将其复制到其他文件夹后再编辑,对于引入的数据集中的 zip 文件,可使用\n", | |
| 98 | "\n", | |
| 99 | "`!unzip ./datasets/<imported_dataset_dir>/<XXX.zip> -d ./<your_folder>`\n", | |
| 100 | "\n", | |
| 101 | "指令解压缩到其他文件夹后使用" | |
| 102 | ] | |
| 103 | }, | |
| 104 | { | |
| 105 | "cell_type": "markdown", | |
| 106 | "metadata": {}, | |
| 107 | "source": [ | |
| 108 | "## 5. 其他可参考资源\n", | |
| 109 | "- [帮助文档](https://momodel.cn/docs/#/):基本页面介绍和常见问题都可以在里面找到\n", | |
| 110 | "- [平台功能教程](https://momodel.cn/classroom/class?id=5c5696cd1afd9458d456bf54&type=doc):通过图文结合的 Notebook 详细介绍开发环境基本功能和操作\n", | |
| 111 | "- [吴恩达机器学习](https://momodel.cn/classroom/class?id=5c5696191afd94720cc94533&type=video):机器学习经典课程\n", | |
| 112 | "- [李宏毅机器学习](https://s.momodel.cn/classroom/class?id=5d40fdafb5113408a8dbb4a1&type=video):中文世界最好的机器学习课程\n", | |
| 113 | "- [机器学习实战](https://momodel.cn/classroom/class?id=5c680b311afd943a9f70901b&type=practice):通过实操指引完成独立的模型,掌握相应的机器学习知识\n", | |
| 114 | "- [Python 教程](https://momodel.cn/classroom/class?id=5d1f3ab81afd940ab7d298bf&type=notebook):简单易懂的 Python 新手教程\n", | |
| 115 | "- [模块开发](https://momodel.cn/modules):关于模型训练、开发与部署的高阶教程" | |
| 116 | ] | |
| 117 | } | |
| 118 | ], | |
| 119 | "metadata": { | |
| 120 | "kernelspec": { | |
| 121 | "display_name": "Python 3", | |
| 122 | "language": "python", | |
| 123 | "name": "python3" | |
| 124 | }, | |
| 125 | "language_info": { | |
| 126 | "codemirror_mode": { | |
| 127 | "name": "ipython", | |
| 128 | "version": 3 | |
| 129 | }, | |
| 130 | "file_extension": ".py", | |
| 131 | "mimetype": "text/x-python", | |
| 132 | "name": "python", | |
| 133 | "nbconvert_exporter": "python", | |
| 134 | "pygments_lexer": "ipython3", | |
| 135 | "version": "3.5.2" | |
| 136 | }, | |
| 137 | "pycharm": { | |
| 138 | "stem_cell": { | |
| 139 | "cell_type": "raw", | |
| 140 | "source": [], | |
| 141 | "metadata": { | |
| 142 | "collapsed": false | |
| 143 | } | |
| 144 | } | |
| 145 | } | |
| 146 | }, | |
| 147 | "nbformat": 4, | |
| 148 | "nbformat_minor": 2 | |
| 149 | }⏎ |
| 0 | ||
| 1 | { | |
| 2 | "cells": [ | |
| 3 | { | |
| 4 | "cell_type": "code", | |
| 5 | "execution_count": null, | |
| 6 | "metadata": {}, | |
| 7 | "outputs": [], | |
| 8 | "source": [ | |
| 9 | "print('Hello Mo!')" | |
| 10 | ] | |
| 11 | } | |
| 12 | ], | |
| 13 | "metadata": { | |
| 14 | "kernelspec": { | |
| 15 | "display_name": "Python 3", | |
| 16 | "language": "python", | |
| 17 | "name": "python3" | |
| 18 | }, | |
| 19 | "language_info": { | |
| 20 | "codemirror_mode": { | |
| 21 | "name": "ipython", | |
| 22 | "version": 3 | |
| 23 | }, | |
| 24 | "file_extension": ".py", | |
| 25 | "mimetype": "text/x-python", | |
| 26 | "name": "python", | |
| 27 | "nbconvert_exporter": "python", | |
| 28 | "pygments_lexer": "ipython3", | |
| 29 | "version": "3.5.2" | |
| 30 | } | |
| 31 | }, | |
| 32 | "nbformat": 4, | |
| 33 | "nbformat_minor": 2 | |
| 34 | } | |
| 35 | ⏎ |
| 0 | from os import listdir | |
| 1 | from os.path import join | |
| 2 | ||
| 3 | from PIL import Image | |
| 4 | from torch.utils.data.dataset import Dataset | |
| 5 | from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize | |
| 6 | ||
| 7 | ||
| 8 | def is_image_file(filename): | |
| 9 | return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']) | |
| 10 | ||
| 11 | ||
| 12 | def calculate_valid_crop_size(crop_size, upscale_factor): | |
| 13 | return crop_size - (crop_size % upscale_factor) | |
| 14 | ||
| 15 | ||
| 16 | def train_hr_transform(crop_size): | |
| 17 | return Compose([ | |
| 18 | RandomCrop(crop_size), | |
| 19 | ToTensor(), | |
| 20 | ]) | |
| 21 | ||
| 22 | ||
| 23 | def train_lr_transform(crop_size, upscale_factor): | |
| 24 | return Compose([ | |
| 25 | ToPILImage(), | |
| 26 | Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC), | |
| 27 | ToTensor() | |
| 28 | ]) | |
| 29 | ||
| 30 | ||
| 31 | def display_transform(): | |
| 32 | return Compose([ | |
| 33 | ToPILImage(), | |
| 34 | Resize(400), | |
| 35 | CenterCrop(400), | |
| 36 | ToTensor() | |
| 37 | ]) | |
| 38 | ||
| 39 | ||
| 40 | class TrainDatasetFromFolder(Dataset): | |
| 41 | def __init__(self, dataset_dir, crop_size, upscale_factor): | |
| 42 | super(TrainDatasetFromFolder, self).__init__() | |
| 43 | self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)] | |
| 44 | crop_size = calculate_valid_crop_size(crop_size, upscale_factor) | |
| 45 | self.hr_transform = train_hr_transform(crop_size) | |
| 46 | self.lr_transform = train_lr_transform(crop_size, upscale_factor) | |
| 47 | ||
| 48 | def __getitem__(self, index): | |
| 49 | hr_image = self.hr_transform(Image.open(self.image_filenames[index])) | |
| 50 | lr_image = self.lr_transform(hr_image) | |
| 51 | return lr_image, hr_image | |
| 52 | ||
| 53 | def __len__(self): | |
| 54 | return len(self.image_filenames) | |
| 55 | ||
| 56 | ||
| 57 | class ValDatasetFromFolder(Dataset): | |
| 58 | def __init__(self, dataset_dir, upscale_factor): | |
| 59 | super(ValDatasetFromFolder, self).__init__() | |
| 60 | self.upscale_factor = upscale_factor | |
| 61 | self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)] | |
| 62 | ||
| 63 | def __getitem__(self, index): | |
| 64 | hr_image = Image.open(self.image_filenames[index]) | |
| 65 | w, h = hr_image.size | |
| 66 | crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor) | |
| 67 | lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC) | |
| 68 | hr_scale = Resize(crop_size, interpolation=Image.BICUBIC) | |
| 69 | hr_image = CenterCrop(crop_size)(hr_image) | |
| 70 | lr_image = lr_scale(hr_image) | |
| 71 | hr_restore_img = hr_scale(lr_image) | |
| 72 | return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image) | |
| 73 | ||
| 74 | def __len__(self): | |
| 75 | return len(self.image_filenames) | |
| 76 | ||
| 77 | ||
| 78 | class TestDatasetFromFolder(Dataset): | |
| 79 | def __init__(self, dataset_dir, upscale_factor): | |
| 80 | super(TestDatasetFromFolder, self).__init__() | |
| 81 | self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/' | |
| 82 | self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/' | |
| 83 | self.upscale_factor = upscale_factor | |
| 84 | self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)] | |
| 85 | self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)] | |
| 86 | ||
| 87 | def __getitem__(self, index): | |
| 88 | image_name = self.lr_filenames[index].split('/')[-1] | |
| 89 | lr_image = Image.open(self.lr_filenames[index]) | |
| 90 | w, h = lr_image.size | |
| 91 | hr_image = Image.open(self.hr_filenames[index]) | |
| 92 | hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=Image.BICUBIC) | |
| 93 | hr_restore_img = hr_scale(lr_image) | |
| 94 | return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image) | |
| 95 | ||
| 96 | def __len__(self): | |
| 97 | return len(self.lr_filenames) |
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
| 0 | import torch | |
| 1 | from torch import nn | |
| 2 | from torchvision.models.vgg import vgg16 | |
| 3 | ||
| 4 | ||
| 5 | class GeneratorLoss(nn.Module): | |
| 6 | def __init__(self): | |
| 7 | super(GeneratorLoss, self).__init__() | |
| 8 | vgg = vgg16(pretrained=True) | |
| 9 | loss_network = nn.Sequential(*list(vgg.features)[:31]).eval() | |
| 10 | for param in loss_network.parameters(): | |
| 11 | param.requires_grad = False | |
| 12 | self.loss_network = loss_network | |
| 13 | self.mse_loss = nn.MSELoss() | |
| 14 | self.tv_loss = TVLoss() | |
| 15 | ||
| 16 | def forward(self, out_labels, out_images, target_images): | |
| 17 | # Adversarial Loss | |
| 18 | adversarial_loss = torch.mean(1 - out_labels) | |
| 19 | # Perception Loss | |
| 20 | perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images)) | |
| 21 | # Image Loss | |
| 22 | image_loss = self.mse_loss(out_images, target_images) | |
| 23 | # TV Loss | |
| 24 | tv_loss = self.tv_loss(out_images) | |
| 25 | return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss | |
| 26 | ||
| 27 | ||
| 28 | class TVLoss(nn.Module): | |
| 29 | def __init__(self, tv_loss_weight=1): | |
| 30 | super(TVLoss, self).__init__() | |
| 31 | self.tv_loss_weight = tv_loss_weight | |
| 32 | ||
| 33 | def forward(self, x): | |
| 34 | batch_size = x.size()[0] | |
| 35 | h_x = x.size()[2] | |
| 36 | w_x = x.size()[3] | |
| 37 | count_h = self.tensor_size(x[:, :, 1:, :]) | |
| 38 | count_w = self.tensor_size(x[:, :, :, 1:]) | |
| 39 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() | |
| 40 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() | |
| 41 | return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size | |
| 42 | ||
| 43 | @staticmethod | |
| 44 | def tensor_size(t): | |
| 45 | return t.size()[1] * t.size()[2] * t.size()[3] | |
| 46 | ||
| 47 | ||
| 48 | if __name__ == "__main__": | |
| 49 | g_loss = GeneratorLoss() | |
| 50 | print(g_loss) |
| 0 | import math | |
| 1 | import torch | |
| 2 | from torch import nn | |
| 3 | ||
| 4 | ||
| 5 | class Generator(nn.Module): | |
| 6 | def __init__(self, scale_factor): | |
| 7 | upsample_block_num = int(math.log(scale_factor, 2)) | |
| 8 | ||
| 9 | super(Generator, self).__init__() | |
| 10 | self.block1 = nn.Sequential( | |
| 11 | nn.Conv2d(3, 64, kernel_size=9, padding=4), | |
| 12 | nn.PReLU() | |
| 13 | ) | |
| 14 | self.block2 = ResidualBlock(64) | |
| 15 | self.block3 = ResidualBlock(64) | |
| 16 | self.block4 = ResidualBlock(64) | |
| 17 | self.block5 = ResidualBlock(64) | |
| 18 | self.block6 = ResidualBlock(64) | |
| 19 | self.block7 = nn.Sequential( | |
| 20 | nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| 21 | nn.BatchNorm2d(64) | |
| 22 | ) | |
| 23 | block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)] | |
| 24 | block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4)) | |
| 25 | self.block8 = nn.Sequential(*block8) | |
| 26 | ||
| 27 | def forward(self, x): | |
| 28 | block1 = self.block1(x) | |
| 29 | block2 = self.block2(block1) | |
| 30 | block3 = self.block3(block2) | |
| 31 | block4 = self.block4(block3) | |
| 32 | block5 = self.block5(block4) | |
| 33 | block6 = self.block6(block5) | |
| 34 | block7 = self.block7(block6) | |
| 35 | block8 = self.block8(block1 + block7) | |
| 36 | ||
| 37 | return (torch.tanh(block8) + 1) / 2 | |
| 38 | ||
| 39 | ||
| 40 | class Discriminator(nn.Module): | |
| 41 | def __init__(self): | |
| 42 | super(Discriminator, self).__init__() | |
| 43 | self.net = nn.Sequential( | |
| 44 | nn.Conv2d(3, 64, kernel_size=3, padding=1), | |
| 45 | nn.LeakyReLU(0.2), | |
| 46 | ||
| 47 | nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), | |
| 48 | nn.BatchNorm2d(64), | |
| 49 | nn.LeakyReLU(0.2), | |
| 50 | ||
| 51 | nn.Conv2d(64, 128, kernel_size=3, padding=1), | |
| 52 | nn.BatchNorm2d(128), | |
| 53 | nn.LeakyReLU(0.2), | |
| 54 | ||
| 55 | nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), | |
| 56 | nn.BatchNorm2d(128), | |
| 57 | nn.LeakyReLU(0.2), | |
| 58 | ||
| 59 | nn.Conv2d(128, 256, kernel_size=3, padding=1), | |
| 60 | nn.BatchNorm2d(256), | |
| 61 | nn.LeakyReLU(0.2), | |
| 62 | ||
| 63 | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), | |
| 64 | nn.BatchNorm2d(256), | |
| 65 | nn.LeakyReLU(0.2), | |
| 66 | ||
| 67 | nn.Conv2d(256, 512, kernel_size=3, padding=1), | |
| 68 | nn.BatchNorm2d(512), | |
| 69 | nn.LeakyReLU(0.2), | |
| 70 | ||
| 71 | nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), | |
| 72 | nn.BatchNorm2d(512), | |
| 73 | nn.LeakyReLU(0.2), | |
| 74 | ||
| 75 | nn.AdaptiveAvgPool2d(1), | |
| 76 | nn.Conv2d(512, 1024, kernel_size=1), | |
| 77 | nn.LeakyReLU(0.2), | |
| 78 | nn.Conv2d(1024, 1, kernel_size=1) | |
| 79 | ) | |
| 80 | ||
| 81 | def forward(self, x): | |
| 82 | batch_size = x.size(0) | |
| 83 | return torch.sigmoid(self.net(x).view(batch_size)) | |
| 84 | ||
| 85 | ||
| 86 | class ResidualBlock(nn.Module): | |
| 87 | def __init__(self, channels): | |
| 88 | super(ResidualBlock, self).__init__() | |
| 89 | self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
| 90 | self.bn1 = nn.BatchNorm2d(channels) | |
| 91 | self.prelu = nn.PReLU() | |
| 92 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | |
| 93 | self.bn2 = nn.BatchNorm2d(channels) | |
| 94 | ||
| 95 | def forward(self, x): | |
| 96 | residual = self.conv1(x) | |
| 97 | residual = self.bn1(residual) | |
| 98 | residual = self.prelu(residual) | |
| 99 | residual = self.conv2(residual) | |
| 100 | residual = self.bn2(residual) | |
| 101 | ||
| 102 | return x + residual | |
| 103 | ||
| 104 | ||
| 105 | class UpsampleBLock(nn.Module): | |
| 106 | def __init__(self, in_channels, up_scale): | |
| 107 | super(UpsampleBLock, self).__init__() | |
| 108 | self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1) | |
| 109 | self.pixel_shuffle = nn.PixelShuffle(up_scale) | |
| 110 | self.prelu = nn.PReLU() | |
| 111 | ||
| 112 | def forward(self, x): | |
| 113 | x = self.conv(x) | |
| 114 | x = self.pixel_shuffle(x) | |
| 115 | x = self.prelu(x) | |
| 116 | return x |
| 0 | from math import exp | |
| 1 | ||
| 2 | import torch | |
| 3 | import torch.nn.functional as F | |
| 4 | from torch.autograd import Variable | |
| 5 | ||
| 6 | ||
| 7 | def gaussian(window_size, sigma): | |
| 8 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) | |
| 9 | return gauss / gauss.sum() | |
| 10 | ||
| 11 | ||
| 12 | def create_window(window_size, channel): | |
| 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) | |
| 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |
| 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) | |
| 16 | return window | |
| 17 | ||
| 18 | ||
| 19 | def _ssim(img1, img2, window, window_size, channel, size_average=True): | |
| 20 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) | |
| 21 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) | |
| 22 | ||
| 23 | mu1_sq = mu1.pow(2) | |
| 24 | mu2_sq = mu2.pow(2) | |
| 25 | mu1_mu2 = mu1 * mu2 | |
| 26 | ||
| 27 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq | |
| 28 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq | |
| 29 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 | |
| 30 | ||
| 31 | C1 = 0.01 ** 2 | |
| 32 | C2 = 0.03 ** 2 | |
| 33 | ||
| 34 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |
| 35 | ||
| 36 | if size_average: | |
| 37 | return ssim_map.mean() | |
| 38 | else: | |
| 39 | return ssim_map.mean(1).mean(1).mean(1) | |
| 40 | ||
| 41 | ||
| 42 | class SSIM(torch.nn.Module): | |
| 43 | def __init__(self, window_size=11, size_average=True): | |
| 44 | super(SSIM, self).__init__() | |
| 45 | self.window_size = window_size | |
| 46 | self.size_average = size_average | |
| 47 | self.channel = 1 | |
| 48 | self.window = create_window(window_size, self.channel) | |
| 49 | ||
| 50 | def forward(self, img1, img2): | |
| 51 | (_, channel, _, _) = img1.size() | |
| 52 | ||
| 53 | if channel == self.channel and self.window.data.type() == img1.data.type(): | |
| 54 | window = self.window | |
| 55 | else: | |
| 56 | window = create_window(self.window_size, channel) | |
| 57 | ||
| 58 | if img1.is_cuda: | |
| 59 | window = window.cuda(img1.get_device()) | |
| 60 | window = window.type_as(img1) | |
| 61 | ||
| 62 | self.window = window | |
| 63 | self.channel = channel | |
| 64 | ||
| 65 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) | |
| 66 | ||
| 67 | ||
| 68 | def ssim(img1, img2, window_size=11, size_average=True): | |
| 69 | (_, channel, _, _) = img1.size() | |
| 70 | window = create_window(window_size, channel) | |
| 71 | ||
| 72 | if img1.is_cuda: | |
| 73 | window = window.cuda(img1.get_device()) | |
| 74 | window = window.type_as(img1) | |
| 75 | ||
| 76 | return _ssim(img1, img2, window, window_size, channel, size_average) |
| 0 | import argparse | |
| 1 | import os | |
| 2 | from math import log10 | |
| 3 | ||
| 4 | import numpy as np | |
| 5 | import pandas as pd | |
| 6 | import torch | |
| 7 | import torchvision.utils as utils | |
| 8 | from torch.autograd import Variable | |
| 9 | from torch.utils.data import DataLoader | |
| 10 | from tqdm import tqdm | |
| 11 | ||
| 12 | import pytorch_ssim | |
| 13 | from data_utils import TestDatasetFromFolder, display_transform | |
| 14 | from model import Generator | |
| 15 | ||
| 16 | parser = argparse.ArgumentParser(description='Test Benchmark Datasets') | |
| 17 | parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor') | |
| 18 | parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name') | |
| 19 | opt = parser.parse_args() | |
| 20 | ||
| 21 | UPSCALE_FACTOR = opt.upscale_factor | |
| 22 | MODEL_NAME = opt.model_name | |
| 23 | ||
| 24 | results = {'Set5': {'psnr': [], 'ssim': []}, 'Set14': {'psnr': [], 'ssim': []}, 'BSD100': {'psnr': [], 'ssim': []}, | |
| 25 | 'Urban100': {'psnr': [], 'ssim': []}, 'SunHays80': {'psnr': [], 'ssim': []}} | |
| 26 | ||
| 27 | model = Generator(UPSCALE_FACTOR).eval() | |
| 28 | if torch.cuda.is_available(): | |
| 29 | model = model.cuda() | |
| 30 | model.load_state_dict(torch.load('epochs/' + MODEL_NAME)) | |
| 31 | ||
| 32 | test_set = TestDatasetFromFolder('data/test', upscale_factor=UPSCALE_FACTOR) | |
| 33 | test_loader = DataLoader(dataset=test_set, num_workers=4, batch_size=1, shuffle=False) | |
| 34 | test_bar = tqdm(test_loader, desc='[testing benchmark datasets]') | |
| 35 | ||
| 36 | out_path = 'benchmark_results/SRF_' + str(UPSCALE_FACTOR) + '/' | |
| 37 | if not os.path.exists(out_path): | |
| 38 | os.makedirs(out_path) | |
| 39 | ||
| 40 | for image_name, lr_image, hr_restore_img, hr_image in test_bar: | |
| 41 | image_name = image_name[0] | |
| 42 | lr_image = Variable(lr_image, volatile=True) | |
| 43 | hr_image = Variable(hr_image, volatile=True) | |
| 44 | if torch.cuda.is_available(): | |
| 45 | lr_image = lr_image.cuda() | |
| 46 | hr_image = hr_image.cuda() | |
| 47 | ||
| 48 | sr_image = model(lr_image) | |
| 49 | mse = ((hr_image - sr_image) ** 2).data.mean() | |
| 50 | psnr = 10 * log10(1 / mse) | |
| 51 | ssim = pytorch_ssim.ssim(sr_image, hr_image).data[0] | |
| 52 | ||
| 53 | test_images = torch.stack( | |
| 54 | [display_transform()(hr_restore_img.squeeze(0)), display_transform()(hr_image.data.cpu().squeeze(0)), | |
| 55 | display_transform()(sr_image.data.cpu().squeeze(0))]) | |
| 56 | image = utils.make_grid(test_images, nrow=3, padding=5) | |
| 57 | utils.save_image(image, out_path + image_name.split('.')[0] + '_psnr_%.4f_ssim_%.4f.' % (psnr, ssim) + | |
| 58 | image_name.split('.')[-1], padding=5) | |
| 59 | ||
| 60 | # save psnr\ssim | |
| 61 | results[image_name.split('_')[0]]['psnr'].append(psnr) | |
| 62 | results[image_name.split('_')[0]]['ssim'].append(ssim) | |
| 63 | ||
| 64 | out_path = 'statistics/' | |
| 65 | saved_results = {'psnr': [], 'ssim': []} | |
| 66 | for item in results.values(): | |
| 67 | psnr = np.array(item['psnr']) | |
| 68 | ssim = np.array(item['ssim']) | |
| 69 | if (len(psnr) == 0) or (len(ssim) == 0): | |
| 70 | psnr = 'No data' | |
| 71 | ssim = 'No data' | |
| 72 | else: | |
| 73 | psnr = psnr.mean() | |
| 74 | ssim = ssim.mean() | |
| 75 | saved_results['psnr'].append(psnr) | |
| 76 | saved_results['ssim'].append(ssim) | |
| 77 | ||
| 78 | data_frame = pd.DataFrame(saved_results, results.keys()) | |
| 79 | data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_test_results.csv', index_label='DataSet') |
| 0 | import argparse | |
| 1 | import time | |
| 2 | ||
| 3 | import torch | |
| 4 | from PIL import Image | |
| 5 | from torch.autograd import Variable | |
| 6 | from torchvision.transforms import ToTensor, ToPILImage | |
| 7 | ||
| 8 | from model import Generator | |
| 9 | ||
| 10 | parser = argparse.ArgumentParser(description='Test Single Image') | |
| 11 | parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor') | |
| 12 | parser.add_argument('--test_mode', default='GPU', type=str, choices=['GPU', 'CPU'], help='using GPU or CPU') | |
| 13 | parser.add_argument('--image_name', type=str, help='test low resolution image name') | |
| 14 | parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name') | |
| 15 | opt = parser.parse_args() | |
| 16 | ||
| 17 | UPSCALE_FACTOR = opt.upscale_factor | |
| 18 | TEST_MODE = True if opt.test_mode == 'GPU' else False | |
| 19 | IMAGE_NAME = opt.image_name | |
| 20 | MODEL_NAME = opt.model_name | |
| 21 | ||
| 22 | model = Generator(UPSCALE_FACTOR).eval() | |
| 23 | if TEST_MODE: | |
| 24 | model.cuda() | |
| 25 | model.load_state_dict(torch.load('epochs/' + MODEL_NAME)) | |
| 26 | else: | |
| 27 | model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage)) | |
| 28 | ||
| 29 | image = Image.open(IMAGE_NAME) | |
| 30 | image = Variable(ToTensor()(image), volatile=True).unsqueeze(0) | |
| 31 | if TEST_MODE: | |
| 32 | image = image.cuda() | |
| 33 | ||
| 34 | start = time.clock() | |
| 35 | out = model(image) | |
| 36 | elapsed = (time.clock() - start) | |
| 37 | print('cost' + str(elapsed) + 's') | |
| 38 | out_img = ToPILImage()(out[0].data.cpu()) | |
| 39 | out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME) |
| 0 | import argparse | |
| 1 | ||
| 2 | import cv2 | |
| 3 | import numpy as np | |
| 4 | import torch | |
| 5 | import torchvision.transforms as transforms | |
| 6 | from PIL import Image | |
| 7 | from torch.autograd import Variable | |
| 8 | from torchvision.transforms import ToTensor, ToPILImage | |
| 9 | from tqdm import tqdm | |
| 10 | ||
| 11 | from model import Generator | |
| 12 | ||
| 13 | if __name__ == "__main__": | |
| 14 | parser = argparse.ArgumentParser(description='Test Single Video') | |
| 15 | parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor') | |
| 16 | parser.add_argument('--video_name', type=str, help='test low resolution video name') | |
| 17 | parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name') | |
| 18 | opt = parser.parse_args() | |
| 19 | ||
| 20 | UPSCALE_FACTOR = opt.upscale_factor | |
| 21 | VIDEO_NAME = opt.video_name | |
| 22 | MODEL_NAME = opt.model_name | |
| 23 | ||
| 24 | model = Generator(UPSCALE_FACTOR).eval() | |
| 25 | if torch.cuda.is_available(): | |
| 26 | model = model.cuda() | |
| 27 | # for cpu | |
| 28 | # model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage)) | |
| 29 | model.load_state_dict(torch.load('epochs/' + MODEL_NAME)) | |
| 30 | ||
| 31 | videoCapture = cv2.VideoCapture(VIDEO_NAME) | |
| 32 | fps = videoCapture.get(cv2.CAP_PROP_FPS) | |
| 33 | frame_numbers = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) | |
| 34 | sr_video_size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR), | |
| 35 | int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR) | |
| 36 | compared_video_size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR * 2 + 10), | |
| 37 | int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR + 10 + int( | |
| 38 | int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR * 2 + 10) / int( | |
| 39 | 10 * int(int( | |
| 40 | videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR) // 5 + 1)) * int( | |
| 41 | int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR) // 5 - 9))) | |
| 42 | output_sr_name = 'out_srf_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.avi' | |
| 43 | output_compared_name = 'compare_srf_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.avi' | |
| 44 | sr_video_writer = cv2.VideoWriter(output_sr_name, cv2.VideoWriter_fourcc('M', 'P', 'E', 'G'), fps, sr_video_size) | |
| 45 | compared_video_writer = cv2.VideoWriter(output_compared_name, cv2.VideoWriter_fourcc('M', 'P', 'E', 'G'), fps, | |
| 46 | compared_video_size) | |
| 47 | # read frame | |
| 48 | success, frame = videoCapture.read() | |
| 49 | test_bar = tqdm(range(int(frame_numbers)), desc='[processing video and saving result videos]') | |
| 50 | for index in test_bar: | |
| 51 | if success: | |
| 52 | image = Variable(ToTensor()(frame), volatile=True).unsqueeze(0) | |
| 53 | if torch.cuda.is_available(): | |
| 54 | image = image.cuda() | |
| 55 | ||
| 56 | out = model(image) | |
| 57 | out = out.cpu() | |
| 58 | out_img = out.data[0].numpy() | |
| 59 | out_img *= 255.0 | |
| 60 | out_img = (np.uint8(out_img)).transpose((1, 2, 0)) | |
| 61 | # save sr video | |
| 62 | sr_video_writer.write(out_img) | |
| 63 | ||
| 64 | # make compared video and crop shot of left top\right top\center\left bottom\right bottom | |
| 65 | out_img = ToPILImage()(out_img) | |
| 66 | crop_out_imgs = transforms.FiveCrop(size=out_img.width // 5 - 9)(out_img) | |
| 67 | crop_out_imgs = [np.asarray(transforms.Pad(padding=(10, 5, 0, 0))(img)) for img in crop_out_imgs] | |
| 68 | out_img = transforms.Pad(padding=(5, 0, 0, 5))(out_img) | |
| 69 | compared_img = transforms.Resize(size=(sr_video_size[1], sr_video_size[0]), interpolation=Image.BICUBIC)( | |
| 70 | ToPILImage()(frame)) | |
| 71 | crop_compared_imgs = transforms.FiveCrop(size=compared_img.width // 5 - 9)(compared_img) | |
| 72 | crop_compared_imgs = [np.asarray(transforms.Pad(padding=(0, 5, 10, 0))(img)) for img in crop_compared_imgs] | |
| 73 | compared_img = transforms.Pad(padding=(0, 0, 5, 5))(compared_img) | |
| 74 | # concatenate all the pictures to one single picture | |
| 75 | top_image = np.concatenate((np.asarray(compared_img), np.asarray(out_img)), axis=1) | |
| 76 | bottom_image = np.concatenate(crop_compared_imgs + crop_out_imgs, axis=1) | |
| 77 | bottom_image = np.asarray(transforms.Resize( | |
| 78 | size=(int(top_image.shape[1] / bottom_image.shape[1] * bottom_image.shape[0]), top_image.shape[1]))( | |
| 79 | ToPILImage()(bottom_image))) | |
| 80 | final_image = np.concatenate((top_image, bottom_image)) | |
| 81 | # save compared video | |
| 82 | compared_video_writer.write(final_image) | |
| 83 | # next frame | |
| 84 | success, frame = videoCapture.read() |
| 0 | import argparse | |
| 1 | import os | |
| 2 | from math import log10 | |
| 3 | ||
| 4 | import pandas as pd | |
| 5 | import torch.optim as optim | |
| 6 | import torch.utils.data | |
| 7 | import torchvision.utils as utils | |
| 8 | from torch.autograd import Variable | |
| 9 | from torch.utils.data import DataLoader | |
| 10 | from tqdm import tqdm | |
| 11 | ||
| 12 | import pytorch_ssim | |
| 13 | from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform | |
| 14 | from loss import GeneratorLoss | |
| 15 | from model import Generator, Discriminator | |
| 16 | ||
| 17 | parser = argparse.ArgumentParser(description='Train Super Resolution Models') | |
| 18 | parser.add_argument('--crop_size', default=88, type=int, help='training images crop size') | |
| 19 | parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8], | |
| 20 | help='super resolution upscale factor') | |
| 21 | parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number') | |
| 22 | ||
| 23 | ||
| 24 | if __name__ == '__main__': | |
| 25 | opt = parser.parse_args() | |
| 26 | ||
| 27 | CROP_SIZE = opt.crop_size | |
| 28 | UPSCALE_FACTOR = opt.upscale_factor | |
| 29 | NUM_EPOCHS = opt.num_epochs | |
| 30 | ||
| 31 | train_set = TrainDatasetFromFolder('data/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR) | |
| 32 | val_set = ValDatasetFromFolder('data/DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR) | |
| 33 | train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True) | |
| 34 | val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False) | |
| 35 | ||
| 36 | netG = Generator(UPSCALE_FACTOR) | |
| 37 | print('# generator parameters:', sum(param.numel() for param in netG.parameters())) | |
| 38 | netD = Discriminator() | |
| 39 | print('# discriminator parameters:', sum(param.numel() for param in netD.parameters())) | |
| 40 | ||
| 41 | generator_criterion = GeneratorLoss() | |
| 42 | ||
| 43 | if torch.cuda.is_available(): | |
| 44 | netG.cuda() | |
| 45 | netD.cuda() | |
| 46 | generator_criterion.cuda() | |
| 47 | ||
| 48 | optimizerG = optim.Adam(netG.parameters()) | |
| 49 | optimizerD = optim.Adam(netD.parameters()) | |
| 50 | ||
| 51 | results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []} | |
| 52 | ||
| 53 | for epoch in range(1, NUM_EPOCHS + 1): | |
| 54 | train_bar = tqdm(train_loader) | |
| 55 | running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0} | |
| 56 | ||
| 57 | netG.train() | |
| 58 | netD.train() | |
| 59 | for data, target in train_bar: | |
| 60 | g_update_first = True | |
| 61 | batch_size = data.size(0) | |
| 62 | running_results['batch_sizes'] += batch_size | |
| 63 | ||
| 64 | ############################ | |
| 65 | # (1) Update D network: maximize D(x)-1-D(G(z)) | |
| 66 | ########################### | |
| 67 | real_img = Variable(target) | |
| 68 | if torch.cuda.is_available(): | |
| 69 | real_img = real_img.cuda() | |
| 70 | z = Variable(data) | |
| 71 | if torch.cuda.is_available(): | |
| 72 | z = z.cuda() | |
| 73 | fake_img = netG(z) | |
| 74 | ||
| 75 | netD.zero_grad() | |
| 76 | real_out = netD(real_img).mean() | |
| 77 | fake_out = netD(fake_img).mean() | |
| 78 | d_loss = 1 - real_out + fake_out | |
| 79 | d_loss.backward(retain_graph=True) | |
| 80 | optimizerD.step() | |
| 81 | ||
| 82 | ############################ | |
| 83 | # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss | |
| 84 | ########################### | |
| 85 | netG.zero_grad() | |
| 86 | g_loss = generator_criterion(fake_out, fake_img, real_img) | |
| 87 | g_loss.backward() | |
| 88 | ||
| 89 | fake_img = netG(z) | |
| 90 | fake_out = netD(fake_img).mean() | |
| 91 | ||
| 92 | ||
| 93 | optimizerG.step() | |
| 94 | ||
| 95 | # loss for current batch before optimization | |
| 96 | running_results['g_loss'] += g_loss.item() * batch_size | |
| 97 | running_results['d_loss'] += d_loss.item() * batch_size | |
| 98 | running_results['d_score'] += real_out.item() * batch_size | |
| 99 | running_results['g_score'] += fake_out.item() * batch_size | |
| 100 | ||
| 101 | train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % ( | |
| 102 | epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'], | |
| 103 | running_results['g_loss'] / running_results['batch_sizes'], | |
| 104 | running_results['d_score'] / running_results['batch_sizes'], | |
| 105 | running_results['g_score'] / running_results['batch_sizes'])) | |
| 106 | ||
| 107 | netG.eval() | |
| 108 | out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/' | |
| 109 | if not os.path.exists(out_path): | |
| 110 | os.makedirs(out_path) | |
| 111 | ||
| 112 | with torch.no_grad(): | |
| 113 | val_bar = tqdm(val_loader) | |
| 114 | valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0} | |
| 115 | val_images = [] | |
| 116 | for val_lr, val_hr_restore, val_hr in val_bar: | |
| 117 | batch_size = val_lr.size(0) | |
| 118 | valing_results['batch_sizes'] += batch_size | |
| 119 | lr = val_lr | |
| 120 | hr = val_hr | |
| 121 | if torch.cuda.is_available(): | |
| 122 | lr = lr.cuda() | |
| 123 | hr = hr.cuda() | |
| 124 | sr = netG(lr) | |
| 125 | ||
| 126 | batch_mse = ((sr - hr) ** 2).data.mean() | |
| 127 | valing_results['mse'] += batch_mse * batch_size | |
| 128 | batch_ssim = pytorch_ssim.ssim(sr, hr).item() | |
| 129 | valing_results['ssims'] += batch_ssim * batch_size | |
| 130 | valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes'])) | |
| 131 | valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes'] | |
| 132 | val_bar.set_description( | |
| 133 | desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % ( | |
| 134 | valing_results['psnr'], valing_results['ssim'])) | |
| 135 | ||
| 136 | val_images.extend( | |
| 137 | [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)), | |
| 138 | display_transform()(sr.data.cpu().squeeze(0))]) | |
| 139 | val_images = torch.stack(val_images) | |
| 140 | val_images = torch.chunk(val_images, val_images.size(0) // 15) | |
| 141 | val_save_bar = tqdm(val_images, desc='[saving training results]') | |
| 142 | index = 1 | |
| 143 | for image in val_save_bar: | |
| 144 | image = utils.make_grid(image, nrow=3, padding=5) | |
| 145 | utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5) | |
| 146 | index += 1 | |
| 147 | ||
| 148 | # save model parameters | |
| 149 | torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch)) | |
| 150 | torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch)) | |
| 151 | # save loss\scores\psnr\ssim | |
| 152 | results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes']) | |
| 153 | results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes']) | |
| 154 | results['d_score'].append(running_results['d_score'] / running_results['batch_sizes']) | |
| 155 | results['g_score'].append(running_results['g_score'] / running_results['batch_sizes']) | |
| 156 | results['psnr'].append(valing_results['psnr']) | |
| 157 | results['ssim'].append(valing_results['ssim']) | |
| 158 | ||
| 159 | if epoch % 10 == 0 and epoch != 0: | |
| 160 | out_path = 'statistics/' | |
| 161 | data_frame = pd.DataFrame( | |
| 162 | data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'], | |
| 163 | 'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']}, | |
| 164 | index=range(1, epoch + 1)) | |
| 165 | data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch') |