a1556b8
Initial Commit lzfxxx 6 years ago
35 changed file(s) with 1093 addition(s) and 0 deletion(s). Raw diff Collapse all Expand all
0 .idea/
1 *.pyc
2 *.swp
3 .DS_Store
4 /.localenv/
5 /datasets/
6 /.ipynb_checkpoints/
0 # SRGAN
1 A PyTorch implementation of SRGAN based on CVPR 2017 paper
2 [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802).
3
4 ## Requirements
5 - [Anaconda](https://www.anaconda.com/download/)
6 - PyTorch
7 ```
8 conda install pytorch torchvision -c pytorch
9 ```
10 - opencv
11 ```
12 conda install opencv
13 ```
14
15 ## Datasets
16
17 ### Train、Val Dataset
18 The train and val datasets are sampled from [VOC2012](http://cvlab.postech.ac.kr/~mooyeol/pascal_voc_2012/).
19 Train dataset has 16700 images and Val dataset has 425 images.
20 Download the datasets from [here](https://pan.baidu.com/s/1xuFperu2WiYc5-_QXBemlA)(access code:5tzp), and then extract it into `data` directory.
21
22 ### Test Image Dataset
23 The test image dataset are sampled from
24 | **Set 5** | [Bevilacqua et al. BMVC 2012](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html)
25 | **Set 14** | [Zeyde et al. LNCS 2010](https://sites.google.com/site/romanzeyde/research-interests)
26 | **BSD 100** | [Martin et al. ICCV 2001](https://www.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/)
27 | **Sun-Hays 80** | [Sun and Hays ICCP 2012](http://cs.brown.edu/~lbsun/SRproj2012/SR_iccp2012.html)
28 | **Urban 100** | [Huang et al. CVPR 2015](https://sites.google.com/site/jbhuang0604/publications/struct_sr).
29 Download the image dataset from [here](https://pan.baidu.com/s/1vGosnyal21wGgVffriL1VQ)(access code:xwhy), and then extract it into `data` directory.
30
31 ### Test Video Dataset
32 The test video dataset are three trailers. Download the video dataset from
33 [here](https://pan.baidu.com/s/1NUZKm5xCHRj1O0JlCZIu8Q)(access code:zabi).
34
35 ## Usage
36
37 ### Train
38 ```
39 python train.py
40
41 optional arguments:
42 --crop_size training images crop size [default value is 88]
43 --upscale_factor super resolution upscale factor [default value is 4](choices:[2, 4, 8])
44 --num_epochs train epoch number [default value is 100]
45 ```
46 The output val super resolution images are on `training_results` directory.
47
48 ### Test Benchmark Datasets
49 ```
50 python test_benchmark.py
51
52 optional arguments:
53 --upscale_factor super resolution upscale factor [default value is 4]
54 --model_name generator model epoch name [default value is netG_epoch_4_100.pth]
55 ```
56 The output super resolution images are on `benchmark_results` directory.
57
58 ### Test Single Image
59 ```
60 python test_image.py
61
62 optional arguments:
63 --upscale_factor super resolution upscale factor [default value is 4]
64 --test_mode using GPU or CPU [default value is 'GPU'](choices:['GPU', 'CPU'])
65 --image_name test low resolution image name
66 --model_name generator model epoch name [default value is netG_epoch_4_100.pth]
67 ```
68 The output super resolution image are on the same directory.
69
70 ### Test Single Video
71 ```
72 python test_video.py
73
74 optional arguments:
75 --upscale_factor super resolution upscale factor [default value is 4]
76 --video_name test low resolution video name
77 --model_name generator model epoch name [default value is netG_epoch_4_100.pth]
78 ```
79 The output super resolution video and compared video are on the same directory.
80
81 ## Benchmarks
82 **Upscale Factor = 2**
83
84 Epochs with batch size of 64 takes ~2 minute 30 seconds on a NVIDIA GTX 1080Ti GPU.
85
86 > Image Results
87
88 The left is bicubic interpolation image, the middle is high resolution image, and
89 the right is super resolution image(output of the SRGAN).
90
91 - BSD100_070(PSNR:32.4517; SSIM:0.9191)
92
93 ![BSD100_070](images/1.png)
94
95 - Set14_005(PSNR:26.9171; SSIM:0.9119)
96
97 ![Set14_005](images/2.png)
98
99 - Set14_013(PSNR:30.8040; SSIM:0.9651)
100
101 ![Set14_013](images/3.png)
102
103 - Urban100_098(PSNR:24.3765; SSIM:0.7855)
104
105 ![Urban100_098](images/4.png)
106
107 > Video Results
108
109 The left is bicubic interpolation video, the right is super resolution video(output of the SRGAN).
110
111 [![Watch the video](images/video_SRF_2.png)](https://youtu.be/05vx-vOJOZs)
112
113 **Upscale Factor = 4**
114
115 Epochs with batch size of 64 takes ~4 minute 30 seconds on a NVIDIA GTX 1080Ti GPU.
116
117 > Image Results
118
119 The left is bicubic interpolation image, the middle is high resolution image, and
120 the right is super resolution image(output of the SRGAN).
121
122 - BSD100_035(PSNR:32.3980; SSIM:0.8512)
123
124 ![BSD100_035](images/5.png)
125
126 - Set14_011(PSNR:29.5944; SSIM:0.9044)
127
128 ![Set14_011](images/6.png)
129
130 - Set14_014(PSNR:25.1299; SSIM:0.7406)
131
132 ![Set14_014](images/7.png)
133
134 - Urban100_060(PSNR:20.7129; SSIM:0.5263)
135
136 ![Urban100_060](images/8.png)
137
138 > Video Results
139
140 The left is bicubic interpolation video, the right is super resolution video(output of the SRGAN).
141
142 [![Watch the video](images/video_SRF_4.png)](https://youtu.be/tNR2eiMeoQs)
143
144 **Upscale Factor = 8**
145
146 Epochs with batch size of 64 takes ~3 minute 30 seconds on a NVIDIA GTX 1080Ti GPU.
147
148 > Image Results
149
150 The left is bicubic interpolation image, the middle is high resolution image, and
151 the right is super resolution image(output of the SRGAN).
152
153 - SunHays80_027(PSNR:29.4941; SSIM:0.8082)
154
155 ![SunHays80_027](images/9.png)
156
157 - SunHays80_035(PSNR:32.1546; SSIM:0.8449)
158
159 ![SunHays80_035](images/10.png)
160
161 - SunHays80_043(PSNR:30.9716; SSIM:0.8789)
162
163 ![SunHays80_043](images/11.png)
164
165 - SunHays80_078(PSNR:31.9351; SSIM:0.8381)
166
167 ![SunHays80_078](images/12.png)
168
169 > Video Results
170
171 The left is bicubic interpolation video, the right is super resolution video(output of the SRGAN).
172
173 [![Watch the video](images/video_SRF_8.png)](https://youtu.be/EuvXTKCRr8I)
174
175 The complete test results could be downloaded from [here](https://pan.baidu.com/s/1pAaEgAQ4aRZbtKo8hNrb4Q)(access code:wnvn).
176
0 ## 介绍 (Introduction)
1
2 添加该项目的功能、使用场景和输入输出参数等相关信息。
3
4 You can describe the function, usage and parameters of the project.
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "## 1. 项目介绍\n",
7 "\n",
8 " - 项目是由模块组成、有特定功能的程序。它能够满足用户的直接使用需求,例如[古诗词生成器](https://momodel.cn/explore/5bfb634e1afd943c623dd9cf?type=app&tab=1)、[风格迁移](https://momodel.cn/explore/5bfb634e1afd943c623dd9cf?type=app&tab=1)等。\n",
9 " - 开发项目过程中你可以导入数据集,也可以通过每个 cell 上方工具栏的`<+>`直接插入[模块](https://momodel.cn/modules)和代码块。\n",
10 " - 你可以将开发好的项目进行[部署](https://momodel.cn/docs/#/zh-cn/%E5%BC%80%E5%8F%91%E5%92%8C%E9%83%A8%E7%BD%B2%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8%EF%BC%88app%EF%BC%89),项目部署成功并选择正式版本发布后会展示在“项目”页面,用户可以在线使用,也可以通过 API 调用。\n",
11 "\n",
12 " - 项目目录结构:\n",
13 "\n",
14 " - ```results```*-----结果的文件存放地(如果你运行 job,务必将运行结果指定在此目录)*\n",
15 " - ```_OVERVIEW.md``` *-----项目的相关介绍*\n",
16 " - ```_README.md```*-----说明文档*\n",
17 " - ```app_spec.yml```*-----定义项目的输入输出,为部署服务*\n",
18 " - ```coding_here.ipynb```*-----输入并运行代码*"
19 ]
20 },
21 {
22 "cell_type": "markdown",
23 "metadata": {},
24 "source": [
25 "\n",
26 "## 2. 开发环境简介\n",
27 "\n",
28 "你当前所在的页面 Notebook 是一个内嵌 JupyterLab 的在线类 IDE 编程环境,开发过程中可以使用页面右侧的 API 文档进行快速查询。Notebook 有以下主要功能:\n",
29 "\n",
30 "- [调用数据集、模块和代码块资源](https://momodel.cn/docs/#/zh-cn/%E5%A6%82%E4%BD%95%E5%AF%BC%E5%85%A5%E5%B9%B6%E4%BD%BF%E7%94%A8%E6%A8%A1%E5%9D%97%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86)\n",
31 "- [多人代码协作](https://momodel.cn/docs/#/zh-cn/%E5%9C%A8Mo%E8%BF%90%E8%A1%8C%E4%BD%A0%E7%9A%84%E7%AC%AC%E4%B8%80%E6%AE%B5%E4%BB%A3%E7%A0%81?id=_7-%e4%bd%a0%e5%8f%af%e4%bb%a5%e9%82%80%e8%af%b7%e5%a5%bd%e5%8f%8b%e8%bf%9b%e8%a1%8c%e5%8d%8f%e4%bd%9c)\n",
32 "- [在 GPU 资源上训练机器学习模型](https://momodel.cn/docs/#/zh-cn/%E5%9C%A8GPU%E6%88%96CPU%E8%B5%84%E6%BA%90%E4%B8%8A%E8%AE%AD%E7%BB%83%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E6%A8%A1%E5%9E%8B)\n",
33 "- [简单部署](https://momodel.cn/docs/#/zh-cn/%E5%BC%80%E5%8F%91%E5%92%8C%E9%83%A8%E7%BD%B2%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8%EF%BC%88app%EF%BC%89)\n",
34 "\n",
35 "快来动手试试吧!点击左侧工具栏的新建文件图标即可选择你需要的文件类型。\n",
36 "\n",
37 "<img src='https://imgbed.momodel.cn/006tNc79gy1g61agfcv23j31c30u0789.jpg' width=100% height=100%>\n",
38 "\n",
39 "\n",
40 "\n",
41 "左侧和右侧工具栏都可根据使用需要进行收合。\n",
42 "<img src='https://imgbed.momodel.cn/collapse_tab.2019-09-06 11_07_44.gif' width=100% height=100%>"
43 ]
44 },
45 {
46 "cell_type": "markdown",
47 "metadata": {},
48 "source": [
49 "## 3. 快捷键与代码补全\n",
50 "Mo Notebook 已完全采用 Jupyter Notebook 的原生快捷键,并且支持 `tab` 代码补全。\n",
51 "\n",
52 "运行代码:`shift` + `enter` 或者 `shift` + `return`"
53 ]
54 },
55 {
56 "cell_type": "markdown",
57 "metadata": {},
58 "source": [
59 "## 4. 常用指令介绍\n",
60 "\n",
61 "- 解压上传后的文件\n",
62 "\n",
63 "在 cell 中输入并运行以下命令:\n",
64 "```!unzip -o file_name.zip```\n",
65 "\n",
66 "- 查看所有包(package)\n",
67 "\n",
68 "`!pip list --format=columns`\n",
69 "\n",
70 "- 检查是否已有某个包\n",
71 "\n",
72 "`!pip show package_name`\n",
73 "\n",
74 "- 安装缺失的包\n",
75 "\n",
76 "`!pip install package_name`\n",
77 "\n",
78 "- 更新已有的包\n",
79 "\n",
80 "`!pip install package_name --upgrade`\n",
81 "\n",
82 "\n",
83 "- 使用包\n",
84 "\n",
85 "`import package_name`\n",
86 "\n",
87 "- 显示当前目录下的档案及目录\n",
88 "\n",
89 "`ls`\n",
90 "\n",
91 "- 使用引入的数据集\n",
92 "\n",
93 "数据集被引入后存放在 datasets 文件夹下,注意,这个文件夹是只读的,不可修改。如果需要修改,可在 Notebook 中使用\n",
94 "\n",
95 "`!cp -R ./datasets/<imported_dataset_dir> ./<your_folder>`\n",
96 "\n",
97 "指令将其复制到其他文件夹后再编辑,对于引入的数据集中的 zip 文件,可使用\n",
98 "\n",
99 "`!unzip ./datasets/<imported_dataset_dir>/<XXX.zip> -d ./<your_folder>`\n",
100 "\n",
101 "指令解压缩到其他文件夹后使用"
102 ]
103 },
104 {
105 "cell_type": "markdown",
106 "metadata": {},
107 "source": [
108 "## 5. 其他可参考资源\n",
109 "- [帮助文档](https://momodel.cn/docs/#/):基本页面介绍和常见问题都可以在里面找到\n",
110 "- [平台功能教程](https://momodel.cn/classroom/class?id=5c5696cd1afd9458d456bf54&type=doc):通过图文结合的 Notebook 详细介绍开发环境基本功能和操作\n",
111 "- [吴恩达机器学习](https://momodel.cn/classroom/class?id=5c5696191afd94720cc94533&type=video):机器学习经典课程\n",
112 "- [李宏毅机器学习](https://s.momodel.cn/classroom/class?id=5d40fdafb5113408a8dbb4a1&type=video):中文世界最好的机器学习课程\n",
113 "- [机器学习实战](https://momodel.cn/classroom/class?id=5c680b311afd943a9f70901b&type=practice):通过实操指引完成独立的模型,掌握相应的机器学习知识\n",
114 "- [Python 教程](https://momodel.cn/classroom/class?id=5d1f3ab81afd940ab7d298bf&type=notebook):简单易懂的 Python 新手教程\n",
115 "- [模块开发](https://momodel.cn/modules):关于模型训练、开发与部署的高阶教程"
116 ]
117 }
118 ],
119 "metadata": {
120 "kernelspec": {
121 "display_name": "Python 3",
122 "language": "python",
123 "name": "python3"
124 },
125 "language_info": {
126 "codemirror_mode": {
127 "name": "ipython",
128 "version": 3
129 },
130 "file_extension": ".py",
131 "mimetype": "text/x-python",
132 "name": "python",
133 "nbconvert_exporter": "python",
134 "pygments_lexer": "ipython3",
135 "version": "3.5.2"
136 },
137 "pycharm": {
138 "stem_cell": {
139 "cell_type": "raw",
140 "source": [],
141 "metadata": {
142 "collapsed": false
143 }
144 }
145 }
146 },
147 "nbformat": 4,
148 "nbformat_minor": 2
149 }
(New empty file)
0
1 {
2 "cells": [
3 {
4 "cell_type": "code",
5 "execution_count": null,
6 "metadata": {},
7 "outputs": [],
8 "source": [
9 "print('Hello Mo!')"
10 ]
11 }
12 ],
13 "metadata": {
14 "kernelspec": {
15 "display_name": "Python 3",
16 "language": "python",
17 "name": "python3"
18 },
19 "language_info": {
20 "codemirror_mode": {
21 "name": "ipython",
22 "version": 3
23 },
24 "file_extension": ".py",
25 "mimetype": "text/x-python",
26 "name": "python",
27 "nbconvert_exporter": "python",
28 "pygments_lexer": "ipython3",
29 "version": "3.5.2"
30 }
31 },
32 "nbformat": 4,
33 "nbformat_minor": 2
34 }
35
(New empty file)
0 from os import listdir
1 from os.path import join
2
3 from PIL import Image
4 from torch.utils.data.dataset import Dataset
5 from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
6
7
8 def is_image_file(filename):
9 return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
10
11
12 def calculate_valid_crop_size(crop_size, upscale_factor):
13 return crop_size - (crop_size % upscale_factor)
14
15
16 def train_hr_transform(crop_size):
17 return Compose([
18 RandomCrop(crop_size),
19 ToTensor(),
20 ])
21
22
23 def train_lr_transform(crop_size, upscale_factor):
24 return Compose([
25 ToPILImage(),
26 Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
27 ToTensor()
28 ])
29
30
31 def display_transform():
32 return Compose([
33 ToPILImage(),
34 Resize(400),
35 CenterCrop(400),
36 ToTensor()
37 ])
38
39
40 class TrainDatasetFromFolder(Dataset):
41 def __init__(self, dataset_dir, crop_size, upscale_factor):
42 super(TrainDatasetFromFolder, self).__init__()
43 self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
44 crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
45 self.hr_transform = train_hr_transform(crop_size)
46 self.lr_transform = train_lr_transform(crop_size, upscale_factor)
47
48 def __getitem__(self, index):
49 hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
50 lr_image = self.lr_transform(hr_image)
51 return lr_image, hr_image
52
53 def __len__(self):
54 return len(self.image_filenames)
55
56
57 class ValDatasetFromFolder(Dataset):
58 def __init__(self, dataset_dir, upscale_factor):
59 super(ValDatasetFromFolder, self).__init__()
60 self.upscale_factor = upscale_factor
61 self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
62
63 def __getitem__(self, index):
64 hr_image = Image.open(self.image_filenames[index])
65 w, h = hr_image.size
66 crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
67 lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
68 hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
69 hr_image = CenterCrop(crop_size)(hr_image)
70 lr_image = lr_scale(hr_image)
71 hr_restore_img = hr_scale(lr_image)
72 return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
73
74 def __len__(self):
75 return len(self.image_filenames)
76
77
78 class TestDatasetFromFolder(Dataset):
79 def __init__(self, dataset_dir, upscale_factor):
80 super(TestDatasetFromFolder, self).__init__()
81 self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
82 self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
83 self.upscale_factor = upscale_factor
84 self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
85 self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]
86
87 def __getitem__(self, index):
88 image_name = self.lr_filenames[index].split('/')[-1]
89 lr_image = Image.open(self.lr_filenames[index])
90 w, h = lr_image.size
91 hr_image = Image.open(self.hr_filenames[index])
92 hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=Image.BICUBIC)
93 hr_restore_img = hr_scale(lr_image)
94 return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
95
96 def __len__(self):
97 return len(self.lr_filenames)
(New empty file)
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
Binary diff not shown
0 import torch
1 from torch import nn
2 from torchvision.models.vgg import vgg16
3
4
5 class GeneratorLoss(nn.Module):
6 def __init__(self):
7 super(GeneratorLoss, self).__init__()
8 vgg = vgg16(pretrained=True)
9 loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
10 for param in loss_network.parameters():
11 param.requires_grad = False
12 self.loss_network = loss_network
13 self.mse_loss = nn.MSELoss()
14 self.tv_loss = TVLoss()
15
16 def forward(self, out_labels, out_images, target_images):
17 # Adversarial Loss
18 adversarial_loss = torch.mean(1 - out_labels)
19 # Perception Loss
20 perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
21 # Image Loss
22 image_loss = self.mse_loss(out_images, target_images)
23 # TV Loss
24 tv_loss = self.tv_loss(out_images)
25 return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss
26
27
28 class TVLoss(nn.Module):
29 def __init__(self, tv_loss_weight=1):
30 super(TVLoss, self).__init__()
31 self.tv_loss_weight = tv_loss_weight
32
33 def forward(self, x):
34 batch_size = x.size()[0]
35 h_x = x.size()[2]
36 w_x = x.size()[3]
37 count_h = self.tensor_size(x[:, :, 1:, :])
38 count_w = self.tensor_size(x[:, :, :, 1:])
39 h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
40 w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
41 return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
42
43 @staticmethod
44 def tensor_size(t):
45 return t.size()[1] * t.size()[2] * t.size()[3]
46
47
48 if __name__ == "__main__":
49 g_loss = GeneratorLoss()
50 print(g_loss)
0 import math
1 import torch
2 from torch import nn
3
4
5 class Generator(nn.Module):
6 def __init__(self, scale_factor):
7 upsample_block_num = int(math.log(scale_factor, 2))
8
9 super(Generator, self).__init__()
10 self.block1 = nn.Sequential(
11 nn.Conv2d(3, 64, kernel_size=9, padding=4),
12 nn.PReLU()
13 )
14 self.block2 = ResidualBlock(64)
15 self.block3 = ResidualBlock(64)
16 self.block4 = ResidualBlock(64)
17 self.block5 = ResidualBlock(64)
18 self.block6 = ResidualBlock(64)
19 self.block7 = nn.Sequential(
20 nn.Conv2d(64, 64, kernel_size=3, padding=1),
21 nn.BatchNorm2d(64)
22 )
23 block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
24 block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
25 self.block8 = nn.Sequential(*block8)
26
27 def forward(self, x):
28 block1 = self.block1(x)
29 block2 = self.block2(block1)
30 block3 = self.block3(block2)
31 block4 = self.block4(block3)
32 block5 = self.block5(block4)
33 block6 = self.block6(block5)
34 block7 = self.block7(block6)
35 block8 = self.block8(block1 + block7)
36
37 return (torch.tanh(block8) + 1) / 2
38
39
40 class Discriminator(nn.Module):
41 def __init__(self):
42 super(Discriminator, self).__init__()
43 self.net = nn.Sequential(
44 nn.Conv2d(3, 64, kernel_size=3, padding=1),
45 nn.LeakyReLU(0.2),
46
47 nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
48 nn.BatchNorm2d(64),
49 nn.LeakyReLU(0.2),
50
51 nn.Conv2d(64, 128, kernel_size=3, padding=1),
52 nn.BatchNorm2d(128),
53 nn.LeakyReLU(0.2),
54
55 nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
56 nn.BatchNorm2d(128),
57 nn.LeakyReLU(0.2),
58
59 nn.Conv2d(128, 256, kernel_size=3, padding=1),
60 nn.BatchNorm2d(256),
61 nn.LeakyReLU(0.2),
62
63 nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
64 nn.BatchNorm2d(256),
65 nn.LeakyReLU(0.2),
66
67 nn.Conv2d(256, 512, kernel_size=3, padding=1),
68 nn.BatchNorm2d(512),
69 nn.LeakyReLU(0.2),
70
71 nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
72 nn.BatchNorm2d(512),
73 nn.LeakyReLU(0.2),
74
75 nn.AdaptiveAvgPool2d(1),
76 nn.Conv2d(512, 1024, kernel_size=1),
77 nn.LeakyReLU(0.2),
78 nn.Conv2d(1024, 1, kernel_size=1)
79 )
80
81 def forward(self, x):
82 batch_size = x.size(0)
83 return torch.sigmoid(self.net(x).view(batch_size))
84
85
86 class ResidualBlock(nn.Module):
87 def __init__(self, channels):
88 super(ResidualBlock, self).__init__()
89 self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
90 self.bn1 = nn.BatchNorm2d(channels)
91 self.prelu = nn.PReLU()
92 self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
93 self.bn2 = nn.BatchNorm2d(channels)
94
95 def forward(self, x):
96 residual = self.conv1(x)
97 residual = self.bn1(residual)
98 residual = self.prelu(residual)
99 residual = self.conv2(residual)
100 residual = self.bn2(residual)
101
102 return x + residual
103
104
105 class UpsampleBLock(nn.Module):
106 def __init__(self, in_channels, up_scale):
107 super(UpsampleBLock, self).__init__()
108 self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
109 self.pixel_shuffle = nn.PixelShuffle(up_scale)
110 self.prelu = nn.PReLU()
111
112 def forward(self, x):
113 x = self.conv(x)
114 x = self.pixel_shuffle(x)
115 x = self.prelu(x)
116 return x
0 from math import exp
1
2 import torch
3 import torch.nn.functional as F
4 from torch.autograd import Variable
5
6
7 def gaussian(window_size, sigma):
8 gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
9 return gauss / gauss.sum()
10
11
12 def create_window(window_size, channel):
13 _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
14 _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
15 window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
16 return window
17
18
19 def _ssim(img1, img2, window, window_size, channel, size_average=True):
20 mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
21 mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
22
23 mu1_sq = mu1.pow(2)
24 mu2_sq = mu2.pow(2)
25 mu1_mu2 = mu1 * mu2
26
27 sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
28 sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
29 sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
30
31 C1 = 0.01 ** 2
32 C2 = 0.03 ** 2
33
34 ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
35
36 if size_average:
37 return ssim_map.mean()
38 else:
39 return ssim_map.mean(1).mean(1).mean(1)
40
41
42 class SSIM(torch.nn.Module):
43 def __init__(self, window_size=11, size_average=True):
44 super(SSIM, self).__init__()
45 self.window_size = window_size
46 self.size_average = size_average
47 self.channel = 1
48 self.window = create_window(window_size, self.channel)
49
50 def forward(self, img1, img2):
51 (_, channel, _, _) = img1.size()
52
53 if channel == self.channel and self.window.data.type() == img1.data.type():
54 window = self.window
55 else:
56 window = create_window(self.window_size, channel)
57
58 if img1.is_cuda:
59 window = window.cuda(img1.get_device())
60 window = window.type_as(img1)
61
62 self.window = window
63 self.channel = channel
64
65 return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
66
67
68 def ssim(img1, img2, window_size=11, size_average=True):
69 (_, channel, _, _) = img1.size()
70 window = create_window(window_size, channel)
71
72 if img1.is_cuda:
73 window = window.cuda(img1.get_device())
74 window = window.type_as(img1)
75
76 return _ssim(img1, img2, window, window_size, channel, size_average)
0 Please store your training checkpoints or results here
1 请在此处存储 checkpoints 和结果文件
0 Please store your tensorboard results here
1 请在此处存储 tensorboard 结果
(New empty file)
0 import argparse
1 import os
2 from math import log10
3
4 import numpy as np
5 import pandas as pd
6 import torch
7 import torchvision.utils as utils
8 from torch.autograd import Variable
9 from torch.utils.data import DataLoader
10 from tqdm import tqdm
11
12 import pytorch_ssim
13 from data_utils import TestDatasetFromFolder, display_transform
14 from model import Generator
15
16 parser = argparse.ArgumentParser(description='Test Benchmark Datasets')
17 parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
18 parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name')
19 opt = parser.parse_args()
20
21 UPSCALE_FACTOR = opt.upscale_factor
22 MODEL_NAME = opt.model_name
23
24 results = {'Set5': {'psnr': [], 'ssim': []}, 'Set14': {'psnr': [], 'ssim': []}, 'BSD100': {'psnr': [], 'ssim': []},
25 'Urban100': {'psnr': [], 'ssim': []}, 'SunHays80': {'psnr': [], 'ssim': []}}
26
27 model = Generator(UPSCALE_FACTOR).eval()
28 if torch.cuda.is_available():
29 model = model.cuda()
30 model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
31
32 test_set = TestDatasetFromFolder('data/test', upscale_factor=UPSCALE_FACTOR)
33 test_loader = DataLoader(dataset=test_set, num_workers=4, batch_size=1, shuffle=False)
34 test_bar = tqdm(test_loader, desc='[testing benchmark datasets]')
35
36 out_path = 'benchmark_results/SRF_' + str(UPSCALE_FACTOR) + '/'
37 if not os.path.exists(out_path):
38 os.makedirs(out_path)
39
40 for image_name, lr_image, hr_restore_img, hr_image in test_bar:
41 image_name = image_name[0]
42 lr_image = Variable(lr_image, volatile=True)
43 hr_image = Variable(hr_image, volatile=True)
44 if torch.cuda.is_available():
45 lr_image = lr_image.cuda()
46 hr_image = hr_image.cuda()
47
48 sr_image = model(lr_image)
49 mse = ((hr_image - sr_image) ** 2).data.mean()
50 psnr = 10 * log10(1 / mse)
51 ssim = pytorch_ssim.ssim(sr_image, hr_image).data[0]
52
53 test_images = torch.stack(
54 [display_transform()(hr_restore_img.squeeze(0)), display_transform()(hr_image.data.cpu().squeeze(0)),
55 display_transform()(sr_image.data.cpu().squeeze(0))])
56 image = utils.make_grid(test_images, nrow=3, padding=5)
57 utils.save_image(image, out_path + image_name.split('.')[0] + '_psnr_%.4f_ssim_%.4f.' % (psnr, ssim) +
58 image_name.split('.')[-1], padding=5)
59
60 # save psnr\ssim
61 results[image_name.split('_')[0]]['psnr'].append(psnr)
62 results[image_name.split('_')[0]]['ssim'].append(ssim)
63
64 out_path = 'statistics/'
65 saved_results = {'psnr': [], 'ssim': []}
66 for item in results.values():
67 psnr = np.array(item['psnr'])
68 ssim = np.array(item['ssim'])
69 if (len(psnr) == 0) or (len(ssim) == 0):
70 psnr = 'No data'
71 ssim = 'No data'
72 else:
73 psnr = psnr.mean()
74 ssim = ssim.mean()
75 saved_results['psnr'].append(psnr)
76 saved_results['ssim'].append(ssim)
77
78 data_frame = pd.DataFrame(saved_results, results.keys())
79 data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_test_results.csv', index_label='DataSet')
0 import argparse
1 import time
2
3 import torch
4 from PIL import Image
5 from torch.autograd import Variable
6 from torchvision.transforms import ToTensor, ToPILImage
7
8 from model import Generator
9
10 parser = argparse.ArgumentParser(description='Test Single Image')
11 parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
12 parser.add_argument('--test_mode', default='GPU', type=str, choices=['GPU', 'CPU'], help='using GPU or CPU')
13 parser.add_argument('--image_name', type=str, help='test low resolution image name')
14 parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name')
15 opt = parser.parse_args()
16
17 UPSCALE_FACTOR = opt.upscale_factor
18 TEST_MODE = True if opt.test_mode == 'GPU' else False
19 IMAGE_NAME = opt.image_name
20 MODEL_NAME = opt.model_name
21
22 model = Generator(UPSCALE_FACTOR).eval()
23 if TEST_MODE:
24 model.cuda()
25 model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
26 else:
27 model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))
28
29 image = Image.open(IMAGE_NAME)
30 image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)
31 if TEST_MODE:
32 image = image.cuda()
33
34 start = time.clock()
35 out = model(image)
36 elapsed = (time.clock() - start)
37 print('cost' + str(elapsed) + 's')
38 out_img = ToPILImage()(out[0].data.cpu())
39 out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)
0 import argparse
1
2 import cv2
3 import numpy as np
4 import torch
5 import torchvision.transforms as transforms
6 from PIL import Image
7 from torch.autograd import Variable
8 from torchvision.transforms import ToTensor, ToPILImage
9 from tqdm import tqdm
10
11 from model import Generator
12
13 if __name__ == "__main__":
14 parser = argparse.ArgumentParser(description='Test Single Video')
15 parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
16 parser.add_argument('--video_name', type=str, help='test low resolution video name')
17 parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name')
18 opt = parser.parse_args()
19
20 UPSCALE_FACTOR = opt.upscale_factor
21 VIDEO_NAME = opt.video_name
22 MODEL_NAME = opt.model_name
23
24 model = Generator(UPSCALE_FACTOR).eval()
25 if torch.cuda.is_available():
26 model = model.cuda()
27 # for cpu
28 # model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))
29 model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
30
31 videoCapture = cv2.VideoCapture(VIDEO_NAME)
32 fps = videoCapture.get(cv2.CAP_PROP_FPS)
33 frame_numbers = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
34 sr_video_size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR),
35 int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR)
36 compared_video_size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR * 2 + 10),
37 int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR + 10 + int(
38 int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR * 2 + 10) / int(
39 10 * int(int(
40 videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR) // 5 + 1)) * int(
41 int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR) // 5 - 9)))
42 output_sr_name = 'out_srf_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.avi'
43 output_compared_name = 'compare_srf_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.avi'
44 sr_video_writer = cv2.VideoWriter(output_sr_name, cv2.VideoWriter_fourcc('M', 'P', 'E', 'G'), fps, sr_video_size)
45 compared_video_writer = cv2.VideoWriter(output_compared_name, cv2.VideoWriter_fourcc('M', 'P', 'E', 'G'), fps,
46 compared_video_size)
47 # read frame
48 success, frame = videoCapture.read()
49 test_bar = tqdm(range(int(frame_numbers)), desc='[processing video and saving result videos]')
50 for index in test_bar:
51 if success:
52 image = Variable(ToTensor()(frame), volatile=True).unsqueeze(0)
53 if torch.cuda.is_available():
54 image = image.cuda()
55
56 out = model(image)
57 out = out.cpu()
58 out_img = out.data[0].numpy()
59 out_img *= 255.0
60 out_img = (np.uint8(out_img)).transpose((1, 2, 0))
61 # save sr video
62 sr_video_writer.write(out_img)
63
64 # make compared video and crop shot of left top\right top\center\left bottom\right bottom
65 out_img = ToPILImage()(out_img)
66 crop_out_imgs = transforms.FiveCrop(size=out_img.width // 5 - 9)(out_img)
67 crop_out_imgs = [np.asarray(transforms.Pad(padding=(10, 5, 0, 0))(img)) for img in crop_out_imgs]
68 out_img = transforms.Pad(padding=(5, 0, 0, 5))(out_img)
69 compared_img = transforms.Resize(size=(sr_video_size[1], sr_video_size[0]), interpolation=Image.BICUBIC)(
70 ToPILImage()(frame))
71 crop_compared_imgs = transforms.FiveCrop(size=compared_img.width // 5 - 9)(compared_img)
72 crop_compared_imgs = [np.asarray(transforms.Pad(padding=(0, 5, 10, 0))(img)) for img in crop_compared_imgs]
73 compared_img = transforms.Pad(padding=(0, 0, 5, 5))(compared_img)
74 # concatenate all the pictures to one single picture
75 top_image = np.concatenate((np.asarray(compared_img), np.asarray(out_img)), axis=1)
76 bottom_image = np.concatenate(crop_compared_imgs + crop_out_imgs, axis=1)
77 bottom_image = np.asarray(transforms.Resize(
78 size=(int(top_image.shape[1] / bottom_image.shape[1] * bottom_image.shape[0]), top_image.shape[1]))(
79 ToPILImage()(bottom_image)))
80 final_image = np.concatenate((top_image, bottom_image))
81 # save compared video
82 compared_video_writer.write(final_image)
83 # next frame
84 success, frame = videoCapture.read()
0 import argparse
1 import os
2 from math import log10
3
4 import pandas as pd
5 import torch.optim as optim
6 import torch.utils.data
7 import torchvision.utils as utils
8 from torch.autograd import Variable
9 from torch.utils.data import DataLoader
10 from tqdm import tqdm
11
12 import pytorch_ssim
13 from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform
14 from loss import GeneratorLoss
15 from model import Generator, Discriminator
16
17 parser = argparse.ArgumentParser(description='Train Super Resolution Models')
18 parser.add_argument('--crop_size', default=88, type=int, help='training images crop size')
19 parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8],
20 help='super resolution upscale factor')
21 parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')
22
23
24 if __name__ == '__main__':
25 opt = parser.parse_args()
26
27 CROP_SIZE = opt.crop_size
28 UPSCALE_FACTOR = opt.upscale_factor
29 NUM_EPOCHS = opt.num_epochs
30
31 train_set = TrainDatasetFromFolder('data/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
32 val_set = ValDatasetFromFolder('data/DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)
33 train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
34 val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)
35
36 netG = Generator(UPSCALE_FACTOR)
37 print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
38 netD = Discriminator()
39 print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
40
41 generator_criterion = GeneratorLoss()
42
43 if torch.cuda.is_available():
44 netG.cuda()
45 netD.cuda()
46 generator_criterion.cuda()
47
48 optimizerG = optim.Adam(netG.parameters())
49 optimizerD = optim.Adam(netD.parameters())
50
51 results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
52
53 for epoch in range(1, NUM_EPOCHS + 1):
54 train_bar = tqdm(train_loader)
55 running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}
56
57 netG.train()
58 netD.train()
59 for data, target in train_bar:
60 g_update_first = True
61 batch_size = data.size(0)
62 running_results['batch_sizes'] += batch_size
63
64 ############################
65 # (1) Update D network: maximize D(x)-1-D(G(z))
66 ###########################
67 real_img = Variable(target)
68 if torch.cuda.is_available():
69 real_img = real_img.cuda()
70 z = Variable(data)
71 if torch.cuda.is_available():
72 z = z.cuda()
73 fake_img = netG(z)
74
75 netD.zero_grad()
76 real_out = netD(real_img).mean()
77 fake_out = netD(fake_img).mean()
78 d_loss = 1 - real_out + fake_out
79 d_loss.backward(retain_graph=True)
80 optimizerD.step()
81
82 ############################
83 # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
84 ###########################
85 netG.zero_grad()
86 g_loss = generator_criterion(fake_out, fake_img, real_img)
87 g_loss.backward()
88
89 fake_img = netG(z)
90 fake_out = netD(fake_img).mean()
91
92
93 optimizerG.step()
94
95 # loss for current batch before optimization
96 running_results['g_loss'] += g_loss.item() * batch_size
97 running_results['d_loss'] += d_loss.item() * batch_size
98 running_results['d_score'] += real_out.item() * batch_size
99 running_results['g_score'] += fake_out.item() * batch_size
100
101 train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
102 epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
103 running_results['g_loss'] / running_results['batch_sizes'],
104 running_results['d_score'] / running_results['batch_sizes'],
105 running_results['g_score'] / running_results['batch_sizes']))
106
107 netG.eval()
108 out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
109 if not os.path.exists(out_path):
110 os.makedirs(out_path)
111
112 with torch.no_grad():
113 val_bar = tqdm(val_loader)
114 valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
115 val_images = []
116 for val_lr, val_hr_restore, val_hr in val_bar:
117 batch_size = val_lr.size(0)
118 valing_results['batch_sizes'] += batch_size
119 lr = val_lr
120 hr = val_hr
121 if torch.cuda.is_available():
122 lr = lr.cuda()
123 hr = hr.cuda()
124 sr = netG(lr)
125
126 batch_mse = ((sr - hr) ** 2).data.mean()
127 valing_results['mse'] += batch_mse * batch_size
128 batch_ssim = pytorch_ssim.ssim(sr, hr).item()
129 valing_results['ssims'] += batch_ssim * batch_size
130 valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes']))
131 valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
132 val_bar.set_description(
133 desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
134 valing_results['psnr'], valing_results['ssim']))
135
136 val_images.extend(
137 [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
138 display_transform()(sr.data.cpu().squeeze(0))])
139 val_images = torch.stack(val_images)
140 val_images = torch.chunk(val_images, val_images.size(0) // 15)
141 val_save_bar = tqdm(val_images, desc='[saving training results]')
142 index = 1
143 for image in val_save_bar:
144 image = utils.make_grid(image, nrow=3, padding=5)
145 utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
146 index += 1
147
148 # save model parameters
149 torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
150 torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
151 # save loss\scores\psnr\ssim
152 results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
153 results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
154 results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
155 results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
156 results['psnr'].append(valing_results['psnr'])
157 results['ssim'].append(valing_results['ssim'])
158
159 if epoch % 10 == 0 and epoch != 0:
160 out_path = 'statistics/'
161 data_frame = pd.DataFrame(
162 data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
163 'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
164 index=range(1, epoch + 1))
165 data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')
(New empty file)