diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..2679d4f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,7 @@
+.idea/
+*.pyc
+*.swp
+.DS_Store
+/.localenv/
+/datasets/
+/.ipynb_checkpoints/
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..f90fde3
--- /dev/null
+++ b/README.md
@@ -0,0 +1,177 @@
+# SRGAN
+A PyTorch implementation of SRGAN based on CVPR 2017 paper
+[Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802).
+
+## Requirements
+- [Anaconda](https://www.anaconda.com/download/)
+- PyTorch
+```
+conda install pytorch torchvision -c pytorch
+```
+- opencv
+```
+conda install opencv
+```
+
+## Datasets
+
+### Train、Val Dataset
+The train and val datasets are sampled from [VOC2012](http://cvlab.postech.ac.kr/~mooyeol/pascal_voc_2012/).
+Train dataset has 16700 images and Val dataset has 425 images.
+Download the datasets from [here](https://pan.baidu.com/s/1xuFperu2WiYc5-_QXBemlA)(access code:5tzp), and then extract it into `data` directory.
+
+### Test Image Dataset
+The test image dataset are sampled from
+| **Set 5** | [Bevilacqua et al. BMVC 2012](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html)
+| **Set 14** | [Zeyde et al. LNCS 2010](https://sites.google.com/site/romanzeyde/research-interests)
+| **BSD 100** | [Martin et al. ICCV 2001](https://www.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/)
+| **Sun-Hays 80** | [Sun and Hays ICCP 2012](http://cs.brown.edu/~lbsun/SRproj2012/SR_iccp2012.html)
+| **Urban 100** | [Huang et al. CVPR 2015](https://sites.google.com/site/jbhuang0604/publications/struct_sr).
+Download the image dataset from [here](https://pan.baidu.com/s/1vGosnyal21wGgVffriL1VQ)(access code:xwhy), and then extract it into `data` directory.
+
+### Test Video Dataset
+The test video dataset are three trailers. Download the video dataset from
+[here](https://pan.baidu.com/s/1NUZKm5xCHRj1O0JlCZIu8Q)(access code:zabi).
+
+## Usage
+
+### Train
+```
+python train.py
+
+optional arguments:
+--crop_size training images crop size [default value is 88]
+--upscale_factor super resolution upscale factor [default value is 4](choices:[2, 4, 8])
+--num_epochs train epoch number [default value is 100]
+```
+The output val super resolution images are on `training_results` directory.
+
+### Test Benchmark Datasets
+```
+python test_benchmark.py
+
+optional arguments:
+--upscale_factor super resolution upscale factor [default value is 4]
+--model_name generator model epoch name [default value is netG_epoch_4_100.pth]
+```
+The output super resolution images are on `benchmark_results` directory.
+
+### Test Single Image
+```
+python test_image.py
+
+optional arguments:
+--upscale_factor super resolution upscale factor [default value is 4]
+--test_mode using GPU or CPU [default value is 'GPU'](choices:['GPU', 'CPU'])
+--image_name test low resolution image name
+--model_name generator model epoch name [default value is netG_epoch_4_100.pth]
+```
+The output super resolution image are on the same directory.
+
+### Test Single Video
+```
+python test_video.py
+
+optional arguments:
+--upscale_factor super resolution upscale factor [default value is 4]
+--video_name test low resolution video name
+--model_name generator model epoch name [default value is netG_epoch_4_100.pth]
+```
+The output super resolution video and compared video are on the same directory.
+
+## Benchmarks
+**Upscale Factor = 2**
+
+Epochs with batch size of 64 takes ~2 minute 30 seconds on a NVIDIA GTX 1080Ti GPU.
+
+> Image Results
+
+The left is bicubic interpolation image, the middle is high resolution image, and
+the right is super resolution image(output of the SRGAN).
+
+- BSD100_070(PSNR:32.4517; SSIM:0.9191)
+
+
+
+- Set14_005(PSNR:26.9171; SSIM:0.9119)
+
+
+
+- Set14_013(PSNR:30.8040; SSIM:0.9651)
+
+
+
+- Urban100_098(PSNR:24.3765; SSIM:0.7855)
+
+
+
+> Video Results
+
+The left is bicubic interpolation video, the right is super resolution video(output of the SRGAN).
+
+[](https://youtu.be/05vx-vOJOZs)
+
+**Upscale Factor = 4**
+
+Epochs with batch size of 64 takes ~4 minute 30 seconds on a NVIDIA GTX 1080Ti GPU.
+
+> Image Results
+
+The left is bicubic interpolation image, the middle is high resolution image, and
+the right is super resolution image(output of the SRGAN).
+
+- BSD100_035(PSNR:32.3980; SSIM:0.8512)
+
+
+
+- Set14_011(PSNR:29.5944; SSIM:0.9044)
+
+
+
+- Set14_014(PSNR:25.1299; SSIM:0.7406)
+
+
+
+- Urban100_060(PSNR:20.7129; SSIM:0.5263)
+
+
+
+> Video Results
+
+The left is bicubic interpolation video, the right is super resolution video(output of the SRGAN).
+
+[](https://youtu.be/tNR2eiMeoQs)
+
+**Upscale Factor = 8**
+
+Epochs with batch size of 64 takes ~3 minute 30 seconds on a NVIDIA GTX 1080Ti GPU.
+
+> Image Results
+
+The left is bicubic interpolation image, the middle is high resolution image, and
+the right is super resolution image(output of the SRGAN).
+
+- SunHays80_027(PSNR:29.4941; SSIM:0.8082)
+
+
+
+- SunHays80_035(PSNR:32.1546; SSIM:0.8449)
+
+
+
+- SunHays80_043(PSNR:30.9716; SSIM:0.8789)
+
+
+
+- SunHays80_078(PSNR:31.9351; SSIM:0.8381)
+
+
+
+> Video Results
+
+The left is bicubic interpolation video, the right is super resolution video(output of the SRGAN).
+
+[](https://youtu.be/EuvXTKCRr8I)
+
+The complete test results could be downloaded from [here](https://pan.baidu.com/s/1pAaEgAQ4aRZbtKo8hNrb4Q)(access code:wnvn).
+
diff --git a/_OVERVIEW.md b/_OVERVIEW.md
new file mode 100644
index 0000000..d2a86d2
--- /dev/null
+++ b/_OVERVIEW.md
@@ -0,0 +1,5 @@
+## 介绍 (Introduction)
+
+添加该项目的功能、使用场景和输入输出参数等相关信息。
+
+You can describe the function, usage and parameters of the project.
\ No newline at end of file
diff --git a/_README.ipynb b/_README.ipynb
new file mode 100644
index 0000000..2abc9d6
--- /dev/null
+++ b/_README.ipynb
@@ -0,0 +1,150 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. 项目介绍\n",
+ "\n",
+ " - 项目是由模块组成、有特定功能的程序。它能够满足用户的直接使用需求,例如[古诗词生成器](https://momodel.cn/explore/5bfb634e1afd943c623dd9cf?type=app&tab=1)、[风格迁移](https://momodel.cn/explore/5bfb634e1afd943c623dd9cf?type=app&tab=1)等。\n",
+ " - 开发项目过程中你可以导入数据集,也可以通过每个 cell 上方工具栏的`<+>`直接插入[模块](https://momodel.cn/modules)和代码块。\n",
+ " - 你可以将开发好的项目进行[部署](https://momodel.cn/docs/#/zh-cn/%E5%BC%80%E5%8F%91%E5%92%8C%E9%83%A8%E7%BD%B2%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8%EF%BC%88app%EF%BC%89),项目部署成功并选择正式版本发布后会展示在“项目”页面,用户可以在线使用,也可以通过 API 调用。\n",
+ "\n",
+ " - 项目目录结构:\n",
+ "\n",
+ " - ```results```*-----结果的文件存放地(如果你运行 job,务必将运行结果指定在此目录)*\n",
+ " - ```_OVERVIEW.md``` *-----项目的相关介绍*\n",
+ " - ```_README.md```*-----说明文档*\n",
+ " - ```app_spec.yml```*-----定义项目的输入输出,为部署服务*\n",
+ " - ```coding_here.ipynb```*-----输入并运行代码*"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "## 2. 开发环境简介\n",
+ "\n",
+ "你当前所在的页面 Notebook 是一个内嵌 JupyterLab 的在线类 IDE 编程环境,开发过程中可以使用页面右侧的 API 文档进行快速查询。Notebook 有以下主要功能:\n",
+ "\n",
+ "- [调用数据集、模块和代码块资源](https://momodel.cn/docs/#/zh-cn/%E5%A6%82%E4%BD%95%E5%AF%BC%E5%85%A5%E5%B9%B6%E4%BD%BF%E7%94%A8%E6%A8%A1%E5%9D%97%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86)\n",
+ "- [多人代码协作](https://momodel.cn/docs/#/zh-cn/%E5%9C%A8Mo%E8%BF%90%E8%A1%8C%E4%BD%A0%E7%9A%84%E7%AC%AC%E4%B8%80%E6%AE%B5%E4%BB%A3%E7%A0%81?id=_7-%e4%bd%a0%e5%8f%af%e4%bb%a5%e9%82%80%e8%af%b7%e5%a5%bd%e5%8f%8b%e8%bf%9b%e8%a1%8c%e5%8d%8f%e4%bd%9c)\n",
+ "- [在 GPU 资源上训练机器学习模型](https://momodel.cn/docs/#/zh-cn/%E5%9C%A8GPU%E6%88%96CPU%E8%B5%84%E6%BA%90%E4%B8%8A%E8%AE%AD%E7%BB%83%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E6%A8%A1%E5%9E%8B)\n",
+ "- [简单部署](https://momodel.cn/docs/#/zh-cn/%E5%BC%80%E5%8F%91%E5%92%8C%E9%83%A8%E7%BD%B2%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8%EF%BC%88app%EF%BC%89)\n",
+ "\n",
+ "快来动手试试吧!点击左侧工具栏的新建文件图标即可选择你需要的文件类型。\n",
+ "\n",
+ "
\n",
+ "\n",
+ "\n",
+ "\n",
+ "左侧和右侧工具栏都可根据使用需要进行收合。\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. 快捷键与代码补全\n",
+ "Mo Notebook 已完全采用 Jupyter Notebook 的原生快捷键,并且支持 `tab` 代码补全。\n",
+ "\n",
+ "运行代码:`shift` + `enter` 或者 `shift` + `return`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 4. 常用指令介绍\n",
+ "\n",
+ "- 解压上传后的文件\n",
+ "\n",
+ "在 cell 中输入并运行以下命令:\n",
+ "```!unzip -o file_name.zip```\n",
+ "\n",
+ "- 查看所有包(package)\n",
+ "\n",
+ "`!pip list --format=columns`\n",
+ "\n",
+ "- 检查是否已有某个包\n",
+ "\n",
+ "`!pip show package_name`\n",
+ "\n",
+ "- 安装缺失的包\n",
+ "\n",
+ "`!pip install package_name`\n",
+ "\n",
+ "- 更新已有的包\n",
+ "\n",
+ "`!pip install package_name --upgrade`\n",
+ "\n",
+ "\n",
+ "- 使用包\n",
+ "\n",
+ "`import package_name`\n",
+ "\n",
+ "- 显示当前目录下的档案及目录\n",
+ "\n",
+ "`ls`\n",
+ "\n",
+ "- 使用引入的数据集\n",
+ "\n",
+ "数据集被引入后存放在 datasets 文件夹下,注意,这个文件夹是只读的,不可修改。如果需要修改,可在 Notebook 中使用\n",
+ "\n",
+ "`!cp -R ./datasets/ ./`\n",
+ "\n",
+ "指令将其复制到其他文件夹后再编辑,对于引入的数据集中的 zip 文件,可使用\n",
+ "\n",
+ "`!unzip ./datasets// -d ./`\n",
+ "\n",
+ "指令解压缩到其他文件夹后使用"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 5. 其他可参考资源\n",
+ "- [帮助文档](https://momodel.cn/docs/#/):基本页面介绍和常见问题都可以在里面找到\n",
+ "- [平台功能教程](https://momodel.cn/classroom/class?id=5c5696cd1afd9458d456bf54&type=doc):通过图文结合的 Notebook 详细介绍开发环境基本功能和操作\n",
+ "- [吴恩达机器学习](https://momodel.cn/classroom/class?id=5c5696191afd94720cc94533&type=video):机器学习经典课程\n",
+ "- [李宏毅机器学习](https://s.momodel.cn/classroom/class?id=5d40fdafb5113408a8dbb4a1&type=video):中文世界最好的机器学习课程\n",
+ "- [机器学习实战](https://momodel.cn/classroom/class?id=5c680b311afd943a9f70901b&type=practice):通过实操指引完成独立的模型,掌握相应的机器学习知识\n",
+ "- [Python 教程](https://momodel.cn/classroom/class?id=5d1f3ab81afd940ab7d298bf&type=notebook):简单易懂的 Python 新手教程\n",
+ "- [模块开发](https://momodel.cn/modules):关于模型训练、开发与部署的高阶教程"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.5.2"
+ },
+ "pycharm": {
+ "stem_cell": {
+ "cell_type": "raw",
+ "source": [],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
\ No newline at end of file
diff --git a/benchmark_results/.gitkeep b/benchmark_results/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/coding_here.ipynb b/coding_here.ipynb
new file mode 100644
index 0000000..90d9432
--- /dev/null
+++ b/coding_here.ipynb
@@ -0,0 +1,36 @@
+
+ {
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print('Hello Mo!')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.5.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+ }
+
\ No newline at end of file
diff --git a/data/.gitkeep b/data/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/data_utils.py b/data_utils.py
new file mode 100755
index 0000000..ff41d51
--- /dev/null
+++ b/data_utils.py
@@ -0,0 +1,98 @@
+from os import listdir
+from os.path import join
+
+from PIL import Image
+from torch.utils.data.dataset import Dataset
+from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
+
+
+def calculate_valid_crop_size(crop_size, upscale_factor):
+ return crop_size - (crop_size % upscale_factor)
+
+
+def train_hr_transform(crop_size):
+ return Compose([
+ RandomCrop(crop_size),
+ ToTensor(),
+ ])
+
+
+def train_lr_transform(crop_size, upscale_factor):
+ return Compose([
+ ToPILImage(),
+ Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
+ ToTensor()
+ ])
+
+
+def display_transform():
+ return Compose([
+ ToPILImage(),
+ Resize(400),
+ CenterCrop(400),
+ ToTensor()
+ ])
+
+
+class TrainDatasetFromFolder(Dataset):
+ def __init__(self, dataset_dir, crop_size, upscale_factor):
+ super(TrainDatasetFromFolder, self).__init__()
+ self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
+ crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
+ self.hr_transform = train_hr_transform(crop_size)
+ self.lr_transform = train_lr_transform(crop_size, upscale_factor)
+
+ def __getitem__(self, index):
+ hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
+ lr_image = self.lr_transform(hr_image)
+ return lr_image, hr_image
+
+ def __len__(self):
+ return len(self.image_filenames)
+
+
+class ValDatasetFromFolder(Dataset):
+ def __init__(self, dataset_dir, upscale_factor):
+ super(ValDatasetFromFolder, self).__init__()
+ self.upscale_factor = upscale_factor
+ self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
+
+ def __getitem__(self, index):
+ hr_image = Image.open(self.image_filenames[index])
+ w, h = hr_image.size
+ crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
+ lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
+ hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
+ hr_image = CenterCrop(crop_size)(hr_image)
+ lr_image = lr_scale(hr_image)
+ hr_restore_img = hr_scale(lr_image)
+ return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
+
+ def __len__(self):
+ return len(self.image_filenames)
+
+
+class TestDatasetFromFolder(Dataset):
+ def __init__(self, dataset_dir, upscale_factor):
+ super(TestDatasetFromFolder, self).__init__()
+ self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
+ self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
+ self.upscale_factor = upscale_factor
+ self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
+ self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]
+
+ def __getitem__(self, index):
+ image_name = self.lr_filenames[index].split('/')[-1]
+ lr_image = Image.open(self.lr_filenames[index])
+ w, h = lr_image.size
+ hr_image = Image.open(self.hr_filenames[index])
+ hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=Image.BICUBIC)
+ hr_restore_img = hr_scale(lr_image)
+ return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
+
+ def __len__(self):
+ return len(self.lr_filenames)
diff --git a/epochs/.gitkeep b/epochs/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/images/1.png b/images/1.png
new file mode 100644
index 0000000..5ee502f
Binary files /dev/null and b/images/1.png differ
diff --git a/images/10.png b/images/10.png
new file mode 100644
index 0000000..337f02a
Binary files /dev/null and b/images/10.png differ
diff --git a/images/11.png b/images/11.png
new file mode 100644
index 0000000..8f12162
Binary files /dev/null and b/images/11.png differ
diff --git a/images/12.png b/images/12.png
new file mode 100644
index 0000000..c4900c2
Binary files /dev/null and b/images/12.png differ
diff --git a/images/2.png b/images/2.png
new file mode 100644
index 0000000..8b795f2
Binary files /dev/null and b/images/2.png differ
diff --git a/images/3.png b/images/3.png
new file mode 100644
index 0000000..8144308
Binary files /dev/null and b/images/3.png differ
diff --git a/images/4.png b/images/4.png
new file mode 100644
index 0000000..54ca285
Binary files /dev/null and b/images/4.png differ
diff --git a/images/5.png b/images/5.png
new file mode 100644
index 0000000..a6b9f1e
Binary files /dev/null and b/images/5.png differ
diff --git a/images/6.png b/images/6.png
new file mode 100644
index 0000000..41fa8a9
Binary files /dev/null and b/images/6.png differ
diff --git a/images/7.png b/images/7.png
new file mode 100644
index 0000000..8ba13c8
Binary files /dev/null and b/images/7.png differ
diff --git a/images/8.png b/images/8.png
new file mode 100644
index 0000000..0b1ce2d
Binary files /dev/null and b/images/8.png differ
diff --git a/images/9.png b/images/9.png
new file mode 100644
index 0000000..7cd1237
Binary files /dev/null and b/images/9.png differ
diff --git a/images/video_SRF_2.png b/images/video_SRF_2.png
new file mode 100644
index 0000000..2877591
Binary files /dev/null and b/images/video_SRF_2.png differ
diff --git a/images/video_SRF_4.png b/images/video_SRF_4.png
new file mode 100644
index 0000000..e5a5fda
Binary files /dev/null and b/images/video_SRF_4.png differ
diff --git a/images/video_SRF_8.png b/images/video_SRF_8.png
new file mode 100644
index 0000000..f7a8fe2
Binary files /dev/null and b/images/video_SRF_8.png differ
diff --git a/loss.py b/loss.py
new file mode 100644
index 0000000..ab30472
--- /dev/null
+++ b/loss.py
@@ -0,0 +1,51 @@
+import torch
+from torch import nn
+from torchvision.models.vgg import vgg16
+
+
+class GeneratorLoss(nn.Module):
+ def __init__(self):
+ super(GeneratorLoss, self).__init__()
+ vgg = vgg16(pretrained=True)
+ loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
+ for param in loss_network.parameters():
+ param.requires_grad = False
+ self.loss_network = loss_network
+ self.mse_loss = nn.MSELoss()
+ self.tv_loss = TVLoss()
+
+ def forward(self, out_labels, out_images, target_images):
+ # Adversarial Loss
+ adversarial_loss = torch.mean(1 - out_labels)
+ # Perception Loss
+ perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
+ # Image Loss
+ image_loss = self.mse_loss(out_images, target_images)
+ # TV Loss
+ tv_loss = self.tv_loss(out_images)
+ return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss
+
+
+class TVLoss(nn.Module):
+ def __init__(self, tv_loss_weight=1):
+ super(TVLoss, self).__init__()
+ self.tv_loss_weight = tv_loss_weight
+
+ def forward(self, x):
+ batch_size = x.size()[0]
+ h_x = x.size()[2]
+ w_x = x.size()[3]
+ count_h = self.tensor_size(x[:, :, 1:, :])
+ count_w = self.tensor_size(x[:, :, :, 1:])
+ h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
+ w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
+ return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
+
+ @staticmethod
+ def tensor_size(t):
+ return t.size()[1] * t.size()[2] * t.size()[3]
+
+
+if __name__ == "__main__":
+ g_loss = GeneratorLoss()
+ print(g_loss)
diff --git a/model.py b/model.py
new file mode 100755
index 0000000..261f0d0
--- /dev/null
+++ b/model.py
@@ -0,0 +1,117 @@
+import math
+import torch
+from torch import nn
+
+
+class Generator(nn.Module):
+ def __init__(self, scale_factor):
+ upsample_block_num = int(math.log(scale_factor, 2))
+
+ super(Generator, self).__init__()
+ self.block1 = nn.Sequential(
+ nn.Conv2d(3, 64, kernel_size=9, padding=4),
+ nn.PReLU()
+ )
+ self.block2 = ResidualBlock(64)
+ self.block3 = ResidualBlock(64)
+ self.block4 = ResidualBlock(64)
+ self.block5 = ResidualBlock(64)
+ self.block6 = ResidualBlock(64)
+ self.block7 = nn.Sequential(
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
+ nn.BatchNorm2d(64)
+ )
+ block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
+ block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
+ self.block8 = nn.Sequential(*block8)
+
+ def forward(self, x):
+ block1 = self.block1(x)
+ block2 = self.block2(block1)
+ block3 = self.block3(block2)
+ block4 = self.block4(block3)
+ block5 = self.block5(block4)
+ block6 = self.block6(block5)
+ block7 = self.block7(block6)
+ block8 = self.block8(block1 + block7)
+
+ return (torch.tanh(block8) + 1) / 2
+
+
+class Discriminator(nn.Module):
+ def __init__(self):
+ super(Discriminator, self).__init__()
+ self.net = nn.Sequential(
+ nn.Conv2d(3, 64, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2),
+
+ nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
+ nn.BatchNorm2d(64),
+ nn.LeakyReLU(0.2),
+
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
+ nn.BatchNorm2d(128),
+ nn.LeakyReLU(0.2),
+
+ nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
+ nn.BatchNorm2d(128),
+ nn.LeakyReLU(0.2),
+
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
+ nn.BatchNorm2d(256),
+ nn.LeakyReLU(0.2),
+
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
+ nn.BatchNorm2d(256),
+ nn.LeakyReLU(0.2),
+
+ nn.Conv2d(256, 512, kernel_size=3, padding=1),
+ nn.BatchNorm2d(512),
+ nn.LeakyReLU(0.2),
+
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
+ nn.BatchNorm2d(512),
+ nn.LeakyReLU(0.2),
+
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(512, 1024, kernel_size=1),
+ nn.LeakyReLU(0.2),
+ nn.Conv2d(1024, 1, kernel_size=1)
+ )
+
+ def forward(self, x):
+ batch_size = x.size(0)
+ return torch.sigmoid(self.net(x).view(batch_size))
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, channels):
+ super(ResidualBlock, self).__init__()
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
+ self.bn1 = nn.BatchNorm2d(channels)
+ self.prelu = nn.PReLU()
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
+ self.bn2 = nn.BatchNorm2d(channels)
+
+ def forward(self, x):
+ residual = self.conv1(x)
+ residual = self.bn1(residual)
+ residual = self.prelu(residual)
+ residual = self.conv2(residual)
+ residual = self.bn2(residual)
+
+ return x + residual
+
+
+class UpsampleBLock(nn.Module):
+ def __init__(self, in_channels, up_scale):
+ super(UpsampleBLock, self).__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
+ self.pixel_shuffle = nn.PixelShuffle(up_scale)
+ self.prelu = nn.PReLU()
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.pixel_shuffle(x)
+ x = self.prelu(x)
+ return x
diff --git a/pytorch_ssim/__init__.py b/pytorch_ssim/__init__.py
new file mode 100755
index 0000000..1ab54e9
--- /dev/null
+++ b/pytorch_ssim/__init__.py
@@ -0,0 +1,77 @@
+from math import exp
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+
+
+def gaussian(window_size, sigma):
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
+ return gauss / gauss.sum()
+
+
+def create_window(window_size, channel):
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
+ return window
+
+
+def _ssim(img1, img2, window, window_size, channel, size_average=True):
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
+
+ C1 = 0.01 ** 2
+ C2 = 0.03 ** 2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+
+ if size_average:
+ return ssim_map.mean()
+ else:
+ return ssim_map.mean(1).mean(1).mean(1)
+
+
+class SSIM(torch.nn.Module):
+ def __init__(self, window_size=11, size_average=True):
+ super(SSIM, self).__init__()
+ self.window_size = window_size
+ self.size_average = size_average
+ self.channel = 1
+ self.window = create_window(window_size, self.channel)
+
+ def forward(self, img1, img2):
+ (_, channel, _, _) = img1.size()
+
+ if channel == self.channel and self.window.data.type() == img1.data.type():
+ window = self.window
+ else:
+ window = create_window(self.window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ self.window = window
+ self.channel = channel
+
+ return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
+
+
+def ssim(img1, img2, window_size=11, size_average=True):
+ (_, channel, _, _) = img1.size()
+ window = create_window(window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ return _ssim(img1, img2, window, window_size, channel, size_average)
diff --git a/results/README.md b/results/README.md
new file mode 100644
index 0000000..81902f1
--- /dev/null
+++ b/results/README.md
@@ -0,0 +1,2 @@
+Please store your training checkpoints or results here
+请在此处存储 checkpoints 和结果文件
\ No newline at end of file
diff --git a/results/tb_results/README.md b/results/tb_results/README.md
new file mode 100644
index 0000000..23659ee
--- /dev/null
+++ b/results/tb_results/README.md
@@ -0,0 +1,2 @@
+Please store your tensorboard results here
+请在此处存储 tensorboard 结果
diff --git a/statistics/.gitkeep b/statistics/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/test_benchmark.py b/test_benchmark.py
new file mode 100755
index 0000000..0bd9870
--- /dev/null
+++ b/test_benchmark.py
@@ -0,0 +1,80 @@
+import argparse
+import os
+from math import log10
+
+import numpy as np
+import pandas as pd
+import torch
+import torchvision.utils as utils
+from torch.autograd import Variable
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+import pytorch_ssim
+from data_utils import TestDatasetFromFolder, display_transform
+from model import Generator
+
+parser = argparse.ArgumentParser(description='Test Benchmark Datasets')
+parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
+parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name')
+opt = parser.parse_args()
+
+UPSCALE_FACTOR = opt.upscale_factor
+MODEL_NAME = opt.model_name
+
+results = {'Set5': {'psnr': [], 'ssim': []}, 'Set14': {'psnr': [], 'ssim': []}, 'BSD100': {'psnr': [], 'ssim': []},
+ 'Urban100': {'psnr': [], 'ssim': []}, 'SunHays80': {'psnr': [], 'ssim': []}}
+
+model = Generator(UPSCALE_FACTOR).eval()
+if torch.cuda.is_available():
+ model = model.cuda()
+model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
+
+test_set = TestDatasetFromFolder('data/test', upscale_factor=UPSCALE_FACTOR)
+test_loader = DataLoader(dataset=test_set, num_workers=4, batch_size=1, shuffle=False)
+test_bar = tqdm(test_loader, desc='[testing benchmark datasets]')
+
+out_path = 'benchmark_results/SRF_' + str(UPSCALE_FACTOR) + '/'
+if not os.path.exists(out_path):
+ os.makedirs(out_path)
+
+for image_name, lr_image, hr_restore_img, hr_image in test_bar:
+ image_name = image_name[0]
+ lr_image = Variable(lr_image, volatile=True)
+ hr_image = Variable(hr_image, volatile=True)
+ if torch.cuda.is_available():
+ lr_image = lr_image.cuda()
+ hr_image = hr_image.cuda()
+
+ sr_image = model(lr_image)
+ mse = ((hr_image - sr_image) ** 2).data.mean()
+ psnr = 10 * log10(1 / mse)
+ ssim = pytorch_ssim.ssim(sr_image, hr_image).data[0]
+
+ test_images = torch.stack(
+ [display_transform()(hr_restore_img.squeeze(0)), display_transform()(hr_image.data.cpu().squeeze(0)),
+ display_transform()(sr_image.data.cpu().squeeze(0))])
+ image = utils.make_grid(test_images, nrow=3, padding=5)
+ utils.save_image(image, out_path + image_name.split('.')[0] + '_psnr_%.4f_ssim_%.4f.' % (psnr, ssim) +
+ image_name.split('.')[-1], padding=5)
+
+ # save psnr\ssim
+ results[image_name.split('_')[0]]['psnr'].append(psnr)
+ results[image_name.split('_')[0]]['ssim'].append(ssim)
+
+out_path = 'statistics/'
+saved_results = {'psnr': [], 'ssim': []}
+for item in results.values():
+ psnr = np.array(item['psnr'])
+ ssim = np.array(item['ssim'])
+ if (len(psnr) == 0) or (len(ssim) == 0):
+ psnr = 'No data'
+ ssim = 'No data'
+ else:
+ psnr = psnr.mean()
+ ssim = ssim.mean()
+ saved_results['psnr'].append(psnr)
+ saved_results['ssim'].append(ssim)
+
+data_frame = pd.DataFrame(saved_results, results.keys())
+data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_test_results.csv', index_label='DataSet')
diff --git a/test_image.py b/test_image.py
new file mode 100755
index 0000000..ef7df58
--- /dev/null
+++ b/test_image.py
@@ -0,0 +1,40 @@
+import argparse
+import time
+
+import torch
+from PIL import Image
+from torch.autograd import Variable
+from torchvision.transforms import ToTensor, ToPILImage
+
+from model import Generator
+
+parser = argparse.ArgumentParser(description='Test Single Image')
+parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
+parser.add_argument('--test_mode', default='GPU', type=str, choices=['GPU', 'CPU'], help='using GPU or CPU')
+parser.add_argument('--image_name', type=str, help='test low resolution image name')
+parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name')
+opt = parser.parse_args()
+
+UPSCALE_FACTOR = opt.upscale_factor
+TEST_MODE = True if opt.test_mode == 'GPU' else False
+IMAGE_NAME = opt.image_name
+MODEL_NAME = opt.model_name
+
+model = Generator(UPSCALE_FACTOR).eval()
+if TEST_MODE:
+ model.cuda()
+ model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
+else:
+ model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))
+
+image = Image.open(IMAGE_NAME)
+image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)
+if TEST_MODE:
+ image = image.cuda()
+
+start = time.clock()
+out = model(image)
+elapsed = (time.clock() - start)
+print('cost' + str(elapsed) + 's')
+out_img = ToPILImage()(out[0].data.cpu())
+out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)
diff --git a/test_video.py b/test_video.py
new file mode 100755
index 0000000..38a3d90
--- /dev/null
+++ b/test_video.py
@@ -0,0 +1,85 @@
+import argparse
+
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from PIL import Image
+from torch.autograd import Variable
+from torchvision.transforms import ToTensor, ToPILImage
+from tqdm import tqdm
+
+from model import Generator
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Test Single Video')
+ parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
+ parser.add_argument('--video_name', type=str, help='test low resolution video name')
+ parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name')
+ opt = parser.parse_args()
+
+ UPSCALE_FACTOR = opt.upscale_factor
+ VIDEO_NAME = opt.video_name
+ MODEL_NAME = opt.model_name
+
+ model = Generator(UPSCALE_FACTOR).eval()
+ if torch.cuda.is_available():
+ model = model.cuda()
+ # for cpu
+ # model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))
+ model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
+
+ videoCapture = cv2.VideoCapture(VIDEO_NAME)
+ fps = videoCapture.get(cv2.CAP_PROP_FPS)
+ frame_numbers = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
+ sr_video_size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR),
+ int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR)
+ compared_video_size = (int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR * 2 + 10),
+ int(videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)) * UPSCALE_FACTOR + 10 + int(
+ int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR * 2 + 10) / int(
+ 10 * int(int(
+ videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR) // 5 + 1)) * int(
+ int(videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) * UPSCALE_FACTOR) // 5 - 9)))
+ output_sr_name = 'out_srf_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.avi'
+ output_compared_name = 'compare_srf_' + str(UPSCALE_FACTOR) + '_' + VIDEO_NAME.split('.')[0] + '.avi'
+ sr_video_writer = cv2.VideoWriter(output_sr_name, cv2.VideoWriter_fourcc('M', 'P', 'E', 'G'), fps, sr_video_size)
+ compared_video_writer = cv2.VideoWriter(output_compared_name, cv2.VideoWriter_fourcc('M', 'P', 'E', 'G'), fps,
+ compared_video_size)
+ # read frame
+ success, frame = videoCapture.read()
+ test_bar = tqdm(range(int(frame_numbers)), desc='[processing video and saving result videos]')
+ for index in test_bar:
+ if success:
+ image = Variable(ToTensor()(frame), volatile=True).unsqueeze(0)
+ if torch.cuda.is_available():
+ image = image.cuda()
+
+ out = model(image)
+ out = out.cpu()
+ out_img = out.data[0].numpy()
+ out_img *= 255.0
+ out_img = (np.uint8(out_img)).transpose((1, 2, 0))
+ # save sr video
+ sr_video_writer.write(out_img)
+
+ # make compared video and crop shot of left top\right top\center\left bottom\right bottom
+ out_img = ToPILImage()(out_img)
+ crop_out_imgs = transforms.FiveCrop(size=out_img.width // 5 - 9)(out_img)
+ crop_out_imgs = [np.asarray(transforms.Pad(padding=(10, 5, 0, 0))(img)) for img in crop_out_imgs]
+ out_img = transforms.Pad(padding=(5, 0, 0, 5))(out_img)
+ compared_img = transforms.Resize(size=(sr_video_size[1], sr_video_size[0]), interpolation=Image.BICUBIC)(
+ ToPILImage()(frame))
+ crop_compared_imgs = transforms.FiveCrop(size=compared_img.width // 5 - 9)(compared_img)
+ crop_compared_imgs = [np.asarray(transforms.Pad(padding=(0, 5, 10, 0))(img)) for img in crop_compared_imgs]
+ compared_img = transforms.Pad(padding=(0, 0, 5, 5))(compared_img)
+ # concatenate all the pictures to one single picture
+ top_image = np.concatenate((np.asarray(compared_img), np.asarray(out_img)), axis=1)
+ bottom_image = np.concatenate(crop_compared_imgs + crop_out_imgs, axis=1)
+ bottom_image = np.asarray(transforms.Resize(
+ size=(int(top_image.shape[1] / bottom_image.shape[1] * bottom_image.shape[0]), top_image.shape[1]))(
+ ToPILImage()(bottom_image)))
+ final_image = np.concatenate((top_image, bottom_image))
+ # save compared video
+ compared_video_writer.write(final_image)
+ # next frame
+ success, frame = videoCapture.read()
diff --git a/train.py b/train.py
new file mode 100755
index 0000000..a2e21b7
--- /dev/null
+++ b/train.py
@@ -0,0 +1,166 @@
+import argparse
+import os
+from math import log10
+
+import pandas as pd
+import torch.optim as optim
+import torch.utils.data
+import torchvision.utils as utils
+from torch.autograd import Variable
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+import pytorch_ssim
+from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform
+from loss import GeneratorLoss
+from model import Generator, Discriminator
+
+parser = argparse.ArgumentParser(description='Train Super Resolution Models')
+parser.add_argument('--crop_size', default=88, type=int, help='training images crop size')
+parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8],
+ help='super resolution upscale factor')
+parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')
+
+
+if __name__ == '__main__':
+ opt = parser.parse_args()
+
+ CROP_SIZE = opt.crop_size
+ UPSCALE_FACTOR = opt.upscale_factor
+ NUM_EPOCHS = opt.num_epochs
+
+ train_set = TrainDatasetFromFolder('data/DIV2K_train_HR', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
+ val_set = ValDatasetFromFolder('data/DIV2K_valid_HR', upscale_factor=UPSCALE_FACTOR)
+ train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
+ val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)
+
+ netG = Generator(UPSCALE_FACTOR)
+ print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
+ netD = Discriminator()
+ print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
+
+ generator_criterion = GeneratorLoss()
+
+ if torch.cuda.is_available():
+ netG.cuda()
+ netD.cuda()
+ generator_criterion.cuda()
+
+ optimizerG = optim.Adam(netG.parameters())
+ optimizerD = optim.Adam(netD.parameters())
+
+ results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
+
+ for epoch in range(1, NUM_EPOCHS + 1):
+ train_bar = tqdm(train_loader)
+ running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}
+
+ netG.train()
+ netD.train()
+ for data, target in train_bar:
+ g_update_first = True
+ batch_size = data.size(0)
+ running_results['batch_sizes'] += batch_size
+
+ ############################
+ # (1) Update D network: maximize D(x)-1-D(G(z))
+ ###########################
+ real_img = Variable(target)
+ if torch.cuda.is_available():
+ real_img = real_img.cuda()
+ z = Variable(data)
+ if torch.cuda.is_available():
+ z = z.cuda()
+ fake_img = netG(z)
+
+ netD.zero_grad()
+ real_out = netD(real_img).mean()
+ fake_out = netD(fake_img).mean()
+ d_loss = 1 - real_out + fake_out
+ d_loss.backward(retain_graph=True)
+ optimizerD.step()
+
+ ############################
+ # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
+ ###########################
+ netG.zero_grad()
+ g_loss = generator_criterion(fake_out, fake_img, real_img)
+ g_loss.backward()
+
+ fake_img = netG(z)
+ fake_out = netD(fake_img).mean()
+
+
+ optimizerG.step()
+
+ # loss for current batch before optimization
+ running_results['g_loss'] += g_loss.item() * batch_size
+ running_results['d_loss'] += d_loss.item() * batch_size
+ running_results['d_score'] += real_out.item() * batch_size
+ running_results['g_score'] += fake_out.item() * batch_size
+
+ train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
+ epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
+ running_results['g_loss'] / running_results['batch_sizes'],
+ running_results['d_score'] / running_results['batch_sizes'],
+ running_results['g_score'] / running_results['batch_sizes']))
+
+ netG.eval()
+ out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
+ if not os.path.exists(out_path):
+ os.makedirs(out_path)
+
+ with torch.no_grad():
+ val_bar = tqdm(val_loader)
+ valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
+ val_images = []
+ for val_lr, val_hr_restore, val_hr in val_bar:
+ batch_size = val_lr.size(0)
+ valing_results['batch_sizes'] += batch_size
+ lr = val_lr
+ hr = val_hr
+ if torch.cuda.is_available():
+ lr = lr.cuda()
+ hr = hr.cuda()
+ sr = netG(lr)
+
+ batch_mse = ((sr - hr) ** 2).data.mean()
+ valing_results['mse'] += batch_mse * batch_size
+ batch_ssim = pytorch_ssim.ssim(sr, hr).item()
+ valing_results['ssims'] += batch_ssim * batch_size
+ valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes']))
+ valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
+ val_bar.set_description(
+ desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
+ valing_results['psnr'], valing_results['ssim']))
+
+ val_images.extend(
+ [display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
+ display_transform()(sr.data.cpu().squeeze(0))])
+ val_images = torch.stack(val_images)
+ val_images = torch.chunk(val_images, val_images.size(0) // 15)
+ val_save_bar = tqdm(val_images, desc='[saving training results]')
+ index = 1
+ for image in val_save_bar:
+ image = utils.make_grid(image, nrow=3, padding=5)
+ utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
+ index += 1
+
+ # save model parameters
+ torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
+ torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
+ # save loss\scores\psnr\ssim
+ results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
+ results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
+ results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
+ results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
+ results['psnr'].append(valing_results['psnr'])
+ results['ssim'].append(valing_results['ssim'])
+
+ if epoch % 10 == 0 and epoch != 0:
+ out_path = 'statistics/'
+ data_frame = pd.DataFrame(
+ data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
+ 'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
+ index=range(1, epoch + 1))
+ data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')
diff --git a/training_results/.gitkeep b/training_results/.gitkeep
new file mode 100644
index 0000000..e69de29