0013db4
Initial Commit lzfxxx 6 years ago
65 changed file(s) with 5101 addition(s) and 0 deletion(s). Raw diff Collapse all Expand all
0 .idea/
1 *.pyc
2 *.swp
3 .DS_Store
4 /.localenv/
5 /datasets/
6 /.ipynb_checkpoints/
0 Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
1 All rights reserved.
2
3 Redistribution and use in source and binary forms, with or without
4 modification, are permitted provided that the following conditions are met:
5
6 * Redistributions of source code must retain the above copyright notice, this
7 list of conditions and the following disclaimer.
8
9 * Redistributions in binary form must reproduce the above copyright notice,
10 this list of conditions and the following disclaimer in the documentation
11 and/or other materials provided with the distribution.
12
13 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
16 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
17 FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
18 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
19 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
20 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
21 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23
24
25 --------------------------- LICENSE FOR pix2pix --------------------------------
26 BSD License
27
28 For pix2pix software
29 Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
30 All rights reserved.
31
32 Redistribution and use in source and binary forms, with or without
33 modification, are permitted provided that the following conditions are met:
34
35 * Redistributions of source code must retain the above copyright notice, this
36 list of conditions and the following disclaimer.
37
38 * Redistributions in binary form must reproduce the above copyright notice,
39 this list of conditions and the following disclaimer in the documentation
40 and/or other materials provided with the distribution.
41
42 ----------------------------- LICENSE FOR DCGAN --------------------------------
43 BSD License
44
45 For dcgan.torch software
46
47 Copyright (c) 2015, Facebook, Inc. All rights reserved.
48
49 Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
50
51 Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
52
53 Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
54
55 Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
56
57 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
0 <img src='imgs/horse2zebra.gif' align="right" width=384>
1
2 <br><br><br>
3
4 # CycleGAN and pix2pix in PyTorch
5
6 We provide PyTorch implementations for both unpaired and paired image-to-image translation.
7
8 The code was written by [Jun-Yan Zhu](https://github.com/junyanz) and [Taesung Park](https://github.com/taesung), and supported by [Tongzhou Wang](https://ssnl.github.io/).
9
10 This PyTorch implementation produces results comparable to or better than our original Torch software. If you would like to reproduce the same results as in the papers, check out the original [CycleGAN Torch](https://github.com/junyanz/CycleGAN) and [pix2pix Torch](https://github.com/phillipi/pix2pix) code
11
12 **Note**: The current software works well with PyTorch 0.41+. Check out the older [branch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/tree/pytorch0.3.1) that supports PyTorch 0.1-0.3.
13
14 You may find useful information in [training/test tips](docs/tips.md) and [frequently asked questions](docs/qa.md). To implement custom models and datasets, check out our [templates](#custom-model-and-dataset). To help users better understand and adapt our codebase, we provide an [overview](docs/overview.md) of the code structure of this repository.
15
16 **CycleGAN: [Project](https://junyanz.github.io/CycleGAN/) | [Paper](https://arxiv.org/pdf/1703.10593.pdf) | [Torch](https://github.com/junyanz/CycleGAN)**
17 <img src="https://junyanz.github.io/CycleGAN/images/teaser_high_res.jpg" width="800"/>
18
19
20 **Pix2pix: [Project](https://phillipi.github.io/pix2pix/) | [Paper](https://arxiv.org/pdf/1611.07004.pdf) | [Torch](https://github.com/phillipi/pix2pix)**
21
22 <img src="https://phillipi.github.io/pix2pix/images/teaser_v3.png" width="800px"/>
23
24
25 **[EdgesCats Demo](https://affinelayer.com/pixsrv/) | [pix2pix-tensorflow](https://github.com/affinelayer/pix2pix-tensorflow) | by [Christopher Hesse](https://twitter.com/christophrhesse)**
26
27 <img src='imgs/edges2cats.jpg' width="400px"/>
28
29 If you use this code for your research, please cite:
30
31 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.<br>
32 [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/)\*, [Taesung Park](https://taesung.me/)\*, [Phillip Isola](https://people.eecs.berkeley.edu/~isola/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros). In ICCV 2017. (* equal contributions) [[Bibtex]](https://junyanz.github.io/CycleGAN/CycleGAN.txt)
33
34
35 Image-to-Image Translation with Conditional Adversarial Networks.<br>
36 [Phillip Isola](https://people.eecs.berkeley.edu/~isola), [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz), [Tinghui Zhou](https://people.eecs.berkeley.edu/~tinghuiz), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros). In CVPR 2017. [[Bibtex]](http://people.csail.mit.edu/junyanz/projects/pix2pix/pix2pix.bib)
37
38 ## Talks and Course
39 pix2pix slides: [keynote](http://efrosgans.eecs.berkeley.edu/CVPR18_slides/pix2pix.key) | [pdf](http://efrosgans.eecs.berkeley.edu/CVPR18_slides/pix2pix.pdf),
40 CycleGAN slides: [pptx](http://efrosgans.eecs.berkeley.edu/CVPR18_slides/CycleGAN.pptx) | [pdf](http://efrosgans.eecs.berkeley.edu/CVPR18_slides/CycleGAN.pdf)
41
42 CycleGAN course assignment [code](http://www.cs.toronto.edu/~rgrosse/courses/csc321_2018/assignments/a4-code.zip) and [handout](http://www.cs.toronto.edu/~rgrosse/courses/csc321_2018/assignments/a4-handout.pdf) designed by Prof. [Roger Grosse](http://www.cs.toronto.edu/~rgrosse/) for [CSC321](http://www.cs.toronto.edu/~rgrosse/courses/csc321_2018/) "Intro to Neural Networks and Machine Learning" at University of Toronto. Please contact the instructor if you would like to adopt it in your course.
43
44 ## Other implementations
45 ### CycleGAN
46 <p><a href="https://github.com/leehomyc/cyclegan-1"> [Tensorflow]</a> (by Harry Yang),
47 <a href="https://github.com/architrathore/CycleGAN/">[Tensorflow]</a> (by Archit Rathore),
48 <a href="https://github.com/vanhuyz/CycleGAN-TensorFlow">[Tensorflow]</a> (by Van Huy),
49 <a href="https://github.com/XHUJOY/CycleGAN-tensorflow">[Tensorflow]</a> (by Xiaowei Hu),
50 <a href="https://github.com/LynnHo/CycleGAN-Tensorflow-Simple"> [Tensorflow-simple]</a> (by Zhenliang He),
51 <a href="https://github.com/luoxier/CycleGAN_Tensorlayer"> [TensorLayer]</a> (by luoxier),
52 <a href="https://github.com/Aixile/chainer-cyclegan">[Chainer]</a> (by Yanghua Jin),
53 <a href="https://github.com/yunjey/mnist-svhn-transfer">[Minimal PyTorch]</a> (by yunjey),
54 <a href="https://github.com/Ldpe2G/DeepLearningForFun/tree/master/Mxnet-Scala/CycleGAN">[Mxnet]</a> (by Ldpe2G),
55 <a href="https://github.com/tjwei/GANotebooks">[lasagne/Keras]</a> (by tjwei),
56 <a href="https://github.com/simontomaskarlsson/CycleGAN-Keras">[Keras]</a> (by Simon Karlsson)
57 </p>
58 </ul>
59
60 ### pix2pix
61 <p><a href="https://github.com/affinelayer/pix2pix-tensorflow"> [Tensorflow]</a> (by Christopher Hesse),
62 <a href="https://github.com/Eyyub/tensorflow-pix2pix">[Tensorflow]</a> (by Eyyüb Sariu),
63 <a href="https://github.com/datitran/face2face-demo"> [Tensorflow (face2face)]</a> (by Dat Tran),
64 <a href="https://github.com/awjuliani/Pix2Pix-Film"> [Tensorflow (film)]</a> (by Arthur Juliani),
65 <a href="https://github.com/kaonashi-tyc/zi2zi">[Tensorflow (zi2zi)]</a> (by Yuchen Tian),
66 <a href="https://github.com/pfnet-research/chainer-pix2pix">[Chainer]</a> (by mattya),
67 <a href="https://github.com/tjwei/GANotebooks">[tf/torch/keras/lasagne]</a> (by tjwei),
68 <a href="https://github.com/taey16/pix2pixBEGAN.pytorch">[Pytorch]</a> (by taey16)
69 </p>
70 </ul>
71
72 ## Prerequisites
73 - Linux or macOS
74 - Python 3
75 - CPU or NVIDIA GPU + CUDA CuDNN
76
77 ## Getting Started
78 ### Installation
79
80 - Clone this repo:
81 ```bash
82 git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
83 cd pytorch-CycleGAN-and-pix2pix
84 ```
85
86 - Install [PyTorch](http://pytorch.org and) 0.4+ and other dependencies (e.g., torchvision, [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate)).
87 - For pip users, please type the command `pip install -r requirements.txt`.
88 - For Conda users, we provide a installation script `./scripts/conda_deps.sh`. Alternatively, you can create a new Conda environment using `conda env create -f environment.yml`.
89 - For Docker users, we provide the pre-built Docker image and Dockerfile. Please refer to our [Docker](docs/docker.md) page.
90
91 ### CycleGAN train/test
92 - Download a CycleGAN dataset (e.g. maps):
93 ```bash
94 bash ./datasets/download_cyclegan_dataset.sh maps
95 ```
96 - To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097.
97 - Train a model:
98 ```bash
99 #!./scripts/train_cyclegan.sh
100 python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
101 ```
102 To see more intermediate results, check out `./checkpoints/maps_cyclegan/web/index.html`.
103 - Test the model:
104 ```bash
105 #!./scripts/test_cyclegan.sh
106 python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
107 ```
108 - The test results will be saved to a html file here: `./results/maps_cyclegan/latest_test/index.html`.
109
110 ### pix2pix train/test
111 - Download a pix2pix dataset (e.g.[facades](http://cmp.felk.cvut.cz/~tylecr1/facade/)):
112 ```bash
113 bash ./datasets/download_pix2pix_dataset.sh facades
114 ```
115 - To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097.
116 - Train a model:
117 ```bash
118 #!./scripts/train_pix2pix.sh
119 python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
120 ```
121 To see more intermediate results, check out `./checkpoints/facades_pix2pix/web/index.html`.
122
123 - Test the model (`bash ./scripts/test_pix2pix.sh`):
124 ```bash
125 #!./scripts/test_pix2pix.sh
126 python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
127 ```
128 - The test results will be saved to a html file here: `./results/facades_pix2pix/test_latest/index.html`. You can find more scripts at `scripts` directory.
129 - To train and test pix2pix-based colorization models, please add `--model colorization` and `--dataset_mode colorization`. See our training [tips](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md#notes-on-colorization) for more details.
130
131 ### Apply a pre-trained model (CycleGAN)
132 - You can download a pretrained model (e.g. horse2zebra) with the following script:
133 ```bash
134 bash ./scripts/download_cyclegan_model.sh horse2zebra
135 ```
136 - The pretrained model is saved at `./checkpoints/{name}_pretrained/latest_net_G.pth`. Check [here](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/scripts/download_cyclegan_model.sh#L3) for all the available CycleGAN models.
137 - To test the model, you also need to download the horse2zebra dataset:
138 ```bash
139 bash ./datasets/download_cyclegan_dataset.sh horse2zebra
140 ```
141
142 - Then generate the results using
143 ```bash
144 python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
145 ```
146 - The option `--model test` is used for generating results of CycleGAN only for one side. This option will automatically set `--dataset_mode single`, which only loads the images from one set. On the contrary, using `--model cycle_gan` requires loading and generating results in both directions, which is sometimes unnecessary. The results will be saved at `./results/`. Use `--results_dir {directory_path_to_save_result}` to specify the results directory.
147
148 - For your own experiments, you might want to specify `--netG`, `--norm`, `--no_dropout` to match the generator architecture of the trained model.
149
150 ### Apply a pre-trained model (pix2pix)
151 Download a pre-trained model with `./scripts/download_pix2pix_model.sh`.
152
153 - Check [here](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/scripts/download_pix2pix_model.sh#L3) for all the available pix2pix models. For example, if you would like to download label2photo model on the Facades dataset,
154 ```bash
155 bash ./scripts/download_pix2pix_model.sh facades_label2photo
156 ```
157 - Download the pix2pix facades datasets:
158 ```bash
159 bash ./datasets/download_pix2pix_dataset.sh facades
160 ```
161 - Then generate the results using
162 ```bash
163 python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_label2photo_pretrained
164 ```
165 - Note that we specified `--direction BtoA` as Facades dataset's A to B direction is photos to labels.
166
167 - If you would like to apply a pre-trained model to a collection of input images (rather than image pairs), please use `--model test` option. See `./scripts/test_single.sh` for how to apply a model to Facade label maps (stored in the directory `facades/testB`).
168
169 - See a list of currently available models at `./scripts/download_pix2pix_model.sh`
170
171 ## [Docker](docs/docker.md)
172 We provide the pre-built Docker image and Dockerfile that can run this code repo. See [docker](docs/docker.md).
173
174 ## [Datasets](docs/datasets.md)
175 Download pix2pix/CycleGAN datasets and create your own datasets.
176
177 ## [Training/Test Tips](docs/tips.md)
178 Best practice for training and testing your models.
179
180 ## [Frequently Asked Questions](docs/qa.md)
181 Before you post a new question, please first look at the above Q & A and existing GitHub issues.
182
183 ## Custom Model and Dataset
184 If you plan to implement custom models and dataset for your new applications, we provide a dataset [template](data/template_dataset.py) and a model [template](models/template_model.py) as a starting point.
185
186 ## [Code structure](docs/overview.md)
187 To help users better understand and use our code, we briefly overview the functionality and implementation of each package and each module.
188
189 ## Pull Request
190 You are always welcome to contribute to this repository by sending a [pull request](https://help.github.com/articles/about-pull-requests/).
191 Please run `flake8 --ignore E501 .` and `python ./scripts/test_before_push.py` before you commit the code. Please also update the code structure [overview](docs/overview.md) accordingly if you add or remove files.
192
193 ## Citation
194 If you use this code for your research, please cite our papers.
195 ```
196 @inproceedings{CycleGAN2017,
197 title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss},
198 author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
199 booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on},
200 year={2017}
201 }
202
203
204 @inproceedings{isola2017image,
205 title={Image-to-Image Translation with Conditional Adversarial Networks},
206 author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A},
207 booktitle={Computer Vision and Pattern Recognition (CVPR), 2017 IEEE Conference on},
208 year={2017}
209 }
210 ```
211
212
213
214 ## Related Projects
215 **[CycleGAN-Torch](https://github.com/junyanz/CycleGAN) |
216 [pix2pix-Torch](https://github.com/phillipi/pix2pix) | [pix2pixHD](https://github.com/NVIDIA/pix2pixHD)|
217 [BicycleGAN](https://github.com/junyanz/BicycleGAN) | [vid2vid](https://tcwang0509.github.io/vid2vid/) | [SPADE/GauGAN](https://github.com/NVlabs/SPADE)**<br>
218 **[iGAN](https://github.com/junyanz/iGAN) | [GAN Dissection](https://github.com/CSAILVision/GANDissect) | [GAN Paint](http://ganpaint.io/)**
219
220 ## Cat Paper Collection
221 If you love cats, and love reading cool graphics, vision, and learning papers, please check out the Cat Paper [Collection](https://github.com/junyanz/CatPapers).
222
223 ## Acknowledgments
224 Our code is inspired by [pytorch-DCGAN](https://github.com/pytorch/examples/tree/master/dcgan).
0 ## 介绍 (Introduction)
1
2 添加该项目的功能、使用场景和输入输出参数等相关信息。
3
4 You can describe the function, usage and parameters of the project.
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "## 1. 项目介绍\n",
7 "\n",
8 " - 项目是由模块组成、有特定功能的程序。它能够满足用户的直接使用需求,例如[古诗词生成器](https://momodel.cn/explore/5bfb634e1afd943c623dd9cf?type=app&tab=1)、[风格迁移](https://momodel.cn/explore/5bfb634e1afd943c623dd9cf?type=app&tab=1)等。\n",
9 " - 开发项目过程中你可以导入数据集,也可以通过每个 cell 上方工具栏的`<+>`直接插入[模块](https://momodel.cn/modules)和代码块。\n",
10 " - 你可以将开发好的项目进行[部署](https://momodel.cn/docs/#/zh-cn/%E5%BC%80%E5%8F%91%E5%92%8C%E9%83%A8%E7%BD%B2%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8%EF%BC%88app%EF%BC%89),项目部署成功并选择正式版本发布后会展示在“项目”页面,用户可以在线使用,也可以通过 API 调用。\n",
11 " \n",
12 " - 项目目录结构:\n",
13 " \n",
14 " - ```results```*-----结果的文件存放地(如果你运行 job,务必将运行结果指定在此目录)*\n",
15 " - ```_OVERVIEW.md``` *-----项目的相关介绍*\n",
16 " - ```_README.md```*-----说明文档*\n",
17 " - ```app_spec.yml```*-----定义项目的输入输出,为部署服务*\n",
18 " - ```coding_here.ipynb```*-----输入并运行代码*"
19 ]
20 },
21 {
22 "cell_type": "markdown",
23 "metadata": {},
24 "source": [
25 " \n",
26 "## 2. 开发环境简介\n",
27 "\n",
28 "你当前所在的页面 Notebook 是一个内嵌 JupyterLab 的在线类 IDE 编程环境,开发过程中可以使用页面右侧的 API 文档进行快速查询。Notebook 有以下主要功能:\n",
29 "\n",
30 "- [调用数据集、模块和代码块资源](https://momodel.cn/docs/#/zh-cn/%E5%A6%82%E4%BD%95%E5%AF%BC%E5%85%A5%E5%B9%B6%E4%BD%BF%E7%94%A8%E6%A8%A1%E5%9D%97%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86)\n",
31 "- [多人代码协作](https://momodel.cn/docs/#/zh-cn/%E5%9C%A8Mo%E8%BF%90%E8%A1%8C%E4%BD%A0%E7%9A%84%E7%AC%AC%E4%B8%80%E6%AE%B5%E4%BB%A3%E7%A0%81?id=_7-%e4%bd%a0%e5%8f%af%e4%bb%a5%e9%82%80%e8%af%b7%e5%a5%bd%e5%8f%8b%e8%bf%9b%e8%a1%8c%e5%8d%8f%e4%bd%9c)\n",
32 "- [在 GPU 资源上训练机器学习模型](https://momodel.cn/docs/#/zh-cn/%E5%9C%A8GPU%E6%88%96CPU%E8%B5%84%E6%BA%90%E4%B8%8A%E8%AE%AD%E7%BB%83%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E6%A8%A1%E5%9E%8B)\n",
33 "- [简单部署](https://momodel.cn/docs/#/zh-cn/%E5%BC%80%E5%8F%91%E5%92%8C%E9%83%A8%E7%BD%B2%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8%EF%BC%88app%EF%BC%89)\n",
34 "\n",
35 "快来动手试试吧!点击左侧工具栏的新建文件图标即可选择你需要的文件类型。\n",
36 "\n",
37 "<img src='http://ww4.sinaimg.cn/large/006tNc79gy1g61agfcv23j31c30u0789.jpg' width=100% height=100%>\n",
38 " \n",
39 "\n",
40 "\n",
41 "左侧和右侧工具栏都可根据使用需要进行收合。\n",
42 "<img src='http://ww4.sinaimg.cn/large/006tNc79gy1g61aw7bj3vg31hc0u0u0x.gif' width=100% height=100%>"
43 ]
44 },
45 {
46 "cell_type": "markdown",
47 "metadata": {},
48 "source": [
49 "## 3. 快捷键与代码补全\n",
50 "Mo Notebook 已完全采用 Jupyter Notebook 的原生快捷键,并且支持 `tab` 代码补全。\n",
51 "\n",
52 "运行代码:`shift` + `enter` 或者 `shift` + `return`"
53 ]
54 },
55 {
56 "cell_type": "markdown",
57 "metadata": {},
58 "source": [
59 "## 4. 常用指令介绍\n",
60 "\n",
61 "- 解压上传后的文件\n",
62 "\n",
63 "在 cell 中输入并运行以下命令:\n",
64 "```!unzip -o file_name.zip```\n",
65 "\n",
66 "- 查看所有包(package)\n",
67 "\n",
68 "`!pip list --format=columns`\n",
69 "\n",
70 "- 检查是否已有某个包\n",
71 "\n",
72 "`!pip show package_name`\n",
73 "\n",
74 "- 安装缺失的包\n",
75 "\n",
76 "`!pip install package_name`\n",
77 "\n",
78 "- 更新已有的包\n",
79 "\n",
80 "`!pip install package_name --upgrade`\n",
81 "\n",
82 "\n",
83 "- 使用包\n",
84 "\n",
85 "`import package_name`\n",
86 "\n",
87 "- 显示当前目录下的档案及目录\n",
88 "\n",
89 "`ls`\n",
90 "\n",
91 "- 使用引入的数据集\n",
92 "\n",
93 "数据集被引入后存放在 datasets 文件夹下,注意,这个文件夹是只读的,不可修改。如果需要修改,可在 Notebook 中使用\n",
94 "\n",
95 "`!cp -R ./datasets/<imported_dataset_dir> ./<your_folder>`\n",
96 "\n",
97 "指令将其复制到其他文件夹后再编辑,对于引入的数据集中的 zip 文件,可使用\n",
98 "\n",
99 "`!unzip ./datasets/<imported_dataset_dir>/<XXX.zip> -d ./<your_folder>`\n",
100 "\n",
101 "指令解压缩到其他文件夹后使用"
102 ]
103 },
104 {
105 "cell_type": "markdown",
106 "metadata": {},
107 "source": [
108 "## 5. 其他可参考资源\n",
109 "- [帮助文档](https://momodel.cn/docs/#/):基本页面介绍和常见问题都可以在里面找到\n",
110 "- [平台功能教程](https://momodel.cn/classroom/class?id=5c5696cd1afd9458d456bf54&type=doc):通过图文结合的 Notebook 详细介绍开发环境基本功能和操作\n",
111 "- [吴恩达机器学习](https://momodel.cn/classroom/class?id=5c5696191afd94720cc94533&type=video):机器学习经典课程\n",
112 "- [李宏毅机器学习](https://s.momodel.cn/classroom/class?id=5d40fdafb5113408a8dbb4a1&type=video):中文世界最好的机器学习课程\n",
113 "- [机器学习实战](https://momodel.cn/classroom/class?id=5c680b311afd943a9f70901b&type=practice):通过实操指引完成独立的模型,掌握相应的机器学习知识\n",
114 "- [Python 教程](https://momodel.cn/classroom/class?id=5d1f3ab81afd940ab7d298bf&type=notebook):简单易懂的 Python 新手教程\n",
115 "- [模块开发](https://momodel.cn/modules):关于模型训练、开发与部署的高阶教程"
116 ]
117 }
118 ],
119 "metadata": {
120 "kernelspec": {
121 "display_name": "Python 3",
122 "language": "python",
123 "name": "python3"
124 },
125 "language_info": {
126 "codemirror_mode": {
127 "name": "ipython",
128 "version": 3
129 },
130 "file_extension": ".py",
131 "mimetype": "text/x-python",
132 "name": "python",
133 "nbconvert_exporter": "python",
134 "pygments_lexer": "ipython3",
135 "version": "3.5.2"
136 }
137 },
138 "nbformat": 4,
139 "nbformat_minor": 2
140 }
0
1 {
2 "cells": [
3 {
4 "cell_type": "code",
5 "execution_count": null,
6 "metadata": {},
7 "outputs": [],
8 "source": [
9 "print('Hello Mo!')"
10 ]
11 }
12 ],
13 "metadata": {
14 "kernelspec": {
15 "display_name": "Python 3",
16 "language": "python",
17 "name": "python3"
18 },
19 "language_info": {
20 "codemirror_mode": {
21 "name": "ipython",
22 "version": 3
23 },
24 "file_extension": ".py",
25 "mimetype": "text/x-python",
26 "name": "python",
27 "nbconvert_exporter": "python",
28 "pygments_lexer": "ipython3",
29 "version": "3.5.2"
30 }
31 },
32 "nbformat": 4,
33 "nbformat_minor": 2
34 }
35
0 """This package includes all the modules related to data loading and preprocessing
1
2 To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
3 You need to implement four functions:
4 -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
5 -- <__len__>: return the size of dataset.
6 -- <__getitem__>: get a data point from data loader.
7 -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
8
9 Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
10 See our template dataset class 'template_dataset.py' for more details.
11 """
12 import importlib
13 import torch.utils.data
14 from data.base_dataset import BaseDataset
15
16
17 def find_dataset_using_name(dataset_name):
18 """Import the module "data/[dataset_name]_dataset.py".
19
20 In the file, the class called DatasetNameDataset() will
21 be instantiated. It has to be a subclass of BaseDataset,
22 and it is case-insensitive.
23 """
24 dataset_filename = "data." + dataset_name + "_dataset"
25 datasetlib = importlib.import_module(dataset_filename)
26
27 dataset = None
28 target_dataset_name = dataset_name.replace('_', '') + 'dataset'
29 for name, cls in datasetlib.__dict__.items():
30 if name.lower() == target_dataset_name.lower() \
31 and issubclass(cls, BaseDataset):
32 dataset = cls
33
34 if dataset is None:
35 raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
36
37 return dataset
38
39
40 def get_option_setter(dataset_name):
41 """Return the static method <modify_commandline_options> of the dataset class."""
42 dataset_class = find_dataset_using_name(dataset_name)
43 return dataset_class.modify_commandline_options
44
45
46 def create_dataset(opt):
47 """Create a dataset given the option.
48
49 This function wraps the class CustomDatasetDataLoader.
50 This is the main interface between this package and 'train.py'/'test.py'
51
52 Example:
53 >>> from data import create_dataset
54 >>> dataset = create_dataset(opt)
55 """
56 data_loader = CustomDatasetDataLoader(opt)
57 dataset = data_loader.load_data()
58 return dataset
59
60
61 class CustomDatasetDataLoader():
62 """Wrapper class of Dataset class that performs multi-threaded data loading"""
63
64 def __init__(self, opt):
65 """Initialize this class
66
67 Step 1: create a dataset instance given the name [dataset_mode]
68 Step 2: create a multi-threaded data loader.
69 """
70 self.opt = opt
71 dataset_class = find_dataset_using_name(opt.dataset_mode)
72 self.dataset = dataset_class(opt)
73 print("dataset [%s] was created" % type(self.dataset).__name__)
74 self.dataloader = torch.utils.data.DataLoader(
75 self.dataset,
76 batch_size=opt.batch_size,
77 shuffle=not opt.serial_batches,
78 num_workers=int(opt.num_threads))
79
80 def load_data(self):
81 return self
82
83 def __len__(self):
84 """Return the number of data in the dataset"""
85 return min(len(self.dataset), self.opt.max_dataset_size)
86
87 def __iter__(self):
88 """Return a batch of data"""
89 for i, data in enumerate(self.dataloader):
90 if i * self.opt.batch_size >= self.opt.max_dataset_size:
91 break
92 yield data
0 import os.path
1 from data.base_dataset import BaseDataset, get_params, get_transform
2 from data.image_folder import make_dataset
3 from PIL import Image
4
5
6 class AlignedDataset(BaseDataset):
7 """A dataset class for paired image dataset.
8
9 It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
10 During test time, you need to prepare a directory '/path/to/data/test'.
11 """
12
13 def __init__(self, opt):
14 """Initialize this dataset class.
15
16 Parameters:
17 opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
18 """
19 BaseDataset.__init__(self, opt)
20 self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory
21 self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths
22 assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image
23 self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
24 self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc
25
26 def __getitem__(self, index):
27 """Return a data point and its metadata information.
28
29 Parameters:
30 index - - a random integer for data indexing
31
32 Returns a dictionary that contains A, B, A_paths and B_paths
33 A (tensor) - - an image in the input domain
34 B (tensor) - - its corresponding image in the target domain
35 A_paths (str) - - image paths
36 B_paths (str) - - image paths (same as A_paths)
37 """
38 # read a image given a random integer index
39 AB_path = self.AB_paths[index]
40 AB = Image.open(AB_path).convert('RGB')
41 # split AB image into A and B
42 w, h = AB.size
43 w2 = int(w / 2)
44 A = AB.crop((0, 0, w2, h))
45 B = AB.crop((w2, 0, w, h))
46
47 # apply the same transform to both A and B
48 transform_params = get_params(self.opt, A.size)
49 A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
50 B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))
51
52 A = A_transform(A)
53 B = B_transform(B)
54
55 return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
56
57 def __len__(self):
58 """Return the total number of images in the dataset."""
59 return len(self.AB_paths)
0 """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
1
2 It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
3 """
4 import random
5 import numpy as np
6 import torch.utils.data as data
7 from PIL import Image
8 import torchvision.transforms as transforms
9 from abc import ABC, abstractmethod
10
11
12 class BaseDataset(data.Dataset, ABC):
13 """This class is an abstract base class (ABC) for datasets.
14
15 To create a subclass, you need to implement the following four functions:
16 -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
17 -- <__len__>: return the size of dataset.
18 -- <__getitem__>: get a data point.
19 -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
20 """
21
22 def __init__(self, opt):
23 """Initialize the class; save the options in the class
24
25 Parameters:
26 opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
27 """
28 self.opt = opt
29 self.root = opt.dataroot
30
31 @staticmethod
32 def modify_commandline_options(parser, is_train):
33 """Add new dataset-specific options, and rewrite default values for existing options.
34
35 Parameters:
36 parser -- original option parser
37 is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
38
39 Returns:
40 the modified parser.
41 """
42 return parser
43
44 @abstractmethod
45 def __len__(self):
46 """Return the total number of images in the dataset."""
47 return 0
48
49 @abstractmethod
50 def __getitem__(self, index):
51 """Return a data point and its metadata information.
52
53 Parameters:
54 index - - a random integer for data indexing
55
56 Returns:
57 a dictionary of data with their names. It ususally contains the data itself and its metadata information.
58 """
59 pass
60
61
62 def get_params(opt, size):
63 w, h = size
64 new_h = h
65 new_w = w
66 if opt.preprocess == 'resize_and_crop':
67 new_h = new_w = opt.load_size
68 elif opt.preprocess == 'scale_width_and_crop':
69 new_w = opt.load_size
70 new_h = opt.load_size * h // w
71
72 x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
73 y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
74
75 flip = random.random() > 0.5
76
77 return {'crop_pos': (x, y), 'flip': flip}
78
79
80 def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
81 transform_list = []
82 if grayscale:
83 transform_list.append(transforms.Grayscale(1))
84 if 'resize' in opt.preprocess:
85 osize = [opt.load_size, opt.load_size]
86 transform_list.append(transforms.Resize(osize, method))
87 elif 'scale_width' in opt.preprocess:
88 transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
89
90 if 'crop' in opt.preprocess:
91 if params is None:
92 transform_list.append(transforms.RandomCrop(opt.crop_size))
93 else:
94 transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
95
96 if opt.preprocess == 'none':
97 transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
98
99 if not opt.no_flip:
100 if params is None:
101 transform_list.append(transforms.RandomHorizontalFlip())
102 elif params['flip']:
103 transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
104
105 if convert:
106 transform_list += [transforms.ToTensor()]
107 if grayscale:
108 transform_list += [transforms.Normalize((0.5,), (0.5,))]
109 else:
110 transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
111 return transforms.Compose(transform_list)
112
113
114 def __make_power_2(img, base, method=Image.BICUBIC):
115 ow, oh = img.size
116 h = int(round(oh / base) * base)
117 w = int(round(ow / base) * base)
118 if (h == oh) and (w == ow):
119 return img
120
121 __print_size_warning(ow, oh, w, h)
122 return img.resize((w, h), method)
123
124
125 def __scale_width(img, target_width, method=Image.BICUBIC):
126 ow, oh = img.size
127 if (ow == target_width):
128 return img
129 w = target_width
130 h = int(target_width * oh / ow)
131 return img.resize((w, h), method)
132
133
134 def __crop(img, pos, size):
135 ow, oh = img.size
136 x1, y1 = pos
137 tw = th = size
138 if (ow > tw or oh > th):
139 return img.crop((x1, y1, x1 + tw, y1 + th))
140 return img
141
142
143 def __flip(img, flip):
144 if flip:
145 return img.transpose(Image.FLIP_LEFT_RIGHT)
146 return img
147
148
149 def __print_size_warning(ow, oh, w, h):
150 """Print warning information about image size(only print once)"""
151 if not hasattr(__print_size_warning, 'has_printed'):
152 print("The image size needs to be a multiple of 4. "
153 "The loaded image size was (%d, %d), so it was adjusted to "
154 "(%d, %d). This adjustment will be done to all images "
155 "whose sizes are not multiples of 4" % (ow, oh, w, h))
156 __print_size_warning.has_printed = True
0 import os.path
1 from data.base_dataset import BaseDataset, get_transform
2 from data.image_folder import make_dataset
3 from skimage import color # require skimage
4 from PIL import Image
5 import numpy as np
6 import torchvision.transforms as transforms
7
8
9 class ColorizationDataset(BaseDataset):
10 """This dataset class can load a set of natural images in RGB, and convert RGB format into (L, ab) pairs in Lab color space.
11
12 This dataset is required by pix2pix-based colorization model ('--model colorization')
13 """
14 @staticmethod
15 def modify_commandline_options(parser, is_train):
16 """Add new dataset-specific options, and rewrite default values for existing options.
17
18 Parameters:
19 parser -- original option parser
20 is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
21
22 Returns:
23 the modified parser.
24
25 By default, the number of channels for input image is 1 (L) and
26 the nubmer of channels for output image is 2 (ab). The direction is from A to B
27 """
28 parser.set_defaults(input_nc=1, output_nc=2, direction='AtoB')
29 return parser
30
31 def __init__(self, opt):
32 """Initialize this dataset class.
33
34 Parameters:
35 opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
36 """
37 BaseDataset.__init__(self, opt)
38 self.dir = os.path.join(opt.dataroot, opt.phase)
39 self.AB_paths = sorted(make_dataset(self.dir, opt.max_dataset_size))
40 assert(opt.input_nc == 1 and opt.output_nc == 2 and opt.direction == 'AtoB')
41 self.transform = get_transform(self.opt, convert=False)
42
43 def __getitem__(self, index):
44 """Return a data point and its metadata information.
45
46 Parameters:
47 index - - a random integer for data indexing
48
49 Returns a dictionary that contains A, B, A_paths and B_paths
50 A (tensor) - - the L channel of an image
51 B (tensor) - - the ab channels of the same image
52 A_paths (str) - - image paths
53 B_paths (str) - - image paths (same as A_paths)
54 """
55 path = self.AB_paths[index]
56 im = Image.open(path).convert('RGB')
57 im = self.transform(im)
58 im = np.array(im)
59 lab = color.rgb2lab(im).astype(np.float32)
60 lab_t = transforms.ToTensor()(lab)
61 A = lab_t[[0], ...] / 50.0 - 1.0
62 B = lab_t[[1, 2], ...] / 110.0
63 return {'A': A, 'B': B, 'A_paths': path, 'B_paths': path}
64
65 def __len__(self):
66 """Return the total number of images in the dataset."""
67 return len(self.AB_paths)
0 """A modified image folder class
1
2 We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
3 so that this class can load images from both current directory and its subdirectories.
4 """
5
6 import torch.utils.data as data
7
8 from PIL import Image
9 import os
10 import os.path
11
12 IMG_EXTENSIONS = [
13 '.jpg', '.JPG', '.jpeg', '.JPEG',
14 '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
15 ]
16
17
18 def is_image_file(filename):
19 return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
20
21
22 def make_dataset(dir, max_dataset_size=float("inf")):
23 images = []
24 assert os.path.isdir(dir), '%s is not a valid directory' % dir
25
26 for root, _, fnames in sorted(os.walk(dir)):
27 for fname in fnames:
28 if is_image_file(fname):
29 path = os.path.join(root, fname)
30 images.append(path)
31 return images[:min(max_dataset_size, len(images))]
32
33
34 def default_loader(path):
35 return Image.open(path).convert('RGB')
36
37
38 class ImageFolder(data.Dataset):
39
40 def __init__(self, root, transform=None, return_paths=False,
41 loader=default_loader):
42 imgs = make_dataset(root)
43 if len(imgs) == 0:
44 raise(RuntimeError("Found 0 images in: " + root + "\n"
45 "Supported image extensions are: " +
46 ",".join(IMG_EXTENSIONS)))
47
48 self.root = root
49 self.imgs = imgs
50 self.transform = transform
51 self.return_paths = return_paths
52 self.loader = loader
53
54 def __getitem__(self, index):
55 path = self.imgs[index]
56 img = self.loader(path)
57 if self.transform is not None:
58 img = self.transform(img)
59 if self.return_paths:
60 return img, path
61 else:
62 return img
63
64 def __len__(self):
65 return len(self.imgs)
0 from data.base_dataset import BaseDataset, get_transform
1 from data.image_folder import make_dataset
2 from PIL import Image
3
4
5 class SingleDataset(BaseDataset):
6 """This dataset class can load a set of images specified by the path --dataroot /path/to/data.
7
8 It can be used for generating CycleGAN results only for one side with the model option '-model test'.
9 """
10
11 def __init__(self, opt):
12 """Initialize this dataset class.
13
14 Parameters:
15 opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
16 """
17 BaseDataset.__init__(self, opt)
18 self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
19 input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
20 self.transform = get_transform(opt, grayscale=(input_nc == 1))
21
22 def __getitem__(self, index):
23 """Return a data point and its metadata information.
24
25 Parameters:
26 index - - a random integer for data indexing
27
28 Returns a dictionary that contains A and A_paths
29 A(tensor) - - an image in one domain
30 A_paths(str) - - the path of the image
31 """
32 A_path = self.A_paths[index]
33 A_img = Image.open(A_path).convert('RGB')
34 A = self.transform(A_img)
35 return {'A': A, 'A_paths': A_path}
36
37 def __len__(self):
38 """Return the total number of images in the dataset."""
39 return len(self.A_paths)
0 """Dataset class template
1
2 This module provides a template for users to implement custom datasets.
3 You can specify '--dataset_mode template' to use this dataset.
4 The class name should be consistent with both the filename and its dataset_mode option.
5 The filename should be <dataset_mode>_dataset.py
6 The class name should be <Dataset_mode>Dataset.py
7 You need to implement the following functions:
8 -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
9 -- <__init__>: Initialize this dataset class.
10 -- <__getitem__>: Return a data point and its metadata information.
11 -- <__len__>: Return the number of images.
12 """
13 from data.base_dataset import BaseDataset, get_transform
14 # from data.image_folder import make_dataset
15 # from PIL import Image
16
17
18 class TemplateDataset(BaseDataset):
19 """A template dataset class for you to implement custom datasets."""
20 @staticmethod
21 def modify_commandline_options(parser, is_train):
22 """Add new dataset-specific options, and rewrite default values for existing options.
23
24 Parameters:
25 parser -- original option parser
26 is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
27
28 Returns:
29 the modified parser.
30 """
31 parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
32 parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
33 return parser
34
35 def __init__(self, opt):
36 """Initialize this dataset class.
37
38 Parameters:
39 opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
40
41 A few things can be done here.
42 - save the options (have been done in BaseDataset)
43 - get image paths and meta information of the dataset.
44 - define the image transformation.
45 """
46 # save the option and dataset root
47 BaseDataset.__init__(self, opt)
48 # get the image paths of your dataset;
49 self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
50 # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
51 self.transform = get_transform(opt)
52
53 def __getitem__(self, index):
54 """Return a data point and its metadata information.
55
56 Parameters:
57 index -- a random integer for data indexing
58
59 Returns:
60 a dictionary of data with their names. It usually contains the data itself and its metadata information.
61
62 Step 1: get a random image path: e.g., path = self.image_paths[index]
63 Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
64 Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
65 Step 4: return a data point as a dictionary.
66 """
67 path = 'temp' # needs to be a string
68 data_A = None # needs to be a tensor
69 data_B = None # needs to be a tensor
70 return {'data_A': data_A, 'data_B': data_B, 'path': path}
71
72 def __len__(self):
73 """Return the total number of images."""
74 return len(self.image_paths)
0 import os.path
1 from data.base_dataset import BaseDataset, get_transform
2 from data.image_folder import make_dataset
3 from PIL import Image
4 import random
5
6
7 class UnalignedDataset(BaseDataset):
8 """
9 This dataset class can load unaligned/unpaired datasets.
10
11 It requires two directories to host training images from domain A '/path/to/data/trainA'
12 and from domain B '/path/to/data/trainB' respectively.
13 You can train the model with the dataset flag '--dataroot /path/to/data'.
14 Similarly, you need to prepare two directories:
15 '/path/to/data/testA' and '/path/to/data/testB' during test time.
16 """
17
18 def __init__(self, opt):
19 """Initialize this dataset class.
20
21 Parameters:
22 opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
23 """
24 BaseDataset.__init__(self, opt)
25 self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
26 self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
27
28 self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
29 self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
30 self.A_size = len(self.A_paths) # get the size of dataset A
31 self.B_size = len(self.B_paths) # get the size of dataset B
32 btoA = self.opt.direction == 'BtoA'
33 input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
34 output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
35 self.transform_A = get_transform(self.opt, grayscale=(input_nc == 1))
36 self.transform_B = get_transform(self.opt, grayscale=(output_nc == 1))
37
38 def __getitem__(self, index):
39 """Return a data point and its metadata information.
40
41 Parameters:
42 index (int) -- a random integer for data indexing
43
44 Returns a dictionary that contains A, B, A_paths and B_paths
45 A (tensor) -- an image in the input domain
46 B (tensor) -- its corresponding image in the target domain
47 A_paths (str) -- image paths
48 B_paths (str) -- image paths
49 """
50 A_path = self.A_paths[index % self.A_size] # make sure index is within then range
51 if self.opt.serial_batches: # make sure index is within then range
52 index_B = index % self.B_size
53 else: # randomize the index for domain B to avoid fixed pairs.
54 index_B = random.randint(0, self.B_size - 1)
55 B_path = self.B_paths[index_B]
56 A_img = Image.open(A_path).convert('RGB')
57 B_img = Image.open(B_path).convert('RGB')
58 # apply image transformation
59 A = self.transform_A(A_img)
60 B = self.transform_B(B_img)
61
62 return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
63
64 def __len__(self):
65 """Return the total number of images in the dataset.
66
67 As we have two datasets with potentially different number of images,
68 we take a maximum of
69 """
70 return max(self.A_size, self.B_size)
0 FROM nvidia/cuda:9.0-base
1
2 RUN apt update && apt install -y wget unzip curl bzip2 git
3 RUN curl -LO http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh
4 RUN bash Miniconda-latest-Linux-x86_64.sh -p /miniconda -b
5 RUN rm Miniconda-latest-Linux-x86_64.sh
6 ENV PATH=/miniconda/bin:${PATH}
7 RUN conda update -y conda
8
9 RUN conda install -y pytorch torchvision -c pytorch
10 RUN mkdir /workspace/ && cd /workspace/ && git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git && cd pytorch-CycleGAN-and-pix2pix && pip install -r requirements.txt
11
12 WORKDIR /workspace
0
1
2 ### CycleGAN Datasets
3 Download the CycleGAN datasets using the following script. Some of the datasets are collected by other researchers. Please cite their papers if you use the data.
4 ```bash
5 bash ./datasets/download_cyclegan_dataset.sh dataset_name
6 ```
7 - `facades`: 400 images from the [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade). [[Citation](../datasets/bibtex/facades.tex)]
8 - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com). [[Citation](../datasets/bibtex/cityscapes.tex)]. Note: Due to license issue, we cannot directly provide the Cityscapes dataset. Please download the Cityscapes dataset from [https://cityscapes-dataset.com](https://cityscapes-dataset.com) and use the script `./datasets/prepare_cityscapes_dataset.py`.
9 - `maps`: 1096 training images scraped from Google Maps.
10 - `horse2zebra`: 939 horse images and 1177 zebra images downloaded from [ImageNet](http://www.image-net.org) using keywords `wild horse` and `zebra`
11 - `apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org) using keywords `apple` and `navel orange`.
12 - `summer2winter_yosemite`: 1273 summer Yosemite images and 854 winter Yosemite images were downloaded using Flickr API. See more details in our paper.
13 - `monet2photo`, `vangogh2photo`, `ukiyoe2photo`, `cezanne2photo`: The art images were downloaded from [Wikiart](https://www.wikiart.org/). The real photos are downloaded from Flickr using the combination of the tags *landscape* and *landscapephotography*. The training set size of each class is Monet:1074, Cezanne:584, Van Gogh:401, Ukiyo-e:1433, Photographs:6853.
14 - `iphone2dslr_flower`: both classes of images were downlaoded from Flickr. The training set size of each class is iPhone:1813, DSLR:3316. See more details in our paper.
15
16 To train a model on your own datasets, you need to create a data folder with two subdirectories `trainA` and `trainB` that contain images from domain A and B. You can test your model on your training set by setting `--phase train` in `test.py`. You can also create subdirectories `testA` and `testB` if you have test data.
17
18 You should **not** expect our method to work on just any random combination of input and output datasets (e.g. `cats<->keyboards`). From our experiments, we find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`. `zebras<->horses` achieves compelling results while `cats<->dogs` completely fails.
19
20 ### pix2pix datasets
21 Download the pix2pix datasets using the following script. Some of the datasets are collected by other researchers. Please cite their papers if you use the data.
22 ```bash
23 bash ./datasets/download_pix2pix_dataset.sh dataset_name
24 ```
25 - `facades`: 400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade). [[Citation](../datasets/bibtex/facades.tex)]
26 - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com). [[Citation](../datasets/bibtex/cityscapes.tex)]
27 - `maps`: 1096 training images scraped from Google Maps
28 - `edges2shoes`: 50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/shoes.tex)]
29 - `edges2handbags`: 137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/handbags.tex)]
30 - `night2day`: around 20K natural scene images from [Transient Attributes dataset](http://transattr.cs.brown.edu/) [[Citation](datasets/bibtex/transattr.tex)]. To train a `day2night` pix2pix model, you need to add `--direction BtoA`.
31
32 We provide a python script to generate pix2pix training data in the form of pairs of images {A,B}, where A and B are two different depictions of the same underlying scene. For example, these might be pairs {label map, photo} or {bw image, color image}. Then we can learn to translate A to B or B to A:
33
34 Create folder `/path/to/data` with subfolders `A` and `B`. `A` and `B` should each have their own subfolders `train`, `val`, `test`, etc. In `/path/to/data/A/train`, put training images in style A. In `/path/to/data/B/train`, put the corresponding images in style B. Repeat same for other data splits (`val`, `test`, etc).
35
36 Corresponding images in a pair {A,B} must be the same size and have the same filename, e.g., `/path/to/data/A/train/1.jpg` is considered to correspond to `/path/to/data/B/train/1.jpg`.
37
38 Once the data is formatted this way, call:
39 ```bash
40 python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data
41 ```
42
43 This will combine each pair of images (A,B) into a single image file, ready for training.
0 # Docker image with pytorch-CycleGAN-and-pix2pix
1
2 We provide both Dockerfile and pre-built Docker container that can run this code repo.
3
4 ## Prerequisite
5
6 - Install [docker-ce](https://docs.docker.com/install/linux/docker-ce/ubuntu/)
7 - Install [nvidia-docker](https://github.com/NVIDIA/nvidia-docker#quickstart)
8
9 ## Running pre-built Dockerfile
10
11 - Pull the pre-built docker file
12
13 ```bash
14 docker pull taesungp/pytorch-cyclegan-and-pix2pix
15 ```
16
17 - Start an interactive docker session. `-p 8097:8097` option is needed if you want to run `visdom` server on the Docker container.
18
19 ```bash
20 nvidia-docker run -it -p 8097:8097 taesungp/pytorch-cyclegan-and-pix2pix
21 ```
22
23 - Now you are in the Docker environment. Go to our code repo and start running things.
24 ```bash
25 cd /workspace/pytorch-CycleGAN-and-pix2pix
26 bash datasets/download_pix2pix_dataset.sh facades
27 python -m visdom.server &
28 bash scripts/train_pix2pix.sh
29 ```
30
31 ## Running with Dockerfile
32
33 We also posted the [Dockerfile](Dockerfile). To generate the pre-built file, download the Dockerfile in this directory and run
34 ```bash
35 docker build -t [target_tag] .
36 ```
37 in the directory that contains the Dockerfile.
0 ## Overview of Code Structure
1 To help users better understand and use our codebase, we briefly overview the functionality and implementation of each package and each module. Please see the documentation in each file for more details. If you have questions, you may find useful information in [training/test tips](tips.md) and [frequently asked questions](qa.md).
2
3 [train.py](../train.py) is a general-purpose training script. It works for various models (with option `--model`: e.g., `pix2pix`, `cyclegan`, `colorization`) and different datasets (with option `--dataset_mode`: e.g., `aligned`, `unaligned`, `single`, `colorization`). See the main [README](.../README.md) and [training/test tips](tips.md) for more details.
4
5 [test.py](../test.py) is a general-purpose test script. Once you have trained your model with `train.py`, you can use this script to test the model. It will load a saved model from `--checkpoints_dir` and save the results to `--results_dir`. See the main [README](.../README.md) and [training/test tips](tips.md) for more details.
6
7
8 [data](../data) directory contains all the modules related to data loading and preprocessing. To add a custom dataset class called `dummy`, you need to add a file called `dummy_dataset.py` and define a subclass `DummyDataset` inherited from `BaseDataset`. You need to implement four functions: `__init__` (initialize the class, you need to first call `BaseDataset.__init__(self, opt)`), `__len__` (return the size of dataset), `__getitem__` (get a data point), and optionally `modify_commandline_options` (add dataset-specific options and set default options). Now you can use the dataset class by specifying flag `--dataset_mode dummy`. See our template dataset [class](../data/template_dataset.py) for an example. Below we explain each file in details.
9
10 * [\_\_init\_\_.py](../data/__init__.py) implements the interface between this package and training and test scripts. `train.py` and `test.py` call `from data import create_dataset` and `dataset = create_dataset(opt)` to create a dataset given the option `opt`.
11 * [base_dataset.py](../data/base_dataset.py) implements an abstract base class ([ABC](https://docs.python.org/3/library/abc.html)) for datasets. It also includes common transformation functions (e.g., `get_transform`, `__scale_width`), which can be later used in subclasses.
12 * [image_folder.py](../data/image_folder.py) implements an image folder class. We modify the official PyTorch image folder [code](https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) so that this class can load images from both the current directory and its subdirectories.
13 * [template_dataset.py](../data/template_dataset.py) provides a dataset template with detailed documentation. Check out this file if you plan to implement your own dataset.
14 * [aligned_dataset.py](../data/aligned_dataset.py) includes a dataset class that can load image pairs. It assumes a single image directory `/path/to/data/train`, which contains image pairs in the form of {A,B}. See [here](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md#prepare-your-own-datasets-for-pix2pix) on how to prepare aligned datasets. During test time, you need to prepare a directory `/path/to/data/test` as test data.
15 * [unaligned_dataset.py](../data/unaligned_dataset.py) includes a dataset class that can load unaligned/unpaired datasets. It assumes that two directories to host training images from domain A `/path/to/data/trainA` and from domain B `/path/to/data/trainB` respectively. Then you can train the model with the dataset flag `--dataroot /path/to/data`. Similarly, you need to prepare two directories `/path/to/data/testA` and `/path/to/data/testB` during test time.
16 * [single_dataset.py](../data/single_dataset.py) includes a dataset class that can load a set of single images specified by the path `--dataroot /path/to/data`. It can be used for generating CycleGAN results only for one side with the model option `-model test`.
17 * [colorization_dataset.py](../data/colorization_dataset.py) implements a dataset class that can load a set of nature images in RGB, and convert RGB format into (L, ab) pairs in [Lab](https://en.wikipedia.org/wiki/CIELAB_color_space) color space. It is required by pix2pix-based colorization model (`--model colorization`).
18
19
20 [models](../models) directory contains modules related to objective functions, optimizations, and network architectures. To add a custom model class called `dummy`, you need to add a file called `dummy_model.py` and define a subclass `DummyModel` inherited from `BaseModel`. You need to implement four functions: `__init__` (initialize the class; you need to first call `BaseModel.__init__(self, opt)`), `set_input` (unpack data from dataset and apply preprocessing), `forward` (generate intermediate results), `optimize_parameters` (calculate loss, gradients, and update network weights), and optionally `modify_commandline_options` (add model-specific options and set default options). Now you can use the model class by specifying flag `--model dummy`. See our template model [class](../models/template_model.py) for an example. Below we explain each file in details.
21
22 * [\_\_init\_\_.py](../models/__init__.py) implements the interface between this package and training and test scripts. `train.py` and `test.py` call `from models import create_model` and `model = create_model(opt)` to create a model given the option `opt`. You also need to call `model.setup(opt)` to properly initialize the model.
23 * [base_model.py](../models/base_model.py) implements an abstract base class ([ABC](https://docs.python.org/3/library/abc.html)) for models. It also includes commonly used helper functions (e.g., `setup`, `test`, `update_learning_rate`, `save_networks`, `load_networks`), which can be later used in subclasses.
24 * [template_model.py](../models/template_model.py) provides a model template with detailed documentation. Check out this file if you plan to implement your own model.
25 * [pix2pix_model.py](../models/pix2pix_model.py) implements the pix2pix [model](https://phillipi.github.io/pix2pix/), for learning a mapping from input images to output images given paired data. The model training requires `--dataset_mode aligned` dataset. By default, it uses a `--netG unet256` [U-Net](https://arxiv.org/pdf/1505.04597.pdf) generator, a `--netD basic` discriminator (PatchGAN), and a `--gan_mode vanilla` GAN loss (standard cross-entropy objective).
26 * [colorization_model.py](../models/colorization_model.py) implements a subclass of `Pix2PixModel` for image colorization (black & white image to colorful image). The model training requires `-dataset_model colorization` dataset. It trains a pix2pix model, mapping from L channel to ab channels in [Lab](https://en.wikipedia.org/wiki/CIELAB_color_space) color space. By default, the `colorization` dataset will automatically set `--input_nc 1` and `--output_nc 2`.
27 * [cycle_gan_model.py](../models/cycle_gan_model.py) implements the CycleGAN [model](https://junyanz.github.io/CycleGAN/), for learning image-to-image translation without paired data. The model training requires `--dataset_mode unaligned` dataset. By default, it uses a `--netG resnet_9blocks` ResNet generator, a `--netD basic` discriminator (PatchGAN introduced by pix2pix), and a least-square GANs [objective](https://arxiv.org/abs/1611.04076) (`--gan_mode lsgan`).
28 * [networks.py](../models/networks.py) module implements network architectures (both generators and discriminators), as well as normalization layers, initialization methods, optimization scheduler (i.e., learning rate policy), and GAN objective function (`vanilla`, `lsgan`, `wgangp`).
29 * [test_model.py](../models/test_model.py) implements a model that can be used to generate CycleGAN results for only one direction. This model will automatically set `--dataset_mode single`, which only loads the images from one set. See the test [instruction](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix#apply-a-pre-trained-model-cyclegan) for more details.
30
31 [options](../options) directory includes our option modules: training options, test options, and basic options (used in both training and test). `TrainOptions` and `TestOptions` are both subclasses of `BaseOptions`. They will reuse the options defined in `BaseOptions`.
32 * [\_\_init\_\_.py](../options/__init__.py) is required to make Python treat the directory `options` as containing packages,
33 * [base_options.py](../options/base_options.py) includes options that are used in both training and test. It also implements a few helper functions such as parsing, printing, and saving the options. It also gathers additional options defined in `modify_commandline_options` functions in both dataset class and model class.
34 * [train_options.py](../options/train_options.py) includes options that are only used during training time.
35 * [test_options.py](../options/test_options.py) includes options that are only used during test time.
36
37
38 [util](../util) directory includes a miscellaneous collection of useful helper functions.
39 * [\_\_init\_\_.py](../util/__init__.py) is required to make Python treat the directory `util` as containing packages,
40 * [get_data.py](../util/get_data.py) provides a Python script for downloading CycleGAN and pix2pix datasets. Alternatively, You can also use bash scripts such as [download_pix2pix_model.sh](../scripts/download_pix2pix_model.sh) and [download_cyclegan_model.sh](../scripts/download_cyclegan_model.sh).
41 * [html.py](../util/html.py) implements a module that saves images into a single HTML file. It consists of functions such as `add_header` (add a text header to the HTML file), `add_images` (add a row of images to the HTML file), `save` (save the HTML to the disk). It is based on Python library `dominate`, a Python library for creating and manipulating HTML documents using a DOM API.
42 * [image_pool.py](../util/image_pool.py) implements an image buffer that stores previously generated images. This buffer enables us to update discriminators using a history of generated images rather than the ones produced by the latest generators. The original idea was discussed in this [paper](http://openaccess.thecvf.com/content_cvpr_2017/papers/Shrivastava_Learning_From_Simulated_CVPR_2017_paper.pdf). The size of the buffer is controlled by the flag `--pool_size`.
43 * [visualizer.py](../util/visualizer.py) includes several functions that can display/save images and print/save logging information. It uses a Python library `visdom` for display and a Python library `dominate` (wrapped in `HTML`) for creating HTML files with images.
44 * [util.py](../util/util.py) consists of simple helper functions such as `tensor2im` (convert a tensor array to a numpy image array), `diagnose_network` (calculate and print the mean of average absolute value of gradients), and `mkdirs` (create multiple directories).
0 ## Frequently Asked Questions
1 Before you post a new question, please first look at the following Q & A and existing GitHub issues. You may also want to read [Training/Test tips](docs/tips.md) for more suggestions.
2
3 #### Connection Error:HTTPConnectionPool ([#230](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/230), [#24](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/24), [#38](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/38))
4 Similar error messages include “Failed to establish a new connection/Connection refused”.
5
6 Please start the visdom server before starting the training:
7 ```bash
8 python -m visdom.server
9 ```
10 To install the visdom, you can use the following command:
11 ```bash
12 pip install visdom
13 ```
14 You can also disable the visdom by setting `--display_id 0`.
15
16 #### My PyTorch errors on CUDA related code.
17 Try to run the following code snippet to make sure that CUDA is working (assuming using PyTorch >= 0.4):
18 ```python
19 import torch
20 torch.cuda.init()
21 print(torch.randn(1, device='cuda'))
22 ```
23
24 If you met an error, it is likely that your PyTorch build does not work with CUDA, e.g., it is installl from the official MacOS binary, or you have a GPU that is too old and not supported anymore. You may run the the code with CPU using `--gpu_ids -1`.
25
26 #### TypeError: Object of type 'Tensor' is not JSON serializable ([#258](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/258))
27 Similar errors: AttributeError: module 'torch' has no attribute 'device' ([#314](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/314))
28
29 The current code only works with PyTorch 0.4+. An earlier PyTorch version can often cause the above errors.
30
31 #### ValueError: empty range for randrange() ([#390](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/390), [#376](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/376), [#194](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/194))
32 Similar error messages include "ConnectionRefusedError: [Errno 111] Connection refused"
33
34 It is related to data augmentation step. It often happens when you use `--preprocess crop`. The program will crop random `crop_size x crop_size` patches out of the input training images. But if some of your image sizes (e.g., `256x384`) are smaller than the `crop_size` (e.g., 512), you will get this error. A simple fix will be to use other data augmentation methods such as `resize_and_crop` or `scale_width_and_crop`. Our program will automatically resize the images according to `load_size` before apply `crop_size x crop_size` cropping. Make sure that `load_size >= crop_size`.
35
36
37 #### Can I continue/resume my training? ([#350](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/350), [#275](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/275), [#234](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/234), [#87](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/87))
38 You can use the option `--continue_train`. Also set `--epoch_count` to specify a different starting epoch count. See more discussion in [training/test tips](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md#trainingtest-tips.
39
40 #### Why does my training loss not converge? ([#335](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/335), [#164](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/164), [#30](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/30))
41 Many GAN losses do not converge (exception: WGAN, WGAN-GP, etc. ) due to the nature of minimax optimization. For DCGAN and LSGAN objective, it is quite normal for the G and D losses to go up and down. It should be fine as long as they do not blow up.
42
43 #### How can I make it work for my own data (e.g., 16-bit png, tiff, hyperspectral images)? ([#309](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/309), [#320](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/), [#202](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/202))
44 The current code only supports RGB and grayscale images. If you would like to train the model on other data types, please follow the following steps:
45
46 - change the parameters `--input_nc` and `--output_nc` to the number of channels in your input/output images.
47 - Write your own custom data loader (It is easy as long as you know how to load your data with python). If you write a new data loader class, you need to change the flag `--dataset_mode` accordingly. Alternatively, you can modify the existing data loader. For aligned datasets, change this [line](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/data/aligned_dataset.py#L41); For unaligned datasets, change these two [lines](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/data/unaligned_dataset.py#L57).
48
49 - If you use visdom and HTML to visualize the results, you may also need to change the visualization code.
50
51 #### Multi-GPU Training ([#327](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/327), [#292](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/292), [#137](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/137), [#35](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/35))
52 You can use Multi-GPU training by setting `--gpu_ids` (e.g., `--gpu_ids 0,1,2,3` for the first four GPUs on your machine.) To fully utilize all the GPUs, you need to increase your batch size. Try `--batch_size 4`, `--batch_size 16`, or even a larger batch_size. Each GPU will process batch_size/#GPUs images. The optimal batch size depends on the number of GPUs you have, GPU memory per GPU, and the resolution of your training images.
53
54 We also recommend that you use the instance normalization for multi-GPU training by setting `--norm instance`. The current batch normalization might not work for multi-GPUs as the batchnorm parameters are not shared across different GPUs. Advanced users can try [synchronized batchnorm](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch).
55
56
57 #### Can I run the model on CPU? ([#310](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/310))
58 Yes, you can set `--gpu_ids -1`. See [training/test tips](tips.md) for more details.
59
60
61 #### Are pre-trained models available? ([#10](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/10))
62 Yes, you can download pretrained models with the bash script `./scripts/download_cyclegan_model.sh`. See [here](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix#apply-a-pre-trained-model-cyclegan) for more details. We are slowly adding more models to the repo.
63
64 #### Out of memory ([#174](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/174))
65 CycleGAN is more memory-intensive than pix2pix as it requires two generators and two discriminators. If you would like to produce high-resolution images, you can do the following.
66
67 - During training, train CycleGAN on cropped images of the training set. Please be careful not to change the aspect ratio or the scale of the original image, as this can lead to the training/test gap. You can usually do this by using `--preprocess crop` option, or `--preprocess scale_width_and_crop`.
68
69 - Then at test time, you can load only one generator to produce the results in a single direction. This greatly saves GPU memory as you are not loading the discriminators and the other generator in the opposite direction. You can probably take the whole image as input. You can do this using `--model test --dataroot [path to the directory that contains your test images (e.g., ./datasets/horse2zebra/trainA)] --model_suffix _A --preprocess none`. You can use either `--preprocess none` or `--preprocess scale_width --crop_size [your_desired_image_width]`. Please see the [model_suffix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/test_model.py#L16) and [preprocess](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/data/base_dataset.py#L24) for more details.
70
71 #### What is the identity loss? ([#322](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/322), [#373](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/373), [#362](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/pull/362))
72 We use the identity loss for our photo to painting application. The identity loss can regularize the generator to be close to an identity mapping when fed with real samples from the *target* domain. If something already looks like from the target domain, you should preserve the image without making additional changes. The generator trained with this loss will often be more conservative for unknown content. Please see more details in Sec 5.2 ''Photo generation from paintings'' and Figure 12 in the CycleGAN [paper](https://arxiv.org/pdf/1703.10593.pdf). The loss was first proposed in the Equation 6 of the prior work [[Taigman et al., 2017]](https://arxiv.org/pdf/1611.02200.pdf).
73
74 #### The color gets inverted from the beginning of training ([#249](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/249))
75 The authors also observe that the generator unnecessarily inverts the color of the input image early in training, and then never learns to undo the inversion. In this case, you can try two things.
76
77 - First, try using identity loss `--lambda_identity 1.0` or `--lambda_identity 0.1`. We observe that the identity loss makes the generator to be more conservative and make fewer unnecessary changes. However, because of this, the change may not be as dramatic.
78
79 - Second, try smaller variance when initializing weights by changing `--init_gain`. We observe that smaller variance in weight initialization results in less color inversion.
80
81 #### For labels2photo Cityscapes evaluation, why does the pretrained FCN-8s model not work well on the original Cityscapes input images? ([#150](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/150))
82 The model was trained on 256x256 images that are resized/upsampled to 1024x2048, so expected input images to the network are very blurry. The purpose of the resizing was to 1) keep the label maps in the original high resolution untouched and 2) avoid the need of changing the standard FCN training code for Cityscapes.
83
84 #### How do I get the `ground-truth` numbers on the labels2photo Cityscapes evaluation? ([#150](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/150))
85 You need to resize the original Cityscapes images to 256x256 before running the evaluation code.
86
87
88 #### Using resize-conv to reduce checkerboard artifacts ([#190](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/190), [#64](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/64))
89 This Distill [blog](https://distill.pub/2016/deconv-checkerboard/) discussed one of the potential causes of the checkerboard artifacts. You can fix that issue by switching from "deconvolution" to nearest-neighbor upsampling followed by regular convolution. Here is one implementation provided by [@SsnL](https://github.com/SsnL). You can replace the ConvTranspose2d with the following layers.
90 ```python
91 nn.Upsample(scale_factor = 2, mode='bilinear'),
92 nn.ReflectionPad2d(1),
93 nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=0),
94 ```
95 We have also noticed that sometimes the checkboard artifacts will go away if you train long enough. Maybe you can try training your model a bit longer.
96
97 #### pix2pix/CycleGAN has no random noise z ([#152](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/152))
98 The current pix2pix/CycleGAN model does not take z as input. In both pix2pix and CycleGAN, we tried to add z to the generator: e.g., adding z to a latent state, concatenating with a latent state, applying dropout, etc., but often found the output did not vary significantly as a function of z. Conditional GANs do not need noise as long as the input is sufficiently complex so that the input can kind of play the role of noise. Without noise, the mapping is deterministic.
99
100 Please check out the following papers that show ways of getting z to actually have a substantial effect: e.g., [BicycleGAN](https://github.com/junyanz/BicycleGAN), [AugmentedCycleGAN](https://arxiv.org/abs/1802.10151), [MUNIT](https://arxiv.org/abs/1804.04732), [DRIT](https://arxiv.org/pdf/1808.00948.pdf), etc.
101
102 #### Experiment details (e.g., BW->color) ([#306](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/306))
103 You can find more training details and hyperparameter settings in the appendix of [CycleGAN](https://arxiv.org/abs/1703.10593) and [pix2pix](https://arxiv.org/abs/1611.07004) papers.
104
105 #### Results with [Cycada](https://arxiv.org/pdf/1711.03213.pdf)
106 We generated the [result of translating GTA images to Cityscapes-style images](https://junyanz.github.io/CycleGAN/) using our Torch repo. Our PyTorch and Torch implementation seemed to produce a little bit different results, although we have not measured the FCN score using the pytorch-trained model. To reproduce the result of Cycada, please use the Torch repo for now.
0 ## Training/test Tips
1 #### Training/test options
2 Please see `options/train_options.py` and `options/base_options.py` for the training flags; see `options/test_options.py` and `options/base_options.py` for the test flags. There are some model-specific flags as well, which are added in the model files, such as `--lambda_A` option in `model/cycle_gan_model.py`. The default values of these options are also adjusted in the model files.
3 #### CPU/GPU (default `--gpu_ids 0`)
4 Please set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g., `--batch_size 32`) to benefit from multiple GPUs.
5
6 #### Visualization
7 During training, the current results can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id -1`. Second, the intermediate results are saved to `[opt.checkpoints_dir]/[opt.name]/web/` as an HTML file. To avoid this, set `--no_html`.
8
9 #### Preprocessing
10 Images can be resized and cropped in different ways using `--preprocess` option. The default option `'resize_and_crop'` resizes the image to be of size `(opt.load_size, opt.load_size)` and does a random crop of size `(opt.crop_size, opt.crop_size)`. `'crop'` skips the resizing step and only performs random cropping. `'scale_width'` resizes the image to have width `opt.crop_size` while keeping the aspect ratio. `'scale_width_and_crop'` first resizes the image to have width `opt.load_size` and then does random cropping of size `(opt.crop_size, opt.crop_size)`. `'none'` tries to skip all these preprocessing steps. However, if the image size is not a multiple of some number depending on the number of downsamplings of the generator, you will get an error because the size of the output image may be different from the size of the input image. Therefore, `'none'` option still tries to adjust the image size to be a multiple of 4. You might need a bigger adjustment if you change the generator architecture. Please see `data/base_datset.py` do see how all these were implemented.
11
12 #### Fine-tuning/resume training
13 To fine-tune a pre-trained model, or resume the previous training, use the `--continue_train` flag. The program will then load the model based on `epoch`. By default, the program will initialize the epoch count as 1. Set `--epoch_count <int>` to specify a different starting epoch count.
14
15
16 #### Prepare your own datasets for CycleGAN
17 You need to create two directories to host images from domain A `/path/to/data/trainA` and from domain B `/path/to/data/trainB`. Then you can train the model with the dataset flag `--dataroot /path/to/data`. Optionally, you can create hold-out test datasets at `/path/to/data/testA` and `/path/to/data/testB` to test your model on unseen images.
18
19 #### Prepare your own datasets for pix2pix
20 Pix2pix's training requires paired data. We provide a python script to generate training data in the form of pairs of images {A,B}, where A and B are two different depictions of the same underlying scene. For example, these might be pairs {label map, photo} or {bw image, color image}. Then we can learn to translate A to B or B to A:
21
22 Create folder `/path/to/data` with subdirectories `A` and `B`. `A` and `B` should each have their own subdirectories `train`, `val`, `test`, etc. In `/path/to/data/A/train`, put training images in style A. In `/path/to/data/B/train`, put the corresponding images in style B. Repeat same for other data splits (`val`, `test`, etc).
23
24 Corresponding images in a pair {A,B} must be the same size and have the same filename, e.g., `/path/to/data/A/train/1.jpg` is considered to correspond to `/path/to/data/B/train/1.jpg`.
25
26 Once the data is formatted this way, call:
27 ```bash
28 python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data
29 ```
30
31 This will combine each pair of images (A,B) into a single image file, ready for training.
32
33
34 #### About image size
35 Since the generator architecture in CycleGAN involves a series of downsampling / upsampling operations, the size of the input and output image may not match if the input image size is not a multiple of 4. As a result, you may get a runtime error because the L1 identity loss cannot be enforced with images of different size. Therefore, we slightly resize the image to become multiples of 4 even with `--preprocess none` option. For the same reason, `--crop_size` needs to be a multiple of 4.
36
37 #### Training/Testing with high res images
38 CycleGAN is quite memory-intensive as four networks (two generators and two discriminators) need to be loaded on one GPU, so a large image cannot be entirely loaded. In this case, we recommend training with cropped images. For example, to generate 1024px results, you can train with `--preprocess scale_width_and_crop --load_size 1024 --crop_size 360`, and test with `--preprocess scale_width --load_size 1024`. This way makes sure the training and test will be at the same scale. At test time, you can afford higher resolution because you don’t need to load all networks.
39
40 #### About loss curve
41 Unfortunately, the loss curve does not reveal much information in training GANs, and CycleGAN is no exception. To check whether the training has converged or not, we recommend periodically generating a few samples and looking at them.
42
43 #### About batch size
44 For all experiments in the paper, we set the batch size to be 1. If there is room for memory, you can use higher batch size with batch norm or instance norm. (Note that the default batchnorm does not work well with multi-GPU training. You may consider using [synchronized batchnorm](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) instead). But please be aware that it can impact the training. In particular, even with Instance Normalization, different batch sizes can lead to different results. Moreover, increasing `--crop_size` may be a good alternative to increasing the batch size.
45
46
47 #### Notes on Colorization
48 No need to run `combine_A_and_B.py` for colorization. Instead, you need to prepare natural images and set `--dataset_mode colorization` and `--model colorization` in the script. The program will automatically convert each RGB image into Lab color space, and create `L -> ab` image pair during the training. Also set `--input_nc 1` and `--output_nc 2`. The training and test directory should be organized as `/your/data/train` and `your/data/test`. See example scripts `scripts/train_colorization.sh` and `scripts/test_colorization` for more details.
49
50 #### Notes on Extracting Edges
51 We provide python and Matlab scripts to extract coarse edges from photos. Run `scripts/edges/batch_hed.py` to compute [HED](https://github.com/s9xie/hed) edges. Run `scripts/edges/PostprocessHED.m` to simplify edges with additional post-processing steps. Check the code documentation for more details.
52
53 #### Evaluating Labels2Photos on Cityscapes
54 We provide scripts for running the evaluation of the Labels2Photos task on the Cityscapes **validation** set. We assume that you have installed `caffe` (and `pycaffe`) in your system. If not, see the [official website](http://caffe.berkeleyvision.org/installation.html) for installation instructions. Once `caffe` is successfully installed, download the pre-trained FCN-8s semantic segmentation model (512MB) by running
55 ```bash
56 bash ./scripts/eval_cityscapes/download_fcn8s.sh
57 ```
58 Then make sure `./scripts/eval_cityscapes/` is in your system's python path. If not, run the following command to add it
59 ```bash
60 export PYTHONPATH=${PYTHONPATH}:./scripts/eval_cityscapes/
61 ```
62 Now you can run the following command to evaluate your predictions:
63 ```bash
64 python ./scripts/eval_cityscapes/evaluate.py --cityscapes_dir /path/to/original/cityscapes/dataset/ --result_dir /path/to/your/predictions/ --output_dir /path/to/output/directory/
65 ```
66 Images stored under `--result_dir` should contain your model predictions on the Cityscapes **validation** split, and have the original Cityscapes naming convention (e.g., `frankfurt_000001_038418_leftImg8bit.png`). The script will output a text file under `--output_dir` containing the metric.
67
68 **Further notes**: The pre-trained model is **not** supposed to work on Cityscapes in the original resolution (1024x2048) as it was trained on 256x256 images that are upsampled to 1024x2048. The purpose of the resizing was to 1) keep the label maps in the original high resolution untouched and 2) avoid the need of changing the standard FCN training code for Cityscapes. To get the *ground-truth* numbers in the paper, you need to resize the original Cityscapes images to 256x256 before running the evaluation code.
0 name: pytorch-CycleGAN-and-pix2pix
1 channels:
2 - peterjc123
3 - defaults
4 dependencies:
5 - python=3.5.5
6 - pytorch=0.4.1
7 - scipy
8 - pip:
9 - dominate==2.3.1
10 - git+https://github.com/pytorch/vision.git
11 - Pillow==5.0.0
12 - numpy==1.14.1
13 - visdom==0.1.7
Binary diff not shown
Binary diff not shown
0 """This package contains modules related to objective functions, optimizations, and network architectures.
1
2 To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
3 You need to implement the following five functions:
4 -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
5 -- <set_input>: unpack data from dataset and apply preprocessing.
6 -- <forward>: produce intermediate results.
7 -- <optimize_parameters>: calculate loss, gradients, and update network weights.
8 -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
9
10 In the function <__init__>, you need to define four lists:
11 -- self.loss_names (str list): specify the training losses that you want to plot and save.
12 -- self.model_names (str list): define networks used in our training.
13 -- self.visual_names (str list): specify the images that you want to display and save.
14 -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
15
16 Now you can use the model class by specifying flag '--model dummy'.
17 See our template model class 'template_model.py' for more details.
18 """
19
20 import importlib
21 from models.base_model import BaseModel
22
23
24 def find_model_using_name(model_name):
25 """Import the module "models/[model_name]_model.py".
26
27 In the file, the class called DatasetNameModel() will
28 be instantiated. It has to be a subclass of BaseModel,
29 and it is case-insensitive.
30 """
31 model_filename = "models." + model_name + "_model"
32 modellib = importlib.import_module(model_filename)
33 model = None
34 target_model_name = model_name.replace('_', '') + 'model'
35 for name, cls in modellib.__dict__.items():
36 if name.lower() == target_model_name.lower() \
37 and issubclass(cls, BaseModel):
38 model = cls
39
40 if model is None:
41 print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
42 exit(0)
43
44 return model
45
46
47 def get_option_setter(model_name):
48 """Return the static method <modify_commandline_options> of the model class."""
49 model_class = find_model_using_name(model_name)
50 return model_class.modify_commandline_options
51
52
53 def create_model(opt):
54 """Create a model given the option.
55
56 This function warps the class CustomDatasetDataLoader.
57 This is the main interface between this package and 'train.py'/'test.py'
58
59 Example:
60 >>> from models import create_model
61 >>> model = create_model(opt)
62 """
63 model = find_model_using_name(opt.model)
64 instance = model(opt)
65 print("model [%s] was created" % type(instance).__name__)
66 return instance
0 import os
1 import torch
2 from collections import OrderedDict
3 from abc import ABC, abstractmethod
4 from . import networks
5
6
7 class BaseModel(ABC):
8 """This class is an abstract base class (ABC) for models.
9 To create a subclass, you need to implement the following five functions:
10 -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
11 -- <set_input>: unpack data from dataset and apply preprocessing.
12 -- <forward>: produce intermediate results.
13 -- <optimize_parameters>: calculate losses, gradients, and update network weights.
14 -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
15 """
16
17 def __init__(self, opt):
18 """Initialize the BaseModel class.
19
20 Parameters:
21 opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
22
23 When creating your custom class, you need to implement your own initialization.
24 In this fucntion, you should first call <BaseModel.__init__(self, opt)>
25 Then, you need to define four lists:
26 -- self.loss_names (str list): specify the training losses that you want to plot and save.
27 -- self.model_names (str list): specify the images that you want to display and save.
28 -- self.visual_names (str list): define networks used in our training.
29 -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
30 """
31 self.opt = opt
32 self.gpu_ids = opt.gpu_ids
33 self.isTrain = opt.isTrain
34 self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
35 self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
36 if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
37 torch.backends.cudnn.benchmark = True
38 self.loss_names = []
39 self.model_names = []
40 self.visual_names = []
41 self.optimizers = []
42 self.image_paths = []
43 self.metric = 0 # used for learning rate policy 'plateau'
44
45 @staticmethod
46 def modify_commandline_options(parser, is_train):
47 """Add new model-specific options, and rewrite default values for existing options.
48
49 Parameters:
50 parser -- original option parser
51 is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
52
53 Returns:
54 the modified parser.
55 """
56 return parser
57
58 @abstractmethod
59 def set_input(self, input):
60 """Unpack input data from the dataloader and perform necessary pre-processing steps.
61
62 Parameters:
63 input (dict): includes the data itself and its metadata information.
64 """
65 pass
66
67 @abstractmethod
68 def forward(self):
69 """Run forward pass; called by both functions <optimize_parameters> and <test>."""
70 pass
71
72 @abstractmethod
73 def optimize_parameters(self):
74 """Calculate losses, gradients, and update network weights; called in every training iteration"""
75 pass
76
77 def setup(self, opt):
78 """Load and print networks; create schedulers
79
80 Parameters:
81 opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
82 """
83 if self.isTrain:
84 self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
85 if not self.isTrain or opt.continue_train:
86 load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
87 self.load_networks(load_suffix)
88 self.print_networks(opt.verbose)
89
90 def eval(self):
91 """Make models eval mode during test time"""
92 for name in self.model_names:
93 if isinstance(name, str):
94 net = getattr(self, 'net' + name)
95 net.eval()
96
97 def test(self):
98 """Forward function used in test time.
99
100 This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
101 It also calls <compute_visuals> to produce additional visualization results
102 """
103 with torch.no_grad():
104 self.forward()
105 self.compute_visuals()
106
107 def compute_visuals(self):
108 """Calculate additional output images for visdom and HTML visualization"""
109 pass
110
111 def get_image_paths(self):
112 """ Return image paths that are used to load current data"""
113 return self.image_paths
114
115 def update_learning_rate(self):
116 """Update learning rates for all the networks; called at the end of every epoch"""
117 for scheduler in self.schedulers:
118 if self.opt.lr_policy == 'plateau':
119 scheduler.step(self.metric)
120 else:
121 scheduler.step()
122
123 lr = self.optimizers[0].param_groups[0]['lr']
124 print('learning rate = %.7f' % lr)
125
126 def get_current_visuals(self):
127 """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
128 visual_ret = OrderedDict()
129 for name in self.visual_names:
130 if isinstance(name, str):
131 visual_ret[name] = getattr(self, name)
132 return visual_ret
133
134 def get_current_losses(self):
135 """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
136 errors_ret = OrderedDict()
137 for name in self.loss_names:
138 if isinstance(name, str):
139 errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
140 return errors_ret
141
142 def save_networks(self, epoch):
143 """Save all the networks to the disk.
144
145 Parameters:
146 epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
147 """
148 for name in self.model_names:
149 if isinstance(name, str):
150 save_filename = '%s_net_%s.pth' % (epoch, name)
151 save_path = os.path.join(self.save_dir, save_filename)
152 net = getattr(self, 'net' + name)
153
154 if len(self.gpu_ids) > 0 and torch.cuda.is_available():
155 torch.save(net.module.cpu().state_dict(), save_path)
156 net.cuda(self.gpu_ids[0])
157 else:
158 torch.save(net.cpu().state_dict(), save_path)
159
160 def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
161 """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
162 key = keys[i]
163 if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
164 if module.__class__.__name__.startswith('InstanceNorm') and \
165 (key == 'running_mean' or key == 'running_var'):
166 if getattr(module, key) is None:
167 state_dict.pop('.'.join(keys))
168 if module.__class__.__name__.startswith('InstanceNorm') and \
169 (key == 'num_batches_tracked'):
170 state_dict.pop('.'.join(keys))
171 else:
172 self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
173
174 def load_networks(self, epoch):
175 """Load all the networks from the disk.
176
177 Parameters:
178 epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
179 """
180 for name in self.model_names:
181 if isinstance(name, str):
182 load_filename = '%s_net_%s.pth' % (epoch, name)
183 load_path = os.path.join(self.save_dir, load_filename)
184 net = getattr(self, 'net' + name)
185 if isinstance(net, torch.nn.DataParallel):
186 net = net.module
187 print('loading the model from %s' % load_path)
188 # if you are using PyTorch newer than 0.4 (e.g., built from
189 # GitHub source), you can remove str() on self.device
190 state_dict = torch.load(load_path, map_location=str(self.device))
191 if hasattr(state_dict, '_metadata'):
192 del state_dict._metadata
193
194 # patch InstanceNorm checkpoints prior to 0.4
195 for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
196 self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
197 net.load_state_dict(state_dict)
198
199 def print_networks(self, verbose):
200 """Print the total number of parameters in the network and (if verbose) network architecture
201
202 Parameters:
203 verbose (bool) -- if verbose: print the network architecture
204 """
205 print('---------- Networks initialized -------------')
206 for name in self.model_names:
207 if isinstance(name, str):
208 net = getattr(self, 'net' + name)
209 num_params = 0
210 for param in net.parameters():
211 num_params += param.numel()
212 if verbose:
213 print(net)
214 print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
215 print('-----------------------------------------------')
216
217 def set_requires_grad(self, nets, requires_grad=False):
218 """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
219 Parameters:
220 nets (network list) -- a list of networks
221 requires_grad (bool) -- whether the networks require gradients or not
222 """
223 if not isinstance(nets, list):
224 nets = [nets]
225 for net in nets:
226 if net is not None:
227 for param in net.parameters():
228 param.requires_grad = requires_grad
0 from .pix2pix_model import Pix2PixModel
1 import torch
2 from skimage import color # used for lab2rgb
3 import numpy as np
4
5
6 class ColorizationModel(Pix2PixModel):
7 """This is a subclass of Pix2PixModel for image colorization (black & white image -> colorful images).
8
9 The model training requires '-dataset_model colorization' dataset.
10 It trains a pix2pix model, mapping from L channel to ab channels in Lab color space.
11 By default, the colorization dataset will automatically set '--input_nc 1' and '--output_nc 2'.
12 """
13 @staticmethod
14 def modify_commandline_options(parser, is_train=True):
15 """Add new dataset-specific options, and rewrite default values for existing options.
16
17 Parameters:
18 parser -- original option parser
19 is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
20
21 Returns:
22 the modified parser.
23
24 By default, we use 'colorization' dataset for this model.
25 See the original pix2pix paper (https://arxiv.org/pdf/1611.07004.pdf) and colorization results (Figure 9 in the paper)
26 """
27 Pix2PixModel.modify_commandline_options(parser, is_train)
28 parser.set_defaults(dataset_mode='colorization')
29 return parser
30
31 def __init__(self, opt):
32 """Initialize the class.
33
34 Parameters:
35 opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
36
37 For visualization, we set 'visual_names' as 'real_A' (input real image),
38 'real_B_rgb' (ground truth RGB image), and 'fake_B_rgb' (predicted RGB image)
39 We convert the Lab image 'real_B' (inherited from Pix2pixModel) to a RGB image 'real_B_rgb'.
40 we convert the Lab image 'fake_B' (inherited from Pix2pixModel) to a RGB image 'fake_B_rgb'.
41 """
42 # reuse the pix2pix model
43 Pix2PixModel.__init__(self, opt)
44 # specify the images to be visualized.
45 self.visual_names = ['real_A', 'real_B_rgb', 'fake_B_rgb']
46
47 def lab2rgb(self, L, AB):
48 """Convert an Lab tensor image to a RGB numpy output
49 Parameters:
50 L (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array)
51 AB (2-channel tensor array): ab channel images (range: [-1, 1], torch tensor array)
52
53 Returns:
54 rgb (RGB numpy image): rgb output images (range: [0, 255], numpy array)
55 """
56 AB2 = AB * 110.0
57 L2 = (L + 1.0) * 50.0
58 Lab = torch.cat([L2, AB2], dim=1)
59 Lab = Lab[0].data.cpu().float().numpy()
60 Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0))
61 rgb = color.lab2rgb(Lab) * 255
62 return rgb
63
64 def compute_visuals(self):
65 """Calculate additional output images for visdom and HTML visualization"""
66 self.real_B_rgb = self.lab2rgb(self.real_A, self.real_B)
67 self.fake_B_rgb = self.lab2rgb(self.real_A, self.fake_B)
0 import torch
1 import itertools
2 from util.image_pool import ImagePool
3 from .base_model import BaseModel
4 from . import networks
5
6
7 class CycleGANModel(BaseModel):
8 """
9 This class implements the CycleGAN model, for learning image-to-image translation without paired data.
10
11 The model training requires '--dataset_mode unaligned' dataset.
12 By default, it uses a '--netG resnet_9blocks' ResNet generator,
13 a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
14 and a least-square GANs objective ('--gan_mode lsgan').
15
16 CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
17 """
18 @staticmethod
19 def modify_commandline_options(parser, is_train=True):
20 """Add new dataset-specific options, and rewrite default values for existing options.
21
22 Parameters:
23 parser -- original option parser
24 is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
25
26 Returns:
27 the modified parser.
28
29 For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
30 A (source domain), B (target domain).
31 Generators: G_A: A -> B; G_B: B -> A.
32 Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
33 Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
34 Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
35 Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
36 Dropout is not used in the original CycleGAN paper.
37 """
38 parser.set_defaults(no_dropout=True) # default CycleGAN did not use dropout
39 if is_train:
40 parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
41 parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
42 parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
43
44 return parser
45
46 def __init__(self, opt):
47 """Initialize the CycleGAN class.
48
49 Parameters:
50 opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
51 """
52 BaseModel.__init__(self, opt)
53 # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
54 self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
55 # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
56 visual_names_A = ['real_A', 'fake_B', 'rec_A']
57 visual_names_B = ['real_B', 'fake_A', 'rec_B']
58 if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
59 visual_names_A.append('idt_B')
60 visual_names_B.append('idt_A')
61
62 self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
63 # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
64 if self.isTrain:
65 self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
66 else: # during test time, only load Gs
67 self.model_names = ['G_A', 'G_B']
68
69 # define networks (both Generators and discriminators)
70 # The naming is different from those used in the paper.
71 # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
72 self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
73 not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
74 self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
75 not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
76
77 if self.isTrain: # define discriminators
78 self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
79 opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
80 self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
81 opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
82
83 if self.isTrain:
84 if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
85 assert(opt.input_nc == opt.output_nc)
86 self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
87 self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
88 # define loss functions
89 self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.
90 self.criterionCycle = torch.nn.L1Loss()
91 self.criterionIdt = torch.nn.L1Loss()
92 # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
93 self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
94 self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
95 self.optimizers.append(self.optimizer_G)
96 self.optimizers.append(self.optimizer_D)
97
98 def set_input(self, input):
99 """Unpack input data from the dataloader and perform necessary pre-processing steps.
100
101 Parameters:
102 input (dict): include the data itself and its metadata information.
103
104 The option 'direction' can be used to swap domain A and domain B.
105 """
106 AtoB = self.opt.direction == 'AtoB'
107 self.real_A = input['A' if AtoB else 'B'].to(self.device)
108 self.real_B = input['B' if AtoB else 'A'].to(self.device)
109 self.image_paths = input['A_paths' if AtoB else 'B_paths']
110
111 def forward(self):
112 """Run forward pass; called by both functions <optimize_parameters> and <test>."""
113 self.fake_B = self.netG_A(self.real_A) # G_A(A)
114 self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
115 self.fake_A = self.netG_B(self.real_B) # G_B(B)
116 self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
117
118 def backward_D_basic(self, netD, real, fake):
119 """Calculate GAN loss for the discriminator
120
121 Parameters:
122 netD (network) -- the discriminator D
123 real (tensor array) -- real images
124 fake (tensor array) -- images generated by a generator
125
126 Return the discriminator loss.
127 We also call loss_D.backward() to calculate the gradients.
128 """
129 # Real
130 pred_real = netD(real)
131 loss_D_real = self.criterionGAN(pred_real, True)
132 # Fake
133 pred_fake = netD(fake.detach())
134 loss_D_fake = self.criterionGAN(pred_fake, False)
135 # Combined loss and calculate gradients
136 loss_D = (loss_D_real + loss_D_fake) * 0.5
137 loss_D.backward()
138 return loss_D
139
140 def backward_D_A(self):
141 """Calculate GAN loss for discriminator D_A"""
142 fake_B = self.fake_B_pool.query(self.fake_B)
143 self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
144
145 def backward_D_B(self):
146 """Calculate GAN loss for discriminator D_B"""
147 fake_A = self.fake_A_pool.query(self.fake_A)
148 self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
149
150 def backward_G(self):
151 """Calculate the loss for generators G_A and G_B"""
152 lambda_idt = self.opt.lambda_identity
153 lambda_A = self.opt.lambda_A
154 lambda_B = self.opt.lambda_B
155 # Identity loss
156 if lambda_idt > 0:
157 # G_A should be identity if real_B is fed: ||G_A(B) - B||
158 self.idt_A = self.netG_A(self.real_B)
159 self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
160 # G_B should be identity if real_A is fed: ||G_B(A) - A||
161 self.idt_B = self.netG_B(self.real_A)
162 self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
163 else:
164 self.loss_idt_A = 0
165 self.loss_idt_B = 0
166
167 # GAN loss D_A(G_A(A))
168 self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
169 # GAN loss D_B(G_B(B))
170 self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
171 # Forward cycle loss || G_B(G_A(A)) - A||
172 self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
173 # Backward cycle loss || G_A(G_B(B)) - B||
174 self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
175 # combined loss and calculate gradients
176 self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
177 self.loss_G.backward()
178
179 def optimize_parameters(self):
180 """Calculate losses, gradients, and update network weights; called in every training iteration"""
181 # forward
182 self.forward() # compute fake images and reconstruction images.
183 # G_A and G_B
184 self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs
185 self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero
186 self.backward_G() # calculate gradients for G_A and G_B
187 self.optimizer_G.step() # update G_A and G_B's weights
188 # D_A and D_B
189 self.set_requires_grad([self.netD_A, self.netD_B], True)
190 self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero
191 self.backward_D_A() # calculate gradients for D_A
192 self.backward_D_B() # calculate graidents for D_B
193 self.optimizer_D.step() # update D_A and D_B's weights
0 import torch
1 import torch.nn as nn
2 from torch.nn import init
3 import functools
4 from torch.optim import lr_scheduler
5
6
7 ###############################################################################
8 # Helper Functions
9 ###############################################################################
10
11
12 class Identity(nn.Module):
13 def forward(self, x):
14 return x
15
16
17 def get_norm_layer(norm_type='instance'):
18 """Return a normalization layer
19
20 Parameters:
21 norm_type (str) -- the name of the normalization layer: batch | instance | none
22
23 For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
24 For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
25 """
26 if norm_type == 'batch':
27 norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
28 elif norm_type == 'instance':
29 norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
30 elif norm_type == 'none':
31 norm_layer = lambda x: Identity()
32 else:
33 raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
34 return norm_layer
35
36
37 def get_scheduler(optimizer, opt):
38 """Return a learning rate scheduler
39
40 Parameters:
41 optimizer -- the optimizer of the network
42 opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
43 opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
44
45 For 'linear', we keep the same learning rate for the first <opt.niter> epochs
46 and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
47 For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
48 See https://pytorch.org/docs/stable/optim.html for more details.
49 """
50 if opt.lr_policy == 'linear':
51 def lambda_rule(epoch):
52 lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
53 return lr_l
54 scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
55 elif opt.lr_policy == 'step':
56 scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
57 elif opt.lr_policy == 'plateau':
58 scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
59 elif opt.lr_policy == 'cosine':
60 scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
61 else:
62 return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
63 return scheduler
64
65
66 def init_weights(net, init_type='normal', init_gain=0.02):
67 """Initialize network weights.
68
69 Parameters:
70 net (network) -- network to be initialized
71 init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
72 init_gain (float) -- scaling factor for normal, xavier and orthogonal.
73
74 We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
75 work better for some applications. Feel free to try yourself.
76 """
77 def init_func(m): # define the initialization function
78 classname = m.__class__.__name__
79 if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
80 if init_type == 'normal':
81 init.normal_(m.weight.data, 0.0, init_gain)
82 elif init_type == 'xavier':
83 init.xavier_normal_(m.weight.data, gain=init_gain)
84 elif init_type == 'kaiming':
85 init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
86 elif init_type == 'orthogonal':
87 init.orthogonal_(m.weight.data, gain=init_gain)
88 else:
89 raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
90 if hasattr(m, 'bias') and m.bias is not None:
91 init.constant_(m.bias.data, 0.0)
92 elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
93 init.normal_(m.weight.data, 1.0, init_gain)
94 init.constant_(m.bias.data, 0.0)
95
96 print('initialize network with %s' % init_type)
97 net.apply(init_func) # apply the initialization function <init_func>
98
99
100 def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
101 """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
102 Parameters:
103 net (network) -- the network to be initialized
104 init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
105 gain (float) -- scaling factor for normal, xavier and orthogonal.
106 gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
107
108 Return an initialized network.
109 """
110 if len(gpu_ids) > 0:
111 assert(torch.cuda.is_available())
112 net.to(gpu_ids[0])
113 net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
114 init_weights(net, init_type, init_gain=init_gain)
115 return net
116
117
118 def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
119 """Create a generator
120
121 Parameters:
122 input_nc (int) -- the number of channels in input images
123 output_nc (int) -- the number of channels in output images
124 ngf (int) -- the number of filters in the last conv layer
125 netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
126 norm (str) -- the name of normalization layers used in the network: batch | instance | none
127 use_dropout (bool) -- if use dropout layers.
128 init_type (str) -- the name of our initialization method.
129 init_gain (float) -- scaling factor for normal, xavier and orthogonal.
130 gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
131
132 Returns a generator
133
134 Our current implementation provides two types of generators:
135 U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
136 The original U-Net paper: https://arxiv.org/abs/1505.04597
137
138 Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
139 Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
140 We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
141
142
143 The generator has been initialized by <init_net>. It uses RELU for non-linearity.
144 """
145 net = None
146 norm_layer = get_norm_layer(norm_type=norm)
147
148 if netG == 'resnet_9blocks':
149 net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
150 elif netG == 'resnet_6blocks':
151 net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
152 elif netG == 'unet_128':
153 net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
154 elif netG == 'unet_256':
155 net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
156 else:
157 raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
158 return init_net(net, init_type, init_gain, gpu_ids)
159
160
161 def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
162 """Create a discriminator
163
164 Parameters:
165 input_nc (int) -- the number of channels in input images
166 ndf (int) -- the number of filters in the first conv layer
167 netD (str) -- the architecture's name: basic | n_layers | pixel
168 n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
169 norm (str) -- the type of normalization layers used in the network.
170 init_type (str) -- the name of the initialization method.
171 init_gain (float) -- scaling factor for normal, xavier and orthogonal.
172 gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
173
174 Returns a discriminator
175
176 Our current implementation provides three types of discriminators:
177 [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
178 It can classify whether 70×70 overlapping patches are real or fake.
179 Such a patch-level discriminator architecture has fewer parameters
180 than a full-image discriminator and can work on arbitrarily-sized images
181 in a fully convolutional fashion.
182
183 [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
184 with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
185
186 [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
187 It encourages greater color diversity but has no effect on spatial statistics.
188
189 The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
190 """
191 net = None
192 norm_layer = get_norm_layer(norm_type=norm)
193
194 if netD == 'basic': # default PatchGAN classifier
195 net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
196 elif netD == 'n_layers': # more options
197 net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
198 elif netD == 'pixel': # classify if each pixel is real or fake
199 net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
200 else:
201 raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
202 return init_net(net, init_type, init_gain, gpu_ids)
203
204
205 ##############################################################################
206 # Classes
207 ##############################################################################
208 class GANLoss(nn.Module):
209 """Define different GAN objectives.
210
211 The GANLoss class abstracts away the need to create the target label tensor
212 that has the same size as the input.
213 """
214
215 def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
216 """ Initialize the GANLoss class.
217
218 Parameters:
219 gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
220 target_real_label (bool) - - label for a real image
221 target_fake_label (bool) - - label of a fake image
222
223 Note: Do not use sigmoid as the last layer of Discriminator.
224 LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
225 """
226 super(GANLoss, self).__init__()
227 self.register_buffer('real_label', torch.tensor(target_real_label))
228 self.register_buffer('fake_label', torch.tensor(target_fake_label))
229 self.gan_mode = gan_mode
230 if gan_mode == 'lsgan':
231 self.loss = nn.MSELoss()
232 elif gan_mode == 'vanilla':
233 self.loss = nn.BCEWithLogitsLoss()
234 elif gan_mode in ['wgangp']:
235 self.loss = None
236 else:
237 raise NotImplementedError('gan mode %s not implemented' % gan_mode)
238
239 def get_target_tensor(self, prediction, target_is_real):
240 """Create label tensors with the same size as the input.
241
242 Parameters:
243 prediction (tensor) - - tpyically the prediction from a discriminator
244 target_is_real (bool) - - if the ground truth label is for real images or fake images
245
246 Returns:
247 A label tensor filled with ground truth label, and with the size of the input
248 """
249
250 if target_is_real:
251 target_tensor = self.real_label
252 else:
253 target_tensor = self.fake_label
254 return target_tensor.expand_as(prediction)
255
256 def __call__(self, prediction, target_is_real):
257 """Calculate loss given Discriminator's output and grount truth labels.
258
259 Parameters:
260 prediction (tensor) - - tpyically the prediction output from a discriminator
261 target_is_real (bool) - - if the ground truth label is for real images or fake images
262
263 Returns:
264 the calculated loss.
265 """
266 if self.gan_mode in ['lsgan', 'vanilla']:
267 target_tensor = self.get_target_tensor(prediction, target_is_real)
268 loss = self.loss(prediction, target_tensor)
269 elif self.gan_mode == 'wgangp':
270 if target_is_real:
271 loss = -prediction.mean()
272 else:
273 loss = prediction.mean()
274 return loss
275
276
277 def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
278 """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
279
280 Arguments:
281 netD (network) -- discriminator network
282 real_data (tensor array) -- real images
283 fake_data (tensor array) -- generated images from the generator
284 device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
285 type (str) -- if we mix real and fake data or not [real | fake | mixed].
286 constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
287 lambda_gp (float) -- weight for this loss
288
289 Returns the gradient penalty loss
290 """
291 if lambda_gp > 0.0:
292 if type == 'real': # either use real images, fake images, or a linear interpolation of two.
293 interpolatesv = real_data
294 elif type == 'fake':
295 interpolatesv = fake_data
296 elif type == 'mixed':
297 alpha = torch.rand(real_data.shape[0], 1, device=device)
298 alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
299 interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
300 else:
301 raise NotImplementedError('{} not implemented'.format(type))
302 interpolatesv.requires_grad_(True)
303 disc_interpolates = netD(interpolatesv)
304 gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
305 grad_outputs=torch.ones(disc_interpolates.size()).to(device),
306 create_graph=True, retain_graph=True, only_inputs=True)
307 gradients = gradients[0].view(real_data.size(0), -1) # flat the data
308 gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
309 return gradient_penalty, gradients
310 else:
311 return 0.0, None
312
313
314 class ResnetGenerator(nn.Module):
315 """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
316
317 We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
318 """
319
320 def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
321 """Construct a Resnet-based generator
322
323 Parameters:
324 input_nc (int) -- the number of channels in input images
325 output_nc (int) -- the number of channels in output images
326 ngf (int) -- the number of filters in the last conv layer
327 norm_layer -- normalization layer
328 use_dropout (bool) -- if use dropout layers
329 n_blocks (int) -- the number of ResNet blocks
330 padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
331 """
332 assert(n_blocks >= 0)
333 super(ResnetGenerator, self).__init__()
334 if type(norm_layer) == functools.partial:
335 use_bias = norm_layer.func == nn.InstanceNorm2d
336 else:
337 use_bias = norm_layer == nn.InstanceNorm2d
338
339 model = [nn.ReflectionPad2d(3),
340 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
341 norm_layer(ngf),
342 nn.ReLU(True)]
343
344 n_downsampling = 2
345 for i in range(n_downsampling): # add downsampling layers
346 mult = 2 ** i
347 model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
348 norm_layer(ngf * mult * 2),
349 nn.ReLU(True)]
350
351 mult = 2 ** n_downsampling
352 for i in range(n_blocks): # add ResNet blocks
353
354 model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
355
356 for i in range(n_downsampling): # add upsampling layers
357 mult = 2 ** (n_downsampling - i)
358 model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
359 kernel_size=3, stride=2,
360 padding=1, output_padding=1,
361 bias=use_bias),
362 norm_layer(int(ngf * mult / 2)),
363 nn.ReLU(True)]
364 model += [nn.ReflectionPad2d(3)]
365 model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
366 model += [nn.Tanh()]
367
368 self.model = nn.Sequential(*model)
369
370 def forward(self, input):
371 """Standard forward"""
372 return self.model(input)
373
374
375 class ResnetBlock(nn.Module):
376 """Define a Resnet block"""
377
378 def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
379 """Initialize the Resnet block
380
381 A resnet block is a conv block with skip connections
382 We construct a conv block with build_conv_block function,
383 and implement skip connections in <forward> function.
384 Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
385 """
386 super(ResnetBlock, self).__init__()
387 self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
388
389 def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
390 """Construct a convolutional block.
391
392 Parameters:
393 dim (int) -- the number of channels in the conv layer.
394 padding_type (str) -- the name of padding layer: reflect | replicate | zero
395 norm_layer -- normalization layer
396 use_dropout (bool) -- if use dropout layers.
397 use_bias (bool) -- if the conv layer uses bias or not
398
399 Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
400 """
401 conv_block = []
402 p = 0
403 if padding_type == 'reflect':
404 conv_block += [nn.ReflectionPad2d(1)]
405 elif padding_type == 'replicate':
406 conv_block += [nn.ReplicationPad2d(1)]
407 elif padding_type == 'zero':
408 p = 1
409 else:
410 raise NotImplementedError('padding [%s] is not implemented' % padding_type)
411
412 conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
413 if use_dropout:
414 conv_block += [nn.Dropout(0.5)]
415
416 p = 0
417 if padding_type == 'reflect':
418 conv_block += [nn.ReflectionPad2d(1)]
419 elif padding_type == 'replicate':
420 conv_block += [nn.ReplicationPad2d(1)]
421 elif padding_type == 'zero':
422 p = 1
423 else:
424 raise NotImplementedError('padding [%s] is not implemented' % padding_type)
425 conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
426
427 return nn.Sequential(*conv_block)
428
429 def forward(self, x):
430 """Forward function (with skip connections)"""
431 out = x + self.conv_block(x) # add skip connections
432 return out
433
434
435 class UnetGenerator(nn.Module):
436 """Create a Unet-based generator"""
437
438 def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
439 """Construct a Unet generator
440 Parameters:
441 input_nc (int) -- the number of channels in input images
442 output_nc (int) -- the number of channels in output images
443 num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
444 image of size 128x128 will become of size 1x1 # at the bottleneck
445 ngf (int) -- the number of filters in the last conv layer
446 norm_layer -- normalization layer
447
448 We construct the U-Net from the innermost layer to the outermost layer.
449 It is a recursive process.
450 """
451 super(UnetGenerator, self).__init__()
452 # construct unet structure
453 unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
454 for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
455 unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
456 # gradually reduce the number of filters from ngf * 8 to ngf
457 unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
458 unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
459 unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
460 self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
461
462 def forward(self, input):
463 """Standard forward"""
464 return self.model(input)
465
466
467 class UnetSkipConnectionBlock(nn.Module):
468 """Defines the Unet submodule with skip connection.
469 X -------------------identity----------------------
470 |-- downsampling -- |submodule| -- upsampling --|
471 """
472
473 def __init__(self, outer_nc, inner_nc, input_nc=None,
474 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
475 """Construct a Unet submodule with skip connections.
476
477 Parameters:
478 outer_nc (int) -- the number of filters in the outer conv layer
479 inner_nc (int) -- the number of filters in the inner conv layer
480 input_nc (int) -- the number of channels in input images/features
481 submodule (UnetSkipConnectionBlock) -- previously defined submodules
482 outermost (bool) -- if this module is the outermost module
483 innermost (bool) -- if this module is the innermost module
484 norm_layer -- normalization layer
485 user_dropout (bool) -- if use dropout layers.
486 """
487 super(UnetSkipConnectionBlock, self).__init__()
488 self.outermost = outermost
489 if type(norm_layer) == functools.partial:
490 use_bias = norm_layer.func == nn.InstanceNorm2d
491 else:
492 use_bias = norm_layer == nn.InstanceNorm2d
493 if input_nc is None:
494 input_nc = outer_nc
495 downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
496 stride=2, padding=1, bias=use_bias)
497 downrelu = nn.LeakyReLU(0.2, True)
498 downnorm = norm_layer(inner_nc)
499 uprelu = nn.ReLU(True)
500 upnorm = norm_layer(outer_nc)
501
502 if outermost:
503 upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
504 kernel_size=4, stride=2,
505 padding=1)
506 down = [downconv]
507 up = [uprelu, upconv, nn.Tanh()]
508 model = down + [submodule] + up
509 elif innermost:
510 upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
511 kernel_size=4, stride=2,
512 padding=1, bias=use_bias)
513 down = [downrelu, downconv]
514 up = [uprelu, upconv, upnorm]
515 model = down + up
516 else:
517 upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
518 kernel_size=4, stride=2,
519 padding=1, bias=use_bias)
520 down = [downrelu, downconv, downnorm]
521 up = [uprelu, upconv, upnorm]
522
523 if use_dropout:
524 model = down + [submodule] + up + [nn.Dropout(0.5)]
525 else:
526 model = down + [submodule] + up
527
528 self.model = nn.Sequential(*model)
529
530 def forward(self, x):
531 if self.outermost:
532 return self.model(x)
533 else: # add skip connections
534 return torch.cat([x, self.model(x)], 1)
535
536
537 class NLayerDiscriminator(nn.Module):
538 """Defines a PatchGAN discriminator"""
539
540 def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
541 """Construct a PatchGAN discriminator
542
543 Parameters:
544 input_nc (int) -- the number of channels in input images
545 ndf (int) -- the number of filters in the last conv layer
546 n_layers (int) -- the number of conv layers in the discriminator
547 norm_layer -- normalization layer
548 """
549 super(NLayerDiscriminator, self).__init__()
550 if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
551 use_bias = norm_layer.func == nn.InstanceNorm2d
552 else:
553 use_bias = norm_layer == nn.InstanceNorm2d
554
555 kw = 4
556 padw = 1
557 sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
558 nf_mult = 1
559 nf_mult_prev = 1
560 for n in range(1, n_layers): # gradually increase the number of filters
561 nf_mult_prev = nf_mult
562 nf_mult = min(2 ** n, 8)
563 sequence += [
564 nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
565 norm_layer(ndf * nf_mult),
566 nn.LeakyReLU(0.2, True)
567 ]
568
569 nf_mult_prev = nf_mult
570 nf_mult = min(2 ** n_layers, 8)
571 sequence += [
572 nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
573 norm_layer(ndf * nf_mult),
574 nn.LeakyReLU(0.2, True)
575 ]
576
577 sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
578 self.model = nn.Sequential(*sequence)
579
580 def forward(self, input):
581 """Standard forward."""
582 return self.model(input)
583
584
585 class PixelDiscriminator(nn.Module):
586 """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
587
588 def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
589 """Construct a 1x1 PatchGAN discriminator
590
591 Parameters:
592 input_nc (int) -- the number of channels in input images
593 ndf (int) -- the number of filters in the last conv layer
594 norm_layer -- normalization layer
595 """
596 super(PixelDiscriminator, self).__init__()
597 if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
598 use_bias = norm_layer.func == nn.InstanceNorm2d
599 else:
600 use_bias = norm_layer == nn.InstanceNorm2d
601
602 self.net = [
603 nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
604 nn.LeakyReLU(0.2, True),
605 nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
606 norm_layer(ndf * 2),
607 nn.LeakyReLU(0.2, True),
608 nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
609
610 self.net = nn.Sequential(*self.net)
611
612 def forward(self, input):
613 """Standard forward."""
614 return self.net(input)
0 import torch
1 from .base_model import BaseModel
2 from . import networks
3
4
5 class Pix2PixModel(BaseModel):
6 """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
7
8 The model training requires '--dataset_mode aligned' dataset.
9 By default, it uses a '--netG unet256' U-Net generator,
10 a '--netD basic' discriminator (PatchGAN),
11 and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
12
13 pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
14 """
15 @staticmethod
16 def modify_commandline_options(parser, is_train=True):
17 """Add new dataset-specific options, and rewrite default values for existing options.
18
19 Parameters:
20 parser -- original option parser
21 is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
22
23 Returns:
24 the modified parser.
25
26 For pix2pix, we do not use image buffer
27 The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1
28 By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets.
29 """
30 # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/)
31 parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned')
32 if is_train:
33 parser.set_defaults(pool_size=0, gan_mode='vanilla')
34 parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
35
36 return parser
37
38 def __init__(self, opt):
39 """Initialize the pix2pix class.
40
41 Parameters:
42 opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
43 """
44 BaseModel.__init__(self, opt)
45 # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
46 self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
47 # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
48 self.visual_names = ['real_A', 'fake_B', 'real_B']
49 # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
50 if self.isTrain:
51 self.model_names = ['G', 'D']
52 else: # during test time, only load G
53 self.model_names = ['G']
54 # define networks (both generator and discriminator)
55 self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
56 not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
57
58 if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
59 self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
60 opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
61
62 if self.isTrain:
63 # define loss functions
64 self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
65 self.criterionL1 = torch.nn.L1Loss()
66 # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
67 self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
68 self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
69 self.optimizers.append(self.optimizer_G)
70 self.optimizers.append(self.optimizer_D)
71
72 def set_input(self, input):
73 """Unpack input data from the dataloader and perform necessary pre-processing steps.
74
75 Parameters:
76 input (dict): include the data itself and its metadata information.
77
78 The option 'direction' can be used to swap images in domain A and domain B.
79 """
80 AtoB = self.opt.direction == 'AtoB'
81 self.real_A = input['A' if AtoB else 'B'].to(self.device)
82 self.real_B = input['B' if AtoB else 'A'].to(self.device)
83 self.image_paths = input['A_paths' if AtoB else 'B_paths']
84
85 def forward(self):
86 """Run forward pass; called by both functions <optimize_parameters> and <test>."""
87 self.fake_B = self.netG(self.real_A) # G(A)
88
89 def backward_D(self):
90 """Calculate GAN loss for the discriminator"""
91 # Fake; stop backprop to the generator by detaching fake_B
92 fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator
93 pred_fake = self.netD(fake_AB.detach())
94 self.loss_D_fake = self.criterionGAN(pred_fake, False)
95 # Real
96 real_AB = torch.cat((self.real_A, self.real_B), 1)
97 pred_real = self.netD(real_AB)
98 self.loss_D_real = self.criterionGAN(pred_real, True)
99 # combine loss and calculate gradients
100 self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
101 self.loss_D.backward()
102
103 def backward_G(self):
104 """Calculate GAN and L1 loss for the generator"""
105 # First, G(A) should fake the discriminator
106 fake_AB = torch.cat((self.real_A, self.fake_B), 1)
107 pred_fake = self.netD(fake_AB)
108 self.loss_G_GAN = self.criterionGAN(pred_fake, True)
109 # Second, G(A) = B
110 self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
111 # combine loss and calculate gradients
112 self.loss_G = self.loss_G_GAN + self.loss_G_L1
113 self.loss_G.backward()
114
115 def optimize_parameters(self):
116 self.forward() # compute fake images: G(A)
117 # update D
118 self.set_requires_grad(self.netD, True) # enable backprop for D
119 self.optimizer_D.zero_grad() # set D's gradients to zero
120 self.backward_D() # calculate gradients for D
121 self.optimizer_D.step() # update D's weights
122 # update G
123 self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
124 self.optimizer_G.zero_grad() # set G's gradients to zero
125 self.backward_G() # calculate graidents for G
126 self.optimizer_G.step() # udpate G's weights
0 """Model class template
1
2 This module provides a template for users to implement custom models.
3 You can specify '--model template' to use this model.
4 The class name should be consistent with both the filename and its model option.
5 The filename should be <model>_dataset.py
6 The class name should be <Model>Dataset.py
7 It implements a simple image-to-image translation baseline based on regression loss.
8 Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
9 min_<netG> ||netG(data_A) - data_B||_1
10 You need to implement the following functions:
11 <modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
12 <__init__>: Initialize this model class.
13 <set_input>: Unpack input data and perform data pre-processing.
14 <forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
15 <optimize_parameters>: Update network weights; it will be called in every training iteration.
16 """
17 import torch
18 from .base_model import BaseModel
19 from . import networks
20
21
22 class TemplateModel(BaseModel):
23 @staticmethod
24 def modify_commandline_options(parser, is_train=True):
25 """Add new model-specific options and rewrite default values for existing options.
26
27 Parameters:
28 parser -- the option parser
29 is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
30
31 Returns:
32 the modified parser.
33 """
34 parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
35 if is_train:
36 parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
37
38 return parser
39
40 def __init__(self, opt):
41 """Initialize this model class.
42
43 Parameters:
44 opt -- training/test options
45
46 A few things can be done here.
47 - (required) call the initialization function of BaseModel
48 - define loss function, visualization images, model names, and optimizers
49 """
50 BaseModel.__init__(self, opt) # call the initialization method of BaseModel
51 # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
52 self.loss_names = ['loss_G']
53 # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
54 self.visual_names = ['data_A', 'data_B', 'output']
55 # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
56 # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
57 self.model_names = ['G']
58 # define networks; you can use opt.isTrain to specify different behaviors for training and test.
59 self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
60 if self.isTrain: # only defined during training time
61 # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
62 # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
63 self.criterionLoss = torch.nn.L1Loss()
64 # define and initialize optimizers. You can define one optimizer for each network.
65 # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
66 self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
67 self.optimizers = [self.optimizer]
68
69 # Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
70
71 def set_input(self, input):
72 """Unpack input data from the dataloader and perform necessary pre-processing steps.
73
74 Parameters:
75 input: a dictionary that contains the data itself and its metadata information.
76 """
77 AtoB = self.opt.direction == 'AtoB' # use <direction> to swap data_A and data_B
78 self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
79 self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
80 self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
81
82 def forward(self):
83 """Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
84 self.output = self.netG(self.data_A) # generate output image given the input data_A
85
86 def backward(self):
87 """Calculate losses, gradients, and update network weights; called in every training iteration"""
88 # caculate the intermediate results if necessary; here self.output has been computed during function <forward>
89 # calculate loss given the input and intermediate results
90 self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
91 self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
92
93 def optimize_parameters(self):
94 """Update network weights; it will be called in every training iteration."""
95 self.forward() # first call forward to calculate intermediate results
96 self.optimizer.zero_grad() # clear network G's existing gradients
97 self.backward() # calculate gradients for network G
98 self.optimizer.step() # update gradients for network G
0 from .base_model import BaseModel
1 from . import networks
2
3
4 class TestModel(BaseModel):
5 """ This TesteModel can be used to generate CycleGAN results for only one direction.
6 This model will automatically set '--dataset_mode single', which only loads the images from one collection.
7
8 See the test instruction for more details.
9 """
10 @staticmethod
11 def modify_commandline_options(parser, is_train=True):
12 """Add new dataset-specific options, and rewrite default values for existing options.
13
14 Parameters:
15 parser -- original option parser
16 is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
17
18 Returns:
19 the modified parser.
20
21 The model can only be used during test time. It requires '--dataset_mode single'.
22 You need to specify the network using the option '--model_suffix'.
23 """
24 assert not is_train, 'TestModel cannot be used during training time'
25 parser.set_defaults(dataset_mode='single')
26 parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.')
27
28 return parser
29
30 def __init__(self, opt):
31 """Initialize the pix2pix class.
32
33 Parameters:
34 opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
35 """
36 assert(not opt.isTrain)
37 BaseModel.__init__(self, opt)
38 # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
39 self.loss_names = []
40 # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
41 self.visual_names = ['real', 'fake']
42 # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
43 self.model_names = ['G' + opt.model_suffix] # only generator is needed.
44 self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
45 opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
46
47 # assigns the model to self.netG_[suffix] so that it can be loaded
48 # please see <BaseModel.load_networks>
49 setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self.
50
51 def set_input(self, input):
52 """Unpack input data from the dataloader and perform necessary pre-processing steps.
53
54 Parameters:
55 input: a dictionary that contains the data itself and its metadata information.
56
57 We need to use 'single_dataset' dataset mode. It only load images from one domain.
58 """
59 self.real = input['A'].to(self.device)
60 self.image_paths = input['A_paths']
61
62 def forward(self):
63 """Run forward pass."""
64 self.fake = self.netG(self.real) # G(real)
65
66 def optimize_parameters(self):
67 """No optimization for test model."""
68 pass
0 """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
0 import argparse
1 import os
2 from util import util
3 import torch
4 import models
5 import data
6
7
8 class BaseOptions():
9 """This class defines options used during both training and test time.
10
11 It also implements several helper functions such as parsing, printing, and saving the options.
12 It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
13 """
14
15 def __init__(self):
16 """Reset the class; indicates the class hasn't been initailized"""
17 self.initialized = False
18
19 def initialize(self, parser):
20 """Define the common options that are used in both training and test."""
21 # basic parameters
22 parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
23 parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
24 parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
25 parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
26 # model parameters
27 parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
28 parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
29 parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
30 parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
31 parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
32 parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
33 parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
34 parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
35 parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
36 parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
37 parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
38 parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
39 # dataset parameters
40 parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
41 parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
42 parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
43 parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
44 parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
45 parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
46 parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
47 parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
48 parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
49 parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
50 parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
51 # additional parameters
52 parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
53 parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
54 parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
55 parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
56 self.initialized = True
57 return parser
58
59 def gather_options(self):
60 """Initialize our parser with basic options(only once).
61 Add additional model-specific and dataset-specific options.
62 These options are defined in the <modify_commandline_options> function
63 in model and dataset classes.
64 """
65 if not self.initialized: # check if it has been initialized
66 parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
67 parser = self.initialize(parser)
68
69 # get the basic options
70 opt, _ = parser.parse_known_args()
71
72 # modify model-related parser options
73 model_name = opt.model
74 model_option_setter = models.get_option_setter(model_name)
75 parser = model_option_setter(parser, self.isTrain)
76 opt, _ = parser.parse_known_args() # parse again with new defaults
77
78 # modify dataset-related parser options
79 dataset_name = opt.dataset_mode
80 dataset_option_setter = data.get_option_setter(dataset_name)
81 parser = dataset_option_setter(parser, self.isTrain)
82
83 # save and return the parser
84 self.parser = parser
85 return parser.parse_args()
86
87 def print_options(self, opt):
88 """Print and save options
89
90 It will print both current options and default values(if different).
91 It will save options into a text file / [checkpoints_dir] / opt.txt
92 """
93 message = ''
94 message += '----------------- Options ---------------\n'
95 for k, v in sorted(vars(opt).items()):
96 comment = ''
97 default = self.parser.get_default(k)
98 if v != default:
99 comment = '\t[default: %s]' % str(default)
100 message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
101 message += '----------------- End -------------------'
102 print(message)
103
104 # save to the disk
105 expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
106 util.mkdirs(expr_dir)
107 file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
108 with open(file_name, 'wt') as opt_file:
109 opt_file.write(message)
110 opt_file.write('\n')
111
112 def parse(self):
113 """Parse our options, create checkpoints directory suffix, and set up gpu device."""
114 opt = self.gather_options()
115 opt.isTrain = self.isTrain # train or test
116
117 # process opt.suffix
118 if opt.suffix:
119 suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
120 opt.name = opt.name + suffix
121
122 self.print_options(opt)
123
124 # set gpu ids
125 str_ids = opt.gpu_ids.split(',')
126 opt.gpu_ids = []
127 for str_id in str_ids:
128 id = int(str_id)
129 if id >= 0:
130 opt.gpu_ids.append(id)
131 if len(opt.gpu_ids) > 0:
132 torch.cuda.set_device(opt.gpu_ids[0])
133
134 self.opt = opt
135 return self.opt
0 from .base_options import BaseOptions
1
2
3 class TestOptions(BaseOptions):
4 """This class includes test options.
5
6 It also includes shared options defined in BaseOptions.
7 """
8
9 def initialize(self, parser):
10 parser = BaseOptions.initialize(self, parser) # define shared options
11 parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
12 parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
13 parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
14 parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
15 # Dropout and Batchnorm has different behavioir during training and test.
16 parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
17 parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
18 # rewrite devalue values
19 parser.set_defaults(model='test')
20 # To avoid cropping, the load_size should be the same as crop_size
21 parser.set_defaults(load_size=parser.get_default('crop_size'))
22 self.isTrain = False
23 return parser
0 from .base_options import BaseOptions
1
2
3 class TrainOptions(BaseOptions):
4 """This class includes training options.
5
6 It also includes shared options defined in BaseOptions.
7 """
8
9 def initialize(self, parser):
10 parser = BaseOptions.initialize(self, parser)
11 # visdom and HTML visualization parameters
12 parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
13 parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
14 parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
15 parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
16 parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
17 parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
18 parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
19 parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
20 parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
21 # network saving and loading parameters
22 parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
23 parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
24 parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
25 parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
26 parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
27 parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
28 # training parameters
29 parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
30 parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
31 parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
32 parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
33 parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
34 parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
35 parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
36 parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
37
38 self.isTrain = True
39 return parser
0 torch>=0.4.1
1 torchvision>=0.2.1
2 dominate>=2.3.1
3 visdom>=0.1.8.3
0 Please store your training checkpoints or results here
1 请在此处存储 checkpoints 和结果文件
0 Please store your tensorboard results here
1 请在此处存储 tensorboard 结果
0 set -ex
1 conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing
2 conda install pytorch torchvision -c pytorch # add cuda90 if CUDA 9
3 conda install visdom dominate -c conda-forge # install visdom and dominate
0 FILE=$1
1
2 echo "Note: available models are apple2orange, orange2apple, summer2winter_yosemite, winter2summer_yosemite, horse2zebra, zebra2horse, monet2photo, style_monet, style_cezanne, style_ukiyoe, style_vangogh, sat2map, map2sat, cityscapes_photo2label, cityscapes_label2photo, facades_photo2label, facades_label2photo, iphone2dslr_flower"
3
4 echo "Specified [$FILE]"
5
6 mkdir -p ./checkpoints/${FILE}_pretrained
7 MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.pth
8 URL=http://efrosgans.eecs.berkeley.edu/cyclegan/pretrained_models/$FILE.pth
9
10 wget -N $URL -O $MODEL_FILE
0 FILE=$1
1
2 echo "Note: available models are edges2shoes, sat2map, map2sat, facades_label2photo, and day2night"
3 echo "Specified [$FILE]"
4
5 mkdir -p ./checkpoints/${FILE}_pretrained
6 MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.pth
7 URL=http://efrosgans.eecs.berkeley.edu/pix2pix/models-pytorch/$FILE.pth
8
9 wget -N $URL -O $MODEL_FILE
0 %%% Prerequisites
1 % You need to get the cpp file edgesNmsMex.cpp from https://raw.githubusercontent.com/pdollar/edges/master/private/edgesNmsMex.cpp
2 % and compile it in Matlab: mex edgesNmsMex.cpp
3 % You also need to download and install Piotr's Computer Vision Matlab Toolbox: https://pdollar.github.io/toolbox/
4
5 %%% parameters
6 % hed_mat_dir: the hed mat file directory (the output of 'batch_hed.py')
7 % edge_dir: the output HED edges directory
8 % image_width: resize the edge map to [image_width, image_width]
9 % threshold: threshold for image binarization (default 25.0/255.0)
10 % small_edge: remove small edges (default 5)
11
12 function [] = PostprocessHED(hed_mat_dir, edge_dir, image_width, threshold, small_edge)
13
14 if ~exist(edge_dir, 'dir')
15 mkdir(edge_dir);
16 end
17 fileList = dir(fullfile(hed_mat_dir, '*.mat'));
18 nFiles = numel(fileList);
19 fprintf('find %d mat files\n', nFiles);
20
21 for n = 1 : nFiles
22 if mod(n, 1000) == 0
23 fprintf('process %d/%d images\n', n, nFiles);
24 end
25 fileName = fileList(n).name;
26 filePath = fullfile(hed_mat_dir, fileName);
27 jpgName = strrep(fileName, '.mat', '.jpg');
28 edge_path = fullfile(edge_dir, jpgName);
29
30 if ~exist(edge_path, 'file')
31 E = GetEdge(filePath);
32 E = imresize(E,[image_width,image_width]);
33 E_simple = SimpleEdge(E, threshold, small_edge);
34 E_simple = uint8(E_simple*255);
35 imwrite(E_simple, edge_path, 'Quality',100);
36 end
37 end
38 end
39
40
41
42
43 function [E] = GetEdge(filePath)
44 load(filePath);
45 E = 1-predict;
46 end
47
48 function [E4] = SimpleEdge(E, threshold, small_edge)
49 if nargin <= 1
50 threshold = 25.0/255.0;
51 end
52
53 if nargin <= 2
54 small_edge = 5;
55 end
56
57 if ndims(E) == 3
58 E = E(:,:,1);
59 end
60
61 E1 = 1 - E;
62 E2 = EdgeNMS(E1);
63 E3 = double(E2>=max(eps,threshold));
64 E3 = bwmorph(E3,'thin',inf);
65 E4 = bwareaopen(E3, small_edge);
66 E4=1-E4;
67 end
68
69 function [E_nms] = EdgeNMS( E )
70 E=single(E);
71 [Ox,Oy] = gradient2(convTri(E,4));
72 [Oxx,~] = gradient2(Ox);
73 [Oxy,Oyy] = gradient2(Oy);
74 O = mod(atan(Oyy.*sign(-Oxy)./(Oxx+1e-5)),pi);
75 E_nms = edgesNmsMex(E,O,1,5,1.01,1);
76 end
0 # HED batch processing script; modified from https://github.com/s9xie/hed/blob/master/examples/hed/HED-tutorial.ipynb
1 # Step 1: download the hed repo: https://github.com/s9xie/hed
2 # Step 2: download the models and protoxt, and put them under {caffe_root}/examples/hed/
3 # Step 3: put this script under {caffe_root}/examples/hed/
4 # Step 4: run the following script:
5 # python batch_hed.py --images_dir=/data/to/path/photos/ --hed_mat_dir=/data/to/path/hed_mat_files/
6 # The code sometimes crashes after computation is done. Error looks like "Check failed: ... driver shutting down". You can just kill the job.
7 # For large images, it will produce gpu memory issue. Therefore, you better resize the images before running this script.
8 # Step 5: run the MATLAB post-processing script "PostprocessHED.m"
9
10
11 import numpy as np
12 from PIL import Image
13 import os
14 import argparse
15 import sys
16 import scipy.io as sio
17
18
19 def parse_args():
20 parser = argparse.ArgumentParser(description='batch proccesing: photos->edges')
21 parser.add_argument('--caffe_root', dest='caffe_root', help='caffe root', default='../../', type=str)
22 parser.add_argument('--caffemodel', dest='caffemodel', help='caffemodel', default='./hed_pretrained_bsds.caffemodel', type=str)
23 parser.add_argument('--prototxt', dest='prototxt', help='caffe prototxt file', default='./deploy.prototxt', type=str)
24 parser.add_argument('--images_dir', dest='images_dir', help='directory to store input photos', type=str)
25 parser.add_argument('--hed_mat_dir', dest='hed_mat_dir', help='directory to store output hed edges in mat file', type=str)
26 parser.add_argument('--border', dest='border', help='padding border', type=int, default=128)
27 parser.add_argument('--gpu_id', dest='gpu_id', help='gpu id', type=int, default=1)
28 args = parser.parse_args()
29 return args
30
31
32 args = parse_args()
33 for arg in vars(args):
34 print('[%s] =' % arg, getattr(args, arg))
35 # Make sure that caffe is on the python path:
36 caffe_root = args.caffe_root # this file is expected to be in {caffe_root}/examples/hed/
37 sys.path.insert(0, caffe_root + 'python')
38 import caffe
39
40
41 if not os.path.exists(args.hed_mat_dir):
42 print('create output directory %s' % args.hed_mat_dir)
43 os.makedirs(args.hed_mat_dir)
44
45 imgList = os.listdir(args.images_dir)
46 nImgs = len(imgList)
47 print('#images = %d' % nImgs)
48
49 caffe.set_mode_gpu()
50 caffe.set_device(args.gpu_id)
51 # load net
52 net = caffe.Net(args.prototxt, args.caffemodel, caffe.TEST)
53 # pad border
54 border = args.border
55
56 for i in range(nImgs):
57 if i % 500 == 0:
58 print('processing image %d/%d' % (i, nImgs))
59 im = Image.open(os.path.join(args.images_dir, imgList[i]))
60
61 in_ = np.array(im, dtype=np.float32)
62 in_ = np.pad(in_, ((border, border), (border, border), (0, 0)), 'reflect')
63
64 in_ = in_[:, :, 0:3]
65 in_ = in_[:, :, ::-1]
66 in_ -= np.array((104.00698793, 116.66876762, 122.67891434))
67 in_ = in_.transpose((2, 0, 1))
68 # remove the following two lines if testing with cpu
69
70 # shape for input (data blob is N x C x H x W), set data
71 net.blobs['data'].reshape(1, *in_.shape)
72 net.blobs['data'].data[...] = in_
73 # run net and take argmax for prediction
74 net.forward()
75 fuse = net.blobs['sigmoid-fuse'].data[0][0, :, :]
76 # get rid of the border
77 fuse = fuse[border:-border, border:-border]
78 # save hed file to the disk
79 name, ext = os.path.splitext(imgList[i])
80 sio.savemat(os.path.join(args.hed_mat_dir, name + '.mat'), {'predict': fuse})
0 layer {
1 name: "data"
2 type: "Input"
3 top: "data"
4 input_param {
5 shape {
6 dim: 1
7 dim: 3
8 dim: 500
9 dim: 500
10 }
11 }
12 }
13 layer {
14 name: "conv1_1"
15 type: "Convolution"
16 bottom: "data"
17 top: "conv1_1"
18 param {
19 lr_mult: 1
20 decay_mult: 1
21 }
22 param {
23 lr_mult: 2
24 decay_mult: 0
25 }
26 convolution_param {
27 num_output: 64
28 pad: 100
29 kernel_size: 3
30 stride: 1
31 weight_filler {
32 type: "gaussian"
33 std: 0.01
34 }
35 bias_filler {
36 type: "constant"
37 value: 0
38 }
39 }
40 }
41 layer {
42 name: "relu1_1"
43 type: "ReLU"
44 bottom: "conv1_1"
45 top: "conv1_1"
46 }
47 layer {
48 name: "conv1_2"
49 type: "Convolution"
50 bottom: "conv1_1"
51 top: "conv1_2"
52 param {
53 lr_mult: 1
54 decay_mult: 1
55 }
56 param {
57 lr_mult: 2
58 decay_mult: 0
59 }
60 convolution_param {
61 num_output: 64
62 pad: 1
63 kernel_size: 3
64 stride: 1
65 weight_filler {
66 type: "gaussian"
67 std: 0.01
68 }
69 bias_filler {
70 type: "constant"
71 value: 0
72 }
73 }
74 }
75 layer {
76 name: "relu1_2"
77 type: "ReLU"
78 bottom: "conv1_2"
79 top: "conv1_2"
80 }
81 layer {
82 name: "pool1"
83 type: "Pooling"
84 bottom: "conv1_2"
85 top: "pool1"
86 pooling_param {
87 pool: MAX
88 kernel_size: 2
89 stride: 2
90 }
91 }
92 layer {
93 name: "conv2_1"
94 type: "Convolution"
95 bottom: "pool1"
96 top: "conv2_1"
97 param {
98 lr_mult: 1
99 decay_mult: 1
100 }
101 param {
102 lr_mult: 2
103 decay_mult: 0
104 }
105 convolution_param {
106 num_output: 128
107 pad: 1
108 kernel_size: 3
109 stride: 1
110 weight_filler {
111 type: "gaussian"
112 std: 0.01
113 }
114 bias_filler {
115 type: "constant"
116 value: 0
117 }
118 }
119 }
120 layer {
121 name: "relu2_1"
122 type: "ReLU"
123 bottom: "conv2_1"
124 top: "conv2_1"
125 }
126 layer {
127 name: "conv2_2"
128 type: "Convolution"
129 bottom: "conv2_1"
130 top: "conv2_2"
131 param {
132 lr_mult: 1
133 decay_mult: 1
134 }
135 param {
136 lr_mult: 2
137 decay_mult: 0
138 }
139 convolution_param {
140 num_output: 128
141 pad: 1
142 kernel_size: 3
143 stride: 1
144 weight_filler {
145 type: "gaussian"
146 std: 0.01
147 }
148 bias_filler {
149 type: "constant"
150 value: 0
151 }
152 }
153 }
154 layer {
155 name: "relu2_2"
156 type: "ReLU"
157 bottom: "conv2_2"
158 top: "conv2_2"
159 }
160 layer {
161 name: "pool2"
162 type: "Pooling"
163 bottom: "conv2_2"
164 top: "pool2"
165 pooling_param {
166 pool: MAX
167 kernel_size: 2
168 stride: 2
169 }
170 }
171 layer {
172 name: "conv3_1"
173 type: "Convolution"
174 bottom: "pool2"
175 top: "conv3_1"
176 param {
177 lr_mult: 1
178 decay_mult: 1
179 }
180 param {
181 lr_mult: 2
182 decay_mult: 0
183 }
184 convolution_param {
185 num_output: 256
186 pad: 1
187 kernel_size: 3
188 stride: 1
189 weight_filler {
190 type: "gaussian"
191 std: 0.01
192 }
193 bias_filler {
194 type: "constant"
195 value: 0
196 }
197 }
198 }
199 layer {
200 name: "relu3_1"
201 type: "ReLU"
202 bottom: "conv3_1"
203 top: "conv3_1"
204 }
205 layer {
206 name: "conv3_2"
207 type: "Convolution"
208 bottom: "conv3_1"
209 top: "conv3_2"
210 param {
211 lr_mult: 1
212 decay_mult: 1
213 }
214 param {
215 lr_mult: 2
216 decay_mult: 0
217 }
218 convolution_param {
219 num_output: 256
220 pad: 1
221 kernel_size: 3
222 stride: 1
223 weight_filler {
224 type: "gaussian"
225 std: 0.01
226 }
227 bias_filler {
228 type: "constant"
229 value: 0
230 }
231 }
232 }
233 layer {
234 name: "relu3_2"
235 type: "ReLU"
236 bottom: "conv3_2"
237 top: "conv3_2"
238 }
239 layer {
240 name: "conv3_3"
241 type: "Convolution"
242 bottom: "conv3_2"
243 top: "conv3_3"
244 param {
245 lr_mult: 1
246 decay_mult: 1
247 }
248 param {
249 lr_mult: 2
250 decay_mult: 0
251 }
252 convolution_param {
253 num_output: 256
254 pad: 1
255 kernel_size: 3
256 stride: 1
257 weight_filler {
258 type: "gaussian"
259 std: 0.01
260 }
261 bias_filler {
262 type: "constant"
263 value: 0
264 }
265 }
266 }
267 layer {
268 name: "relu3_3"
269 type: "ReLU"
270 bottom: "conv3_3"
271 top: "conv3_3"
272 }
273 layer {
274 name: "pool3"
275 type: "Pooling"
276 bottom: "conv3_3"
277 top: "pool3"
278 pooling_param {
279 pool: MAX
280 kernel_size: 2
281 stride: 2
282 }
283 }
284 layer {
285 name: "conv4_1"
286 type: "Convolution"
287 bottom: "pool3"
288 top: "conv4_1"
289 param {
290 lr_mult: 1
291 decay_mult: 1
292 }
293 param {
294 lr_mult: 2
295 decay_mult: 0
296 }
297 convolution_param {
298 num_output: 512
299 pad: 1
300 kernel_size: 3
301 stride: 1
302 weight_filler {
303 type: "gaussian"
304 std: 0.01
305 }
306 bias_filler {
307 type: "constant"
308 value: 0
309 }
310 }
311 }
312 layer {
313 name: "relu4_1"
314 type: "ReLU"
315 bottom: "conv4_1"
316 top: "conv4_1"
317 }
318 layer {
319 name: "conv4_2"
320 type: "Convolution"
321 bottom: "conv4_1"
322 top: "conv4_2"
323 param {
324 lr_mult: 1
325 decay_mult: 1
326 }
327 param {
328 lr_mult: 2
329 decay_mult: 0
330 }
331 convolution_param {
332 num_output: 512
333 pad: 1
334 kernel_size: 3
335 stride: 1
336 weight_filler {
337 type: "gaussian"
338 std: 0.01
339 }
340 bias_filler {
341 type: "constant"
342 value: 0
343 }
344 }
345 }
346 layer {
347 name: "relu4_2"
348 type: "ReLU"
349 bottom: "conv4_2"
350 top: "conv4_2"
351 }
352 layer {
353 name: "conv4_3"
354 type: "Convolution"
355 bottom: "conv4_2"
356 top: "conv4_3"
357 param {
358 lr_mult: 1
359 decay_mult: 1
360 }
361 param {
362 lr_mult: 2
363 decay_mult: 0
364 }
365 convolution_param {
366 num_output: 512
367 pad: 1
368 kernel_size: 3
369 stride: 1
370 weight_filler {
371 type: "gaussian"
372 std: 0.01
373 }
374 bias_filler {
375 type: "constant"
376 value: 0
377 }
378 }
379 }
380 layer {
381 name: "relu4_3"
382 type: "ReLU"
383 bottom: "conv4_3"
384 top: "conv4_3"
385 }
386 layer {
387 name: "pool4"
388 type: "Pooling"
389 bottom: "conv4_3"
390 top: "pool4"
391 pooling_param {
392 pool: MAX
393 kernel_size: 2
394 stride: 2
395 }
396 }
397 layer {
398 name: "conv5_1"
399 type: "Convolution"
400 bottom: "pool4"
401 top: "conv5_1"
402 param {
403 lr_mult: 1
404 decay_mult: 1
405 }
406 param {
407 lr_mult: 2
408 decay_mult: 0
409 }
410 convolution_param {
411 num_output: 512
412 pad: 1
413 kernel_size: 3
414 stride: 1
415 weight_filler {
416 type: "gaussian"
417 std: 0.01
418 }
419 bias_filler {
420 type: "constant"
421 value: 0
422 }
423 }
424 }
425 layer {
426 name: "relu5_1"
427 type: "ReLU"
428 bottom: "conv5_1"
429 top: "conv5_1"
430 }
431 layer {
432 name: "conv5_2"
433 type: "Convolution"
434 bottom: "conv5_1"
435 top: "conv5_2"
436 param {
437 lr_mult: 1
438 decay_mult: 1
439 }
440 param {
441 lr_mult: 2
442 decay_mult: 0
443 }
444 convolution_param {
445 num_output: 512
446 pad: 1
447 kernel_size: 3
448 stride: 1
449 weight_filler {
450 type: "gaussian"
451 std: 0.01
452 }
453 bias_filler {
454 type: "constant"
455 value: 0
456 }
457 }
458 }
459 layer {
460 name: "relu5_2"
461 type: "ReLU"
462 bottom: "conv5_2"
463 top: "conv5_2"
464 }
465 layer {
466 name: "conv5_3"
467 type: "Convolution"
468 bottom: "conv5_2"
469 top: "conv5_3"
470 param {
471 lr_mult: 1
472 decay_mult: 1
473 }
474 param {
475 lr_mult: 2
476 decay_mult: 0
477 }
478 convolution_param {
479 num_output: 512
480 pad: 1
481 kernel_size: 3
482 stride: 1
483 weight_filler {
484 type: "gaussian"
485 std: 0.01
486 }
487 bias_filler {
488 type: "constant"
489 value: 0
490 }
491 }
492 }
493 layer {
494 name: "relu5_3"
495 type: "ReLU"
496 bottom: "conv5_3"
497 top: "conv5_3"
498 }
499 layer {
500 name: "pool5"
501 type: "Pooling"
502 bottom: "conv5_3"
503 top: "pool5"
504 pooling_param {
505 pool: MAX
506 kernel_size: 2
507 stride: 2
508 }
509 }
510 layer {
511 name: "fc6_cs"
512 type: "Convolution"
513 bottom: "pool5"
514 top: "fc6_cs"
515 param {
516 lr_mult: 1
517 decay_mult: 1
518 }
519 param {
520 lr_mult: 2
521 decay_mult: 0
522 }
523 convolution_param {
524 num_output: 4096
525 pad: 0
526 kernel_size: 7
527 stride: 1
528 weight_filler {
529 type: "gaussian"
530 std: 0.01
531 }
532 bias_filler {
533 type: "constant"
534 value: 0
535 }
536 }
537 }
538 layer {
539 name: "relu6_cs"
540 type: "ReLU"
541 bottom: "fc6_cs"
542 top: "fc6_cs"
543 }
544 layer {
545 name: "fc7_cs"
546 type: "Convolution"
547 bottom: "fc6_cs"
548 top: "fc7_cs"
549 param {
550 lr_mult: 1
551 decay_mult: 1
552 }
553 param {
554 lr_mult: 2
555 decay_mult: 0
556 }
557 convolution_param {
558 num_output: 4096
559 pad: 0
560 kernel_size: 1
561 stride: 1
562 weight_filler {
563 type: "gaussian"
564 std: 0.01
565 }
566 bias_filler {
567 type: "constant"
568 value: 0
569 }
570 }
571 }
572 layer {
573 name: "relu7_cs"
574 type: "ReLU"
575 bottom: "fc7_cs"
576 top: "fc7_cs"
577 }
578 layer {
579 name: "score_fr"
580 type: "Convolution"
581 bottom: "fc7_cs"
582 top: "score_fr"
583 param {
584 lr_mult: 1
585 decay_mult: 1
586 }
587 param {
588 lr_mult: 2
589 decay_mult: 0
590 }
591 convolution_param {
592 num_output: 20
593 pad: 0
594 kernel_size: 1
595 weight_filler {
596 type: "xavier"
597 }
598 bias_filler {
599 type: "constant"
600 }
601 }
602 }
603 layer {
604 name: "upscore2"
605 type: "Deconvolution"
606 bottom: "score_fr"
607 top: "upscore2"
608 param {
609 lr_mult: 1
610 }
611 convolution_param {
612 num_output: 20
613 bias_term: false
614 kernel_size: 4
615 stride: 2
616 weight_filler {
617 type: "xavier"
618 }
619 bias_filler {
620 type: "constant"
621 }
622 }
623 }
624 layer {
625 name: "score_pool4"
626 type: "Convolution"
627 bottom: "pool4"
628 top: "score_pool4"
629 param {
630 lr_mult: 1
631 decay_mult: 1
632 }
633 param {
634 lr_mult: 2
635 decay_mult: 0
636 }
637 convolution_param {
638 num_output: 20
639 pad: 0
640 kernel_size: 1
641 weight_filler {
642 type: "xavier"
643 }
644 bias_filler {
645 type: "constant"
646 }
647 }
648 }
649 layer {
650 name: "score_pool4c"
651 type: "Crop"
652 bottom: "score_pool4"
653 bottom: "upscore2"
654 top: "score_pool4c"
655 crop_param {
656 axis: 2
657 offset: 5
658 }
659 }
660 layer {
661 name: "fuse_pool4"
662 type: "Eltwise"
663 bottom: "upscore2"
664 bottom: "score_pool4c"
665 top: "fuse_pool4"
666 eltwise_param {
667 operation: SUM
668 }
669 }
670 layer {
671 name: "upscore_pool4"
672 type: "Deconvolution"
673 bottom: "fuse_pool4"
674 top: "upscore_pool4"
675 param {
676 lr_mult: 1
677 }
678 convolution_param {
679 num_output: 20
680 bias_term: false
681 kernel_size: 4
682 stride: 2
683 weight_filler {
684 type: "xavier"
685 }
686 bias_filler {
687 type: "constant"
688 }
689 }
690 }
691 layer {
692 name: "score_pool3"
693 type: "Convolution"
694 bottom: "pool3"
695 top: "score_pool3"
696 param {
697 lr_mult: 1
698 decay_mult: 1
699 }
700 param {
701 lr_mult: 2
702 decay_mult: 0
703 }
704 convolution_param {
705 num_output: 20
706 pad: 0
707 kernel_size: 1
708 weight_filler {
709 type: "xavier"
710 }
711 bias_filler {
712 type: "constant"
713 }
714 }
715 }
716 layer {
717 name: "score_pool3c"
718 type: "Crop"
719 bottom: "score_pool3"
720 bottom: "upscore_pool4"
721 top: "score_pool3c"
722 crop_param {
723 axis: 2
724 offset: 9
725 }
726 }
727 layer {
728 name: "fuse_pool3"
729 type: "Eltwise"
730 bottom: "upscore_pool4"
731 bottom: "score_pool3c"
732 top: "fuse_pool3"
733 eltwise_param {
734 operation: SUM
735 }
736 }
737 layer {
738 name: "upscore8"
739 type: "Deconvolution"
740 bottom: "fuse_pool3"
741 top: "upscore8"
742 param {
743 lr_mult: 1
744 }
745 convolution_param {
746 num_output: 20
747 bias_term: false
748 kernel_size: 16
749 stride: 8
750 weight_filler {
751 type: "xavier"
752 }
753 bias_filler {
754 type: "constant"
755 }
756 }
757 }
758 layer {
759 name: "score"
760 type: "Crop"
761 bottom: "upscore8"
762 bottom: "data"
763 top: "score"
764 crop_param {
765 axis: 2
766 offset: 31
767 }
768 }
0 # The following code is modified from https://github.com/shelhamer/clockwork-fcn
1 import sys
2 import os
3 import glob
4 import numpy as np
5 from PIL import Image
6
7
8 class cityscapes:
9 def __init__(self, data_path):
10 # data_path something like /data2/cityscapes
11 self.dir = data_path
12 self.classes = ['road', 'sidewalk', 'building', 'wall', 'fence',
13 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain',
14 'sky', 'person', 'rider', 'car', 'truck',
15 'bus', 'train', 'motorcycle', 'bicycle']
16 self.mean = np.array((72.78044, 83.21195, 73.45286), dtype=np.float32)
17 # import cityscapes label helper and set up label mappings
18 sys.path.insert(0, '{}/scripts/helpers/'.format(self.dir))
19 labels = __import__('labels')
20 self.id2trainId = {label.id: label.trainId for label in labels.labels} # dictionary mapping from raw IDs to train IDs
21 self.trainId2color = {label.trainId: label.color for label in labels.labels} # dictionary mapping train IDs to colors as 3-tuples
22
23 def get_dset(self, split):
24 '''
25 List images as (city, id) for the specified split
26
27 TODO(shelhamer) generate splits from cityscapes itself, instead of
28 relying on these separately made text files.
29 '''
30 if split == 'train':
31 dataset = open('{}/ImageSets/segFine/train.txt'.format(self.dir)).read().splitlines()
32 else:
33 dataset = open('{}/ImageSets/segFine/val.txt'.format(self.dir)).read().splitlines()
34 return [(item.split('/')[0], item.split('/')[1]) for item in dataset]
35
36 def load_image(self, split, city, idx):
37 im = Image.open('{}/leftImg8bit_sequence/{}/{}/{}_leftImg8bit.png'.format(self.dir, split, city, idx))
38 return im
39
40 def assign_trainIds(self, label):
41 """
42 Map the given label IDs to the train IDs appropriate for training
43 Use the label mapping provided in labels.py from the cityscapes scripts
44 """
45 label = np.array(label, dtype=np.float32)
46 if sys.version_info[0] < 3:
47 for k, v in self.id2trainId.iteritems():
48 label[label == k] = v
49 else:
50 for k, v in self.id2trainId.items():
51 label[label == k] = v
52 return label
53
54 def load_label(self, split, city, idx):
55 """
56 Load label image as 1 x height x width integer array of label indices.
57 The leading singleton dimension is required by the loss.
58 """
59 label = Image.open('{}/gtFine/{}/{}/{}_gtFine_labelIds.png'.format(self.dir, split, city, idx))
60 label = self.assign_trainIds(label) # get proper labels for eval
61 label = np.array(label, dtype=np.uint8)
62 label = label[np.newaxis, ...]
63 return label
64
65 def preprocess(self, im):
66 """
67 Preprocess loaded image (by load_image) for Caffe:
68 - cast to float
69 - switch channels RGB -> BGR
70 - subtract mean
71 - transpose to channel x height x width order
72 """
73 in_ = np.array(im, dtype=np.float32)
74 in_ = in_[:, :, ::-1]
75 in_ -= self.mean
76 in_ = in_.transpose((2, 0, 1))
77 return in_
78
79 def palette(self, label):
80 '''
81 Map trainIds to colors as specified in labels.py
82 '''
83 if label.ndim == 3:
84 label = label[0]
85 color = np.empty((label.shape[0], label.shape[1], 3))
86 if sys.version_info[0] < 3:
87 for k, v in self.trainId2color.iteritems():
88 color[label == k, :] = v
89 else:
90 for k, v in self.trainId2color.items():
91 color[label == k, :] = v
92 return color
93
94 def make_boundaries(label, thickness=None):
95 """
96 Input is an image label, output is a numpy array mask encoding the boundaries of the objects
97 Extract pixels at the true boundary by dilation - erosion of label.
98 Don't just pick the void label as it is not exclusive to the boundaries.
99 """
100 assert(thickness is not None)
101 import skimage.morphology as skm
102 void = 255
103 mask = np.logical_and(label > 0, label != void)[0]
104 selem = skm.disk(thickness)
105 boundaries = np.logical_xor(skm.dilation(mask, selem),
106 skm.erosion(mask, selem))
107 return boundaries
108
109 def list_label_frames(self, split):
110 """
111 Select labeled frames from a split for evaluation
112 collected as (city, shot, idx) tuples
113 """
114 def file2idx(f):
115 """Helper to convert file path into frame ID"""
116 city, shot, frame = (os.path.basename(f).split('_')[:3])
117 return "_".join([city, shot, frame])
118 frames = []
119 cities = [os.path.basename(f) for f in glob.glob('{}/gtFine/{}/*'.format(self.dir, split))]
120 for c in cities:
121 files = sorted(glob.glob('{}/gtFine/{}/{}/*labelIds.png'.format(self.dir, split, c)))
122 frames.extend([file2idx(f) for f in files])
123 return frames
124
125 def collect_frame_sequence(self, split, idx, length):
126 """
127 Collect sequence of frames preceding (and including) a labeled frame
128 as a list of Images.
129
130 Note: 19 preceding frames are provided for each labeled frame.
131 """
132 SEQ_LEN = length
133 city, shot, frame = idx.split('_')
134 frame = int(frame)
135 frame_seq = []
136 for i in range(frame - SEQ_LEN, frame + 1):
137 frame_path = '{0}/leftImg8bit_sequence/val/{1}/{1}_{2}_{3:0>6d}_leftImg8bit.png'.format(
138 self.dir, city, shot, i)
139 frame_seq.append(Image.open(frame_path))
140 return frame_seq
0 URL=http://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/fcn-8s-cityscapes/fcn-8s-cityscapes.caffemodel
1 OUTPUT_FILE=./scripts/eval_cityscapes/caffemodel/fcn-8s-cityscapes.caffemodel
2 wget -N $URL -O $OUTPUT_FILE
0 import os
1 import caffe
2 import argparse
3 import numpy as np
4 import scipy.misc
5 from PIL import Image
6 from util import segrun, fast_hist, get_scores
7 from cityscapes import cityscapes
8
9 parser = argparse.ArgumentParser()
10 parser.add_argument("--cityscapes_dir", type=str, required=True, help="Path to the original cityscapes dataset")
11 parser.add_argument("--result_dir", type=str, required=True, help="Path to the generated images to be evaluated")
12 parser.add_argument("--output_dir", type=str, required=True, help="Where to save the evaluation results")
13 parser.add_argument("--caffemodel_dir", type=str, default='./scripts/eval_cityscapes/caffemodel/', help="Where the FCN-8s caffemodel stored")
14 parser.add_argument("--gpu_id", type=int, default=0, help="Which gpu id to use")
15 parser.add_argument("--split", type=str, default='val', help="Data split to be evaluated")
16 parser.add_argument("--save_output_images", type=int, default=0, help="Whether to save the FCN output images")
17 args = parser.parse_args()
18
19
20 def main():
21 if not os.path.isdir(args.output_dir):
22 os.makedirs(args.output_dir)
23 if args.save_output_images > 0:
24 output_image_dir = args.output_dir + 'image_outputs/'
25 if not os.path.isdir(output_image_dir):
26 os.makedirs(output_image_dir)
27 CS = cityscapes(args.cityscapes_dir)
28 n_cl = len(CS.classes)
29 label_frames = CS.list_label_frames(args.split)
30 caffe.set_device(args.gpu_id)
31 caffe.set_mode_gpu()
32 net = caffe.Net(args.caffemodel_dir + '/deploy.prototxt',
33 args.caffemodel_dir + 'fcn-8s-cityscapes.caffemodel',
34 caffe.TEST)
35
36 hist_perframe = np.zeros((n_cl, n_cl))
37 for i, idx in enumerate(label_frames):
38 if i % 10 == 0:
39 print('Evaluating: %d/%d' % (i, len(label_frames)))
40 city = idx.split('_')[0]
41 # idx is city_shot_frame
42 label = CS.load_label(args.split, city, idx)
43 im_file = args.result_dir + '/' + idx + '_leftImg8bit.png'
44 im = np.array(Image.open(im_file))
45 im = scipy.misc.imresize(im, (label.shape[1], label.shape[2]))
46 out = segrun(net, CS.preprocess(im))
47 hist_perframe += fast_hist(label.flatten(), out.flatten(), n_cl)
48 if args.save_output_images > 0:
49 label_im = CS.palette(label)
50 pred_im = CS.palette(out)
51 scipy.misc.imsave(output_image_dir + '/' + str(i) + '_pred.jpg', pred_im)
52 scipy.misc.imsave(output_image_dir + '/' + str(i) + '_gt.jpg', label_im)
53 scipy.misc.imsave(output_image_dir + '/' + str(i) + '_input.jpg', im)
54
55 mean_pixel_acc, mean_class_acc, mean_class_iou, per_class_acc, per_class_iou = get_scores(hist_perframe)
56 with open(args.output_dir + '/evaluation_results.txt', 'w') as f:
57 f.write('Mean pixel accuracy: %f\n' % mean_pixel_acc)
58 f.write('Mean class accuracy: %f\n' % mean_class_acc)
59 f.write('Mean class IoU: %f\n' % mean_class_iou)
60 f.write('************ Per class numbers below ************\n')
61 for i, cl in enumerate(CS.classes):
62 while len(cl) < 15:
63 cl = cl + ' '
64 f.write('%s: acc = %f, iou = %f\n' % (cl, per_class_acc[i], per_class_iou[i]))
65
66
67 main()
0 # The following code is modified from https://github.com/shelhamer/clockwork-fcn
1 import numpy as np
2
3
4 def get_out_scoremap(net):
5 return net.blobs['score'].data[0].argmax(axis=0).astype(np.uint8)
6
7
8 def feed_net(net, in_):
9 """
10 Load prepared input into net.
11 """
12 net.blobs['data'].reshape(1, *in_.shape)
13 net.blobs['data'].data[...] = in_
14
15
16 def segrun(net, in_):
17 feed_net(net, in_)
18 net.forward()
19 return get_out_scoremap(net)
20
21
22 def fast_hist(a, b, n):
23 k = np.where((a >= 0) & (a < n))[0]
24 bc = np.bincount(n * a[k].astype(int) + b[k], minlength=n**2)
25 if len(bc) != n**2:
26 # ignore this example if dimension mismatch
27 return 0
28 return bc.reshape(n, n)
29
30
31 def get_scores(hist):
32 # Mean pixel accuracy
33 acc = np.diag(hist).sum() / (hist.sum() + 1e-12)
34
35 # Per class accuracy
36 cl_acc = np.diag(hist) / (hist.sum(1) + 1e-12)
37
38 # Per class IoU
39 iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-12)
40
41 return acc, np.nanmean(cl_acc), np.nanmean(iu), cl_acc, iu
0 set -ex
1 pip install visdom
2 pip install dominate
0 # Simple script to make sure basic usage
1 # such as training, testing, saving and loading
2 # runs without errors.
3 import os
4
5
6 def run(command):
7 print(command)
8 exit_status = os.system(command)
9 if exit_status > 0:
10 exit(1)
11
12
13 if __name__ == '__main__':
14 # download mini datasets
15 if not os.path.exists('./datasets/mini'):
16 run('bash ./datasets/download_cyclegan_dataset.sh mini')
17
18 if not os.path.exists('./datasets/mini_pix2pix'):
19 run('bash ./datasets/download_cyclegan_dataset.sh mini_pix2pix')
20
21 # pretrained cyclegan model
22 if not os.path.exists('./checkpoints/horse2zebra_pretrained/latest_net_G.pth'):
23 run('bash ./scripts/download_cyclegan_model.sh horse2zebra')
24 run('python test.py --model test --dataroot ./datasets/mini --name horse2zebra_pretrained --no_dropout --num_test 1 --no_dropout')
25
26 # pretrained pix2pix model
27 if not os.path.exists('./checkpoints/facades_label2photo_pretrained/latest_net_G.pth'):
28 run('bash ./scripts/download_pix2pix_model.sh facades_label2photo')
29 if not os.path.exists('./datasets/facades'):
30 run('bash ./datasets/download_pix2pix_dataset.sh facades')
31 run('python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_label2photo_pretrained --num_test 1')
32
33 # cyclegan train/test
34 run('python train.py --model cycle_gan --name temp_cyclegan --dataroot ./datasets/mini --niter 1 --niter_decay 0 --save_latest_freq 10 --print_freq 1 --display_id -1')
35 run('python test.py --model test --name temp_cyclegan --dataroot ./datasets/mini --num_test 1 --model_suffix "_A" --no_dropout')
36
37 # pix2pix train/test
38 run('python train.py --model pix2pix --name temp_pix2pix --dataroot ./datasets/mini_pix2pix --niter 1 --niter_decay 5 --save_latest_freq 10 --display_id -1')
39 run('python test.py --model pix2pix --name temp_pix2pix --dataroot ./datasets/mini_pix2pix --num_test 1')
40
41 # template train/test
42 run('python train.py --model template --name temp2 --dataroot ./datasets/mini_pix2pix --niter 1 --niter_decay 0 --save_latest_freq 10 --display_id -1')
43 run('python test.py --model template --name temp2 --dataroot ./datasets/mini_pix2pix --num_test 1')
44
45 # colorization train/test (optional)
46 if not os.path.exists('./datasets/mini_colorization'):
47 run('bash ./datasets/download_cyclegan_dataset.sh mini_colorization')
48
49 run('python train.py --model colorization --name temp_color --dataroot ./datasets/mini_colorization --niter 1 --niter_decay 0 --save_latest_freq 5 --display_id -1')
50 run('python test.py --model colorization --name temp_color --dataroot ./datasets/mini_colorization --num_test 1')
0 set -ex
1 python test.py --dataroot ./datasets/colorization --name color_pix2pix --model colorization
0 set -ex
1 python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --phase test --no_dropout
0 set -ex
1 python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --netG unet_256 --direction BtoA --dataset_mode aligned --norm batch
0 set -ex
1 python test.py --dataroot ./datasets/facades/testB/ --name facades_pix2pix --model test --netG unet_256 --direction BtoA --dataset_mode single --norm batch
0 set -ex
1 python train.py --dataroot ./datasets/colorization --name color_pix2pix --model colorization
0 set -ex
1 python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --pool_size 50 --no_dropout
0 set -ex
1 python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --netG unet_256 --direction BtoA --lambda_L1 100 --dataset_mode aligned --norm batch --pool_size 0
0 """General-purpose test script for image-to-image translation.
1
2 Once you have trained your model with train.py, you can use this script to test the model.
3 It will load a saved model from --checkpoints_dir and save the results to --results_dir.
4
5 It first creates model and dataset given the option. It will hard-code some parameters.
6 It then runs inference for --num_test images and save results to an HTML file.
7
8 Example (You need to train models first or download pre-trained models from our website):
9 Test a CycleGAN model (both sides):
10 python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
11
12 Test a CycleGAN model (one side only):
13 python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
14
15 The option '--model test' is used for generating CycleGAN results only for one side.
16 This option will automatically set '--dataset_mode single', which only loads the images from one set.
17 On the contrary, using '--model cycle_gan' requires loading and generating results in both directions,
18 which is sometimes unnecessary. The results will be saved at ./results/.
19 Use '--results_dir <directory_path_to_save_result>' to specify the results directory.
20
21 Test a pix2pix model:
22 python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
23
24 See options/base_options.py and options/test_options.py for more test options.
25 See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
26 See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
27 """
28 import os
29 from options.test_options import TestOptions
30 from data import create_dataset
31 from models import create_model
32 from util.visualizer import save_images
33 from util import html
34
35
36 if __name__ == '__main__':
37 opt = TestOptions().parse() # get test options
38 # hard-code some parameters for test
39 opt.num_threads = 0 # test code only supports num_threads = 1
40 opt.batch_size = 1 # test code only supports batch_size = 1
41 opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
42 opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
43 opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
44 dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
45 model = create_model(opt) # create a model given opt.model and other options
46 model.setup(opt) # regular setup: load and print networks; create schedulers
47 # create a website
48 web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # define the website directory
49 webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
50 # test with eval mode. This only affects layers like batchnorm and dropout.
51 # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
52 # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
53 if opt.eval:
54 model.eval()
55 for i, data in enumerate(dataset):
56 if i >= opt.num_test: # only apply our model to opt.num_test images.
57 break
58 model.set_input(data) # unpack data from data loader
59 model.test() # run inference
60 visuals = model.get_current_visuals() # get image results
61 img_path = model.get_image_paths() # get image paths
62 if i % 5 == 0: # save images to an HTML file
63 print('processing (%04d)-th image... %s' % (i, img_path))
64 save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
65 webpage.save() # save the HTML
0 """General-purpose training script for image-to-image translation.
1
2 This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and
3 different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization).
4 You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model').
5
6 It first creates model, dataset, and visualizer given the option.
7 It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models.
8 The script supports continue/resume training. Use '--continue_train' to resume your previous training.
9
10 Example:
11 Train a CycleGAN model:
12 python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
13 Train a pix2pix model:
14 python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
15
16 See options/base_options.py and options/train_options.py for more training options.
17 See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
18 See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
19 """
20 import time
21 from options.train_options import TrainOptions
22 from data import create_dataset
23 from models import create_model
24 from util.visualizer import Visualizer
25
26 if __name__ == '__main__':
27 opt = TrainOptions().parse() # get training options
28 dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
29 dataset_size = len(dataset) # get the number of images in the dataset.
30 print('The number of training images = %d' % dataset_size)
31
32 model = create_model(opt) # create a model given opt.model and other options
33 model.setup(opt) # regular setup: load and print networks; create schedulers
34 visualizer = Visualizer(opt) # create a visualizer that display/save images and plots
35 total_iters = 0 # the total number of training iterations
36
37 for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
38 epoch_start_time = time.time() # timer for entire epoch
39 iter_data_time = time.time() # timer for data loading per iteration
40 epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
41
42 for i, data in enumerate(dataset): # inner loop within one epoch
43 iter_start_time = time.time() # timer for computation per iteration
44 if total_iters % opt.print_freq == 0:
45 t_data = iter_start_time - iter_data_time
46 visualizer.reset()
47 total_iters += opt.batch_size
48 epoch_iter += opt.batch_size
49 model.set_input(data) # unpack data from dataset and apply preprocessing
50 model.optimize_parameters() # calculate loss functions, get gradients, update network weights
51
52 if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file
53 save_result = total_iters % opt.update_html_freq == 0
54 model.compute_visuals()
55 visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
56
57 if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk
58 losses = model.get_current_losses()
59 t_comp = (time.time() - iter_start_time) / opt.batch_size
60 visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
61 if opt.display_id > 0:
62 visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
63
64 if total_iters % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations
65 print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
66 save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
67 model.save_networks(save_suffix)
68
69 iter_data_time = time.time()
70 if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
71 print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
72 model.save_networks('latest')
73 model.save_networks(epoch)
74
75 print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
76 model.update_learning_rate() # update learning rates at the end of every epoch.
0 """This package includes a miscellaneous collection of useful helper functions."""
0 from __future__ import print_function
1 import os
2 import tarfile
3 import requests
4 from warnings import warn
5 from zipfile import ZipFile
6 from bs4 import BeautifulSoup
7 from os.path import abspath, isdir, join, basename
8
9
10 class GetData(object):
11 """A Python script for downloading CycleGAN or pix2pix datasets.
12
13 Parameters:
14 technique (str) -- One of: 'cyclegan' or 'pix2pix'.
15 verbose (bool) -- If True, print additional information.
16
17 Examples:
18 >>> from util.get_data import GetData
19 >>> gd = GetData(technique='cyclegan')
20 >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
21
22 Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
23 and 'scripts/download_cyclegan_model.sh'.
24 """
25
26 def __init__(self, technique='cyclegan', verbose=True):
27 url_dict = {
28 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
29 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
30 }
31 self.url = url_dict.get(technique.lower())
32 self._verbose = verbose
33
34 def _print(self, text):
35 if self._verbose:
36 print(text)
37
38 @staticmethod
39 def _get_options(r):
40 soup = BeautifulSoup(r.text, 'lxml')
41 options = [h.text for h in soup.find_all('a', href=True)
42 if h.text.endswith(('.zip', 'tar.gz'))]
43 return options
44
45 def _present_options(self):
46 r = requests.get(self.url)
47 options = self._get_options(r)
48 print('Options:\n')
49 for i, o in enumerate(options):
50 print("{0}: {1}".format(i, o))
51 choice = input("\nPlease enter the number of the "
52 "dataset above you wish to download:")
53 return options[int(choice)]
54
55 def _download_data(self, dataset_url, save_path):
56 if not isdir(save_path):
57 os.makedirs(save_path)
58
59 base = basename(dataset_url)
60 temp_save_path = join(save_path, base)
61
62 with open(temp_save_path, "wb") as f:
63 r = requests.get(dataset_url)
64 f.write(r.content)
65
66 if base.endswith('.tar.gz'):
67 obj = tarfile.open(temp_save_path)
68 elif base.endswith('.zip'):
69 obj = ZipFile(temp_save_path, 'r')
70 else:
71 raise ValueError("Unknown File Type: {0}.".format(base))
72
73 self._print("Unpacking Data...")
74 obj.extractall(save_path)
75 obj.close()
76 os.remove(temp_save_path)
77
78 def get(self, save_path, dataset=None):
79 """
80
81 Download a dataset.
82
83 Parameters:
84 save_path (str) -- A directory to save the data to.
85 dataset (str) -- (optional). A specific dataset to download.
86 Note: this must include the file extension.
87 If None, options will be presented for you
88 to choose from.
89
90 Returns:
91 save_path_full (str) -- the absolute path to the downloaded data.
92
93 """
94 if dataset is None:
95 selected_dataset = self._present_options()
96 else:
97 selected_dataset = dataset
98
99 save_path_full = join(save_path, selected_dataset.split('.')[0])
100
101 if isdir(save_path_full):
102 warn("\n'{0}' already exists. Voiding Download.".format(
103 save_path_full))
104 else:
105 self._print('Downloading Data...')
106 url = "{0}/{1}".format(self.url, selected_dataset)
107 self._download_data(url, save_path=save_path)
108
109 return abspath(save_path_full)
0 import dominate
1 from dominate.tags import meta, h3, table, tr, td, p, a, img, br
2 import os
3
4
5 class HTML:
6 """This HTML class allows us to save images and write texts into a single HTML file.
7
8 It consists of functions such as <add_header> (add a text header to the HTML file),
9 <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
10 It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
11 """
12
13 def __init__(self, web_dir, title, refresh=0):
14 """Initialize the HTML classes
15
16 Parameters:
17 web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
18 title (str) -- the webpage name
19 refresh (int) -- how often the website refresh itself; if 0; no refreshing
20 """
21 self.title = title
22 self.web_dir = web_dir
23 self.img_dir = os.path.join(self.web_dir, 'images')
24 if not os.path.exists(self.web_dir):
25 os.makedirs(self.web_dir)
26 if not os.path.exists(self.img_dir):
27 os.makedirs(self.img_dir)
28
29 self.doc = dominate.document(title=title)
30 if refresh > 0:
31 with self.doc.head:
32 meta(http_equiv="refresh", content=str(refresh))
33
34 def get_image_dir(self):
35 """Return the directory that stores images"""
36 return self.img_dir
37
38 def add_header(self, text):
39 """Insert a header to the HTML file
40
41 Parameters:
42 text (str) -- the header text
43 """
44 with self.doc:
45 h3(text)
46
47 def add_images(self, ims, txts, links, width=400):
48 """add images to the HTML file
49
50 Parameters:
51 ims (str list) -- a list of image paths
52 txts (str list) -- a list of image names shown on the website
53 links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
54 """
55 self.t = table(border=1, style="table-layout: fixed;") # Insert a table
56 self.doc.add(self.t)
57 with self.t:
58 with tr():
59 for im, txt, link in zip(ims, txts, links):
60 with td(style="word-wrap: break-word;", halign="center", valign="top"):
61 with p():
62 with a(href=os.path.join('images', link)):
63 img(style="width:%dpx" % width, src=os.path.join('images', im))
64 br()
65 p(txt)
66
67 def save(self):
68 """save the current content to the HMTL file"""
69 html_file = '%s/index.html' % self.web_dir
70 f = open(html_file, 'wt')
71 f.write(self.doc.render())
72 f.close()
73
74
75 if __name__ == '__main__': # we show an example usage here.
76 html = HTML('web/', 'test_html')
77 html.add_header('hello world')
78
79 ims, txts, links = [], [], []
80 for n in range(4):
81 ims.append('image_%d.png' % n)
82 txts.append('text_%d' % n)
83 links.append('image_%d.png' % n)
84 html.add_images(ims, txts, links)
85 html.save()
0 import random
1 import torch
2
3
4 class ImagePool():
5 """This class implements an image buffer that stores previously generated images.
6
7 This buffer enables us to update discriminators using a history of generated images
8 rather than the ones produced by the latest generators.
9 """
10
11 def __init__(self, pool_size):
12 """Initialize the ImagePool class
13
14 Parameters:
15 pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
16 """
17 self.pool_size = pool_size
18 if self.pool_size > 0: # create an empty pool
19 self.num_imgs = 0
20 self.images = []
21
22 def query(self, images):
23 """Return an image from the pool.
24
25 Parameters:
26 images: the latest generated images from the generator
27
28 Returns images from the buffer.
29
30 By 50/100, the buffer will return input images.
31 By 50/100, the buffer will return images previously stored in the buffer,
32 and insert the current images to the buffer.
33 """
34 if self.pool_size == 0: # if the buffer size is 0, do nothing
35 return images
36 return_images = []
37 for image in images:
38 image = torch.unsqueeze(image.data, 0)
39 if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
40 self.num_imgs = self.num_imgs + 1
41 self.images.append(image)
42 return_images.append(image)
43 else:
44 p = random.uniform(0, 1)
45 if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
46 random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
47 tmp = self.images[random_id].clone()
48 self.images[random_id] = image
49 return_images.append(tmp)
50 else: # by another 50% chance, the buffer will return the current image
51 return_images.append(image)
52 return_images = torch.cat(return_images, 0) # collect all the images and return
53 return return_images
0 """This module contains simple helper functions """
1 from __future__ import print_function
2 import torch
3 import numpy as np
4 from PIL import Image
5 import os
6
7
8 def tensor2im(input_image, imtype=np.uint8):
9 """"Converts a Tensor array into a numpy image array.
10
11 Parameters:
12 input_image (tensor) -- the input image tensor array
13 imtype (type) -- the desired type of the converted numpy array
14 """
15 if not isinstance(input_image, np.ndarray):
16 if isinstance(input_image, torch.Tensor): # get the data from a variable
17 image_tensor = input_image.data
18 else:
19 return input_image
20 image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
21 if image_numpy.shape[0] == 1: # grayscale to RGB
22 image_numpy = np.tile(image_numpy, (3, 1, 1))
23 image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
24 else: # if it is a numpy array, do nothing
25 image_numpy = input_image
26 return image_numpy.astype(imtype)
27
28
29 def diagnose_network(net, name='network'):
30 """Calculate and print the mean of average absolute(gradients)
31
32 Parameters:
33 net (torch network) -- Torch network
34 name (str) -- the name of the network
35 """
36 mean = 0.0
37 count = 0
38 for param in net.parameters():
39 if param.grad is not None:
40 mean += torch.mean(torch.abs(param.grad.data))
41 count += 1
42 if count > 0:
43 mean = mean / count
44 print(name)
45 print(mean)
46
47
48 def save_image(image_numpy, image_path, aspect_ratio=1.0):
49 """Save a numpy image to the disk
50
51 Parameters:
52 image_numpy (numpy array) -- input numpy array
53 image_path (str) -- the path of the image
54 """
55
56 image_pil = Image.fromarray(image_numpy)
57 h, w, _ = image_numpy.shape
58
59 if aspect_ratio > 1.0:
60 image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
61 if aspect_ratio < 1.0:
62 image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
63 image_pil.save(image_path)
64
65
66 def print_numpy(x, val=True, shp=False):
67 """Print the mean, min, max, median, std, and size of a numpy array
68
69 Parameters:
70 val (bool) -- if print the values of the numpy array
71 shp (bool) -- if print the shape of the numpy array
72 """
73 x = x.astype(np.float64)
74 if shp:
75 print('shape,', x.shape)
76 if val:
77 x = x.flatten()
78 print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
79 np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
80
81
82 def mkdirs(paths):
83 """create empty directories if they don't exist
84
85 Parameters:
86 paths (str list) -- a list of directory paths
87 """
88 if isinstance(paths, list) and not isinstance(paths, str):
89 for path in paths:
90 mkdir(path)
91 else:
92 mkdir(paths)
93
94
95 def mkdir(path):
96 """create a single empty directory if it didn't exist
97
98 Parameters:
99 path (str) -- a single directory path
100 """
101 if not os.path.exists(path):
102 os.makedirs(path)
0 import numpy as np
1 import os
2 import sys
3 import ntpath
4 import time
5 from . import util, html
6 from subprocess import Popen, PIPE
7
8
9 if sys.version_info[0] == 2:
10 VisdomExceptionBase = Exception
11 else:
12 VisdomExceptionBase = ConnectionError
13
14
15 def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
16 """Save images to the disk.
17
18 Parameters:
19 webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
20 visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
21 image_path (str) -- the string is used to create image paths
22 aspect_ratio (float) -- the aspect ratio of saved images
23 width (int) -- the images will be resized to width x width
24
25 This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
26 """
27 image_dir = webpage.get_image_dir()
28 short_path = ntpath.basename(image_path[0])
29 name = os.path.splitext(short_path)[0]
30
31 webpage.add_header(name)
32 ims, txts, links = [], [], []
33
34 for label, im_data in visuals.items():
35 im = util.tensor2im(im_data)
36 image_name = '%s_%s.png' % (name, label)
37 save_path = os.path.join(image_dir, image_name)
38 util.save_image(im, save_path, aspect_ratio=aspect_ratio)
39 ims.append(image_name)
40 txts.append(label)
41 links.append(image_name)
42 webpage.add_images(ims, txts, links, width=width)
43
44
45 class Visualizer():
46 """This class includes several functions that can display/save images and print/save logging information.
47
48 It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
49 """
50
51 def __init__(self, opt):
52 """Initialize the Visualizer class
53
54 Parameters:
55 opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
56 Step 1: Cache the training/test options
57 Step 2: connect to a visdom server
58 Step 3: create an HTML object for saveing HTML filters
59 Step 4: create a logging file to store training losses
60 """
61 self.opt = opt # cache the option
62 self.display_id = opt.display_id
63 self.use_html = opt.isTrain and not opt.no_html
64 self.win_size = opt.display_winsize
65 self.name = opt.name
66 self.port = opt.display_port
67 self.saved = False
68 if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
69 import visdom
70 self.ncols = opt.display_ncols
71 self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
72 if not self.vis.check_connection():
73 self.create_visdom_connections()
74
75 if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
76 self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
77 self.img_dir = os.path.join(self.web_dir, 'images')
78 print('create web directory %s...' % self.web_dir)
79 util.mkdirs([self.web_dir, self.img_dir])
80 # create a logging file to store training losses
81 self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
82 with open(self.log_name, "a") as log_file:
83 now = time.strftime("%c")
84 log_file.write('================ Training Loss (%s) ================\n' % now)
85
86 def reset(self):
87 """Reset the self.saved status"""
88 self.saved = False
89
90 def create_visdom_connections(self):
91 """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
92 cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
93 print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
94 print('Command: %s' % cmd)
95 Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
96
97 def display_current_results(self, visuals, epoch, save_result):
98 """Display current results on visdom; save current results to an HTML file.
99
100 Parameters:
101 visuals (OrderedDict) - - dictionary of images to display or save
102 epoch (int) - - the current epoch
103 save_result (bool) - - if save the current results to an HTML file
104 """
105 if self.display_id > 0: # show images in the browser using visdom
106 ncols = self.ncols
107 if ncols > 0: # show all the images in one visdom panel
108 ncols = min(ncols, len(visuals))
109 h, w = next(iter(visuals.values())).shape[:2]
110 table_css = """<style>
111 table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
112 table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
113 </style>""" % (w, h) # create a table css
114 # create a table of images.
115 title = self.name
116 label_html = ''
117 label_html_row = ''
118 images = []
119 idx = 0
120 for label, image in visuals.items():
121 image_numpy = util.tensor2im(image)
122 label_html_row += '<td>%s</td>' % label
123 images.append(image_numpy.transpose([2, 0, 1]))
124 idx += 1
125 if idx % ncols == 0:
126 label_html += '<tr>%s</tr>' % label_html_row
127 label_html_row = ''
128 white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
129 while idx % ncols != 0:
130 images.append(white_image)
131 label_html_row += '<td></td>'
132 idx += 1
133 if label_html_row != '':
134 label_html += '<tr>%s</tr>' % label_html_row
135 try:
136 self.vis.images(images, nrow=ncols, win=self.display_id + 1,
137 padding=2, opts=dict(title=title + ' images'))
138 label_html = '<table>%s</table>' % label_html
139 self.vis.text(table_css + label_html, win=self.display_id + 2,
140 opts=dict(title=title + ' labels'))
141 except VisdomExceptionBase:
142 self.create_visdom_connections()
143
144 else: # show each image in a separate visdom panel;
145 idx = 1
146 try:
147 for label, image in visuals.items():
148 image_numpy = util.tensor2im(image)
149 self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
150 win=self.display_id + idx)
151 idx += 1
152 except VisdomExceptionBase:
153 self.create_visdom_connections()
154
155 if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
156 self.saved = True
157 # save images to the disk
158 for label, image in visuals.items():
159 image_numpy = util.tensor2im(image)
160 img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
161 util.save_image(image_numpy, img_path)
162
163 # update website
164 webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
165 for n in range(epoch, 0, -1):
166 webpage.add_header('epoch [%d]' % n)
167 ims, txts, links = [], [], []
168
169 for label, image_numpy in visuals.items():
170 image_numpy = util.tensor2im(image)
171 img_path = 'epoch%.3d_%s.png' % (n, label)
172 ims.append(img_path)
173 txts.append(label)
174 links.append(img_path)
175 webpage.add_images(ims, txts, links, width=self.win_size)
176 webpage.save()
177
178 def plot_current_losses(self, epoch, counter_ratio, losses):
179 """display the current losses on visdom display: dictionary of error labels and values
180
181 Parameters:
182 epoch (int) -- current epoch
183 counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
184 losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
185 """
186 if not hasattr(self, 'plot_data'):
187 self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
188 self.plot_data['X'].append(epoch + counter_ratio)
189 self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
190 try:
191 self.vis.line(
192 X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
193 Y=np.array(self.plot_data['Y']),
194 opts={
195 'title': self.name + ' loss over time',
196 'legend': self.plot_data['legend'],
197 'xlabel': 'epoch',
198 'ylabel': 'loss'},
199 win=self.display_id)
200 except VisdomExceptionBase:
201 self.create_visdom_connections()
202
203 # losses: same format as |losses| of plot_current_losses
204 def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
205 """print current losses on console; also save the losses to the disk
206
207 Parameters:
208 epoch (int) -- current epoch
209 iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
210 losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
211 t_comp (float) -- computational time per data point (normalized by batch_size)
212 t_data (float) -- data loading time per data point (normalized by batch_size)
213 """
214 message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
215 for k, v in losses.items():
216 message += '%s: %.3f ' % (k, v)
217
218 print(message) # print the message
219 with open(self.log_name, "a") as log_file:
220 log_file.write('%s\n' % message) # save the message