Merge remote-tracking branch 'origin/gan_lab' into gan_lab

This commit is contained in:
James Betker 2020-10-24 11:57:02 -06:00
commit ee6216966c

View File

@ -1,61 +1,71 @@
# MMSR
# Deep Learning Art School
MMSR is an open source image and video super-resolution toolbox based on PyTorch. It is a part of the [open-mmlab](https://github.com/open-mmlab) project developed by [Multimedia Laboratory, CUHK](http://mmlab.ie.cuhk.edu.hk). MMSR is based on our previous projects: [BasicSR](https://github.com/xinntao/BasicSR), [ESRGAN](https://github.com/xinntao/ESRGAN), and [EDVR](https://github.com/xinntao/EDVR).
Send your Pytorch model to art class!
## My (@neonbjb) Modifications
After tinkering with MMSR, I really began to like a lot about how the codebase was laid out and the general practices being used. I have since worked to extend it to more
general use cases, as well as implement several GAN training features. The additions are too many to list, but I'll give it a shot:
This repository is both a framework and a set of tools for training deep neural networks that create images. It started as a branch of the [open-mmlab](https://github.com/open-mmlab) project developed by [Multimedia Laboratory, CUHK](http://mmlab.ie.cuhk.edu.hk) but has been almost completely re-written at every level.
- FP16 support.
- Alternative dataset support (notably a disjoint dataset for training a generator to style-transfer between imagesets).
- Addition of several new architectures, including a ResNet-based discrimator, a downsampling generator (for training image corruptors), and a fix-and-upsample generator.
- Fixup resblock support which resists the exploding gradients which necessitate batch norms. Most of the fixup architectures can be trained with BN turned off, though they
take longer to train and are occasionally divergent in FP16 mode.
- Batch testing for performing generator augmentation on large sets of images.
- Model swapout during training - randomly select a past D or G and substitute it in for a short time to increase variance on the respective model.
- Adding random noise on both the inputs of the discriminator and generator. The discriminator variety has a decay.
- Decaying the influence of the feature loss.
- "Corruption" generators which can alter an input before it is fed through the SRGAN pipeline.
- Outputting "state" images which are very useful in debugging what is actually going on in the pipeline.
- Skip layers between the generator and discriminator.
- Support for any number of image resolutions into the discriminators. The original MMSR only accepted 128x128 images.
- "Megabatches" - gradient accumulation across multiple batches before performing an optimizer step.
- Image cropping can be disabled. I prefer to do this in preprocessing.
- Tensorboard logs for an experiment are cleared out when the experiment is restarted anew.
- A LOT more data is logged to tensorboard.
## Why do we need another training framework
Note that this codebase is far from clean. I've notably broken LMDB support in a couple of places. Likely everything other than SRGAN doesn't work too well anymore either.
I will get around to documenting all this in the near future once the repo stabilizes a bit. For now, you're on your own!
These are a dime a dozen, no doubt. DL Art School (*DLAS*) differentiates itself by being configuration driven. You write the model code (specifically, a torch.nn.Module) and (possibly) some losses, then you cobble together a config file written in yaml that tells DLAS how to train it. Swapping model architectures is simple and often requires no changes to actual code. This effectively enables you to run multiple concurrent experiments that use the same codebase, as well as retain backwards compatibility for past experiments.
Training effective generators often means juggling multiple loss functions. As a result, DLAS' configuraion language is specifically designed to make it easy to support large number of losses and networks that interact with each other. As an example: some GANs I have trained in this framework consist of more than 15 losses and use 2 separate discriminators and require no bespoke code.
Generators are also notorious GPU memory hogs. I have spent substantial time streamlining the training framework to support gradient checkpointing and FP16. DLAS also supports "mega batching", where multiple forward passes contribute to a single backward pass. Most models can be trained on midrange GPUs with 8-11GB of memory.
The final value-added feature is interpretability. Tensorboard logging operates out of the box with no custom code. Intermediate images from within the training pipeline can be intermittently surfaced as normal PNG files so you can see what your network is up to. Validation passes are also cached as images so you can view how your network improves over time.
## Modeling Capabilities
DLAS was built with extensibly in mind. One of the reasons I'm putting in the effort to better document this code is the incredible ease with which I have been able to train entirely new model types with no changes to the core training code.
I intend to fill out the sections below with sample configurations which can be used to train different architectures. You will need to bring your own data.
### Super-resolution
TBC..
- Pixel-based SR (SRCNN, RCAN, PANet, etc)
- GAN-based SR (ESRGAN)
- Multi-GAN SR (SPSR)
- Video SR (TecoGAN)
### Optical Flow & 3D
### Style Transfer
## Dependencies and Installation
- Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download))
- [PyTorch >= 1.1](https://pytorch.org)
- Python 3
- [PyTorch >= 1.6](https://pytorch.org)
- NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
- [Deformable Convolution](https://arxiv.org/abs/1703.06211). We use [mmdetection](https://github.com/open-mmlab/mmdetection)'s dcn implementation. Please first compile it.
```
cd ./codes/models/archs/dcn
python setup.py develop
```
- Python packages: `pip install -r requirements.txt`
- Some video utilities require [FFMPEG](https://ffmpeg.org/)
## User Guide
TBC
## Dataset Preparation
We use datasets in LDMB format for faster IO speed. Please refer to [DATASETS.md](datasets/DATASETS.md) for more details.
### Dataset Preparation
DLAS comes with some Dataset instances that I have created for my own use. Unless you want to use one of the recipes above, you'll need to provide your own. Here is how to add your own Dataset:
## Training and Testing
Please see [wiki- Training and Testing](https://github.com/open-mmlab/mmsr/wiki/Training-and-Testing) for the basic usage, *i.e.,* training and testing.
1. Create a Dataset in codes/data/ which takes a single Python dict as a constructor and extracts options from that dict.
2. Register your Dataset in codes/data/__init__.py
3. Your Dataset should return a dict of tensors. The keys of the dict are injected directly into the training state, which you can interact within your configuration file.
## Model Zoo and Baselines
Results and pre-trained models are available in the [wiki-Model Zoo](https://github.com/open-mmlab/mmsr/wiki/Model-Zoo).
### Training and Testing
There are currently 3 base scripts for interacting with models. They all take a single parameter, `-opt` which specifies the configuration file which controls how they work. Configs (will be) documented above in the user guide.
#### train.py
Starts (or continues) a training session.
`python train.py -opt <your_config.yml>`
#### test.py
Runs a model against a validation or test set of data and reports metrics (for now, just PSNR and a custom perceptual metric)
`python test.py -opt <your_config.yml>`
#### process_video.py
Breaks a video into individual frames and uses a network to do processing on it, then reassembles the output back into video form.
`python process_video -opt <your_config.yml>`
## Contributing
We appreciate all contributions. Please refer to [mmdetection](https://github.com/open-mmlab/mmdetection/blob/master/CONTRIBUTING.md) for contributing guideline.
**Python code style**<br/>
We adopt [PEP8](https://python.org/dev/peps/pep-0008) as the preferred code style. We use [flake8](http://flake8.pycqa.org/en/latest) as the linter and [yapf](https://github.com/google/yapf) as the formatter. Please upgrade to the latest yapf (>=0.27.0) and refer to the [yapf configuration](.style.yapf) and [flake8 configuration](.flake8).
> Before you create a PR, make sure that your code lints and is formatted by yapf.
At this time I am not taking feature requests or bug reports, but I appreciate all contributions.
## License
This project is released under the Apache 2.0 license.