Add ESRGAN docs

This commit is contained in:
James Betker 2020-12-20 11:50:31 -07:00
parent 7938f9f50b
commit bbc677dc7b
5 changed files with 466 additions and 3 deletions

74
recipes/esrgan/README.md Normal file
View File

@ -0,0 +1,74 @@
# Training super-resolution networks with ESRGAN
[SRGAN](https://arxiv.org/abs/1609.04802) is a landmark SR technique. It is quickly approaching "seminal" status because of how many papers
use some or all of the techniques originally introduced in this paper. [ESRGAN](https://arxiv.org/abs/1809.00219) is a followup
paper by the same authors which strictly improves the results of SRGAN.
After considerable trial and error, I recommend an additional set of modifications to ESRGAN to
improve training performance and reduce artifacts:
* Gradient penalty loss on the discriminator keeps the discriminator gradients to the generator in check.
* Adding noise of 1/255 to the discriminator prevents it from using the fixed input range of HR images for discrimination. (e.g. - natural HR images can only have values in increments of 1/255, while the generator has continuous outputs. The discriminator can cheat by using this fact.)
* Adding GroupNorm to the discriminator layers. This further stabilizes gradients without the downsides of BatchNorm.
* Adding a translational loss to the generator term. This loss works by computing using the generator to compute two HQ images
during each training pass from random sub-patches of the original image. A L1 loss is then computed across the shared
region of the two outputs with a very high gain. I found this to be tremendously helpful in reducing GAN artifacts
as it forces the generator to be self-consistent.
* Use a vanilla GAN. The ESRGAN paper promotes the use of RAGAN but I found its effect on result qualit to be minimal
with the above modifications. In some cases, it can actually be harmful because it drives strange training
dynamics on the discriminator. For example, I've observed the output of the discriminator to sometimes
"explode" when using RAGAN because it does not force a fixed output value. It is also more computationally expensive
to compute.
The examples below have all of these modifications added. I've also provided a reference file that
should be closer to the original ESRGAN implementation, `train_div2k_esrgan_reference.yml`.
## Training ESRGAN
DLAS can train and use ESRGAN models end-to-end. These docs will show you how.
### Dataset Preparation
Start by assembling your dataset. The ESRGAN paper uses the [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and
[Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) datasets. These include a small set of high-resolution
images. ESRGAN is trained on small sub-patches of those images. Generate these patches using the instructions found
in 'Generating a chunked dataset' [here](https://github.com/neonbjb/DL-Art-School/blob/gan_lab/codes/data/README.md).
Consider creating a validation set at the same time. These can just be a few medium-resolution, high-quality
images. DLAS will downsample them for you and send them through your network for validation.
### Training the model
Use the train_div2k_esrgan.yml configuration file in this directory as a template to train your
ESRGAN. Search the file for `<--` to find options that will need to be adjusted for your installation.
Train with:
`python train.py -opt train_div2k_esrgan.yml`
Note that this configuration trains an RRDB network with an L1 pixel loss only for the first 100k
steps. I recommend you save the model at step 100k (this is done by default, just copy the file
out of the experiments/train_div2k_esrgan/models directory once it hits step 100k) so that you
do not need to repeat this training in future experiments.
## Using an ESRGAN model
### Image SR
You can apply a pre-trained ESRGAN model against a set of images using the code in `test.py`.
Documentation for this script is forthcoming but basically you feed it your training configuration
file with the `pretrain_model_generator` option set properly and your folder with test images
pointed to in the datasets section in lieu of the validation set.
### Video SR
I've put together a script that strips a video into its constituent frames, applies an ESRGAN
model to each frame one a time, then recombines the frames back into videos (without sound).
You will need to use ffmpeg to stitch the videos back together and add sound, but this is
trivial.
This script is called `process_video.py` and it takes a special configuration file. A sample
config is provided in `rrdb_process_video.yml` in this directory. Further documentation on this
procedure is forthcoming.
Fun fact: the foundations of DLAS lie in the (now defunct) MMSR github repo, which was
primarily an implementation of ESRGAN.

View File

@ -0,0 +1,69 @@
name: video_process
suffix: ~ # add suffix to saved images
model: extensibletrainer
scale: 4
gpu_ids: [0]
fp16: true
minivid_crf: 12 # Defines the 'crf' output video quality parameter fed to FFMPEG
frames_per_mini_vid: 360 # How many frames to process before generating a small video segment. Used to reduce number of images you must store to convert an entire video.
minivid_start_no: 360
recurrent_mode: false
dataset:
n_workers: 1
name: myvideo
video_file: <your path> # <-- Path to your video file here. any format supported by ffmpeg works.
frame_rate: 30 # Set to the frame rate of your video.
start_at_seconds: 0 # Set this if you want to start somewhere other than the beginning of the video.
end_at_seconds: 5000 # Set to the time you want to stop at.
batch_size: 1 # Set to the number of frames to convert at once. Larger batches provide a modest performance increase.
vertical_splits: 1 # Used for 3d binocular videos. Leave at 1.
force_multiple: 1
#### network structures
networks:
generator:
type: generator
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
initial_stride: 1
nf: 64
nb: 23
scale: 4
blocks_per_checkpoint: 3
#### path
path:
pretrain_model_generator: <your path> # <-- Set your generator path here.
steps:
generator:
training: generator
generator: generator
# Optimizer params. Not used, but currently required to initialize ExtensibleTrainer, even in eval mode.
lr: !!float 5e-6
weight_decay: 0
beta1: 0.9
beta2: 0.99
injectors:
gen_inj:
type: generator
generator: generator
in: lq
out: gen
# Train section is required, even though we are just evaluating.
train:
niter: 500000
warmup_iter: -1
mega_batch_factor: 1
val_freq: 500
default_lr_scheme: MultiStepLR
gen_lr_steps: [20000, 40000, 80000, 100000, 140000, 180000]
lr_gamma: 0.5
eval:
output_state: gen

View File

@ -0,0 +1,179 @@
name: train_div2k_esrgan
model: extensibletrainer
scale: 4
gpu_ids: [0]
fp16: false
start_step: -1
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
use_tb_logger: true
wandb: false
datasets:
train:
n_workers: 2
batch_size: 16
name: div2k
mode: single_image_extensible
paths: /content/div2k # <-- Put your path here.
target_size: 128
force_multiple: 1
scale: 4
strict: false
val:
name: val
mode: fullimage
dataroot_GT: /content/set14
scale: 4
networks:
generator:
type: generator
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
initial_stride: 1
nf: 64
nb: 23
scale: 4
blocks_per_checkpoint: 3
feature_discriminator:
type: discriminator
which_model_D: discriminator_vgg_128_gn
scale: 2
nf: 64
in_nc: 3
image_size: 96
#### path
path:
#pretrain_model_generator: <insert pretrained model path if desired>
strict_load: true
#resume_state: ../experiments/train_div2k_esrgan/training_state/0.state # <-- Set this to resume from a previous training state.
steps:
feature_discriminator:
training: feature_discriminator
after: 100000 # Discriminator doesn't "turn-on" until step 100k to allow generator to anneal on PSNR loss.
# Optimizer params
lr: !!float 2e-4
weight_decay: 0
beta1: 0.9
beta2: 0.99
injectors:
# "image_patch" injectors support the translational loss below. You can remove them if you remove that loss.
plq:
type: image_patch
patch_size: 24
in: lq
out: plq
phq:
type: image_patch
patch_size: 96
in: hq
out: phq
dgen_inj:
type: generator
generator: generator
grad: false
in: plq
out: dgen
losses:
gan_disc_img:
type: discriminator_gan
gan_type: gan
weight: 1
#min_loss: .4
noise: .004
gradient_penalty: true
real: phq
fake: dgen
generator:
training: generator
optimizer_params:
lr: !!float 2e-4
weight_decay: 0
beta1: 0.9
beta2: 0.99
injectors:
pglq:
type: image_patch
patch_size: 24
in: lq
out: pglq
pghq:
type: image_patch
patch_size: 96
in: hq
out: pghq
gen_inj:
type: generator
generator: generator
in: pglq
out: gen
losses:
pix:
type: pix
weight: .05
criterion: l1
real: pghq
fake: gen
feature:
type: feature
after: 80000 # Perceptual/"feature" loss doesn't turn on until step 80k.
which_model_F: vgg
criterion: l1
weight: 1
real: pghq
fake: gen
gan_gen_img:
after: 100000
type: generator_gan
gan_type: gan
weight: .02
noise: .004
discriminator: feature_discriminator
fake: gen
real: pghq
# Translational loss <- not present in the original ESRGAN paper, but I find it reduces artifacts from the GAN.
# Feel free to remove. The network will still train well.
translational:
type: translational
after: 80000
weight: 2
criterion: l1
generator: generator
generator_output_index: 0
detach_fake: false
patch_size: 96
overlap: 64
real: gen
fake: ['pglq']
train:
niter: 500000
warmup_iter: -1
mega_batch_factor: 1
val_freq: 2000
# LR scheduler options
default_lr_scheme: MultiStepLR
gen_lr_steps: [140000, 180000, 200000, 240000] # LR is halved at these steps. Don't do it until GAN is online.
lr_gamma: 0.5
eval:
output_state: gen
logger:
print_freq: 30
save_checkpoint_freq: 1000
visuals: [gen, hq, pglq, pghq]
visual_debug_rate: 100

View File

@ -0,0 +1,142 @@
# This is a config file that trains ESRGAN using the dynamics spelled out in the paper with no modifications.
# This has not been trained to completion in some time. I make no guarantees that it will work well.
name: train_div2k_esrgan_reference
model: extensibletrainer
scale: 4
gpu_ids: [0]
fp16: false
start_step: -1
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
use_tb_logger: true
wandb: false
datasets:
train:
n_workers: 2
batch_size: 16
name: div2k
mode: single_image_extensible
paths: /content/div2k # <-- Put your path here.
target_size: 128
force_multiple: 1
scale: 4
strict: false
val:
name: val
mode: fullimage
dataroot_GT: /content/set14
scale: 4
networks:
generator:
type: generator
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
initial_stride: 1
nf: 64
nb: 23
scale: 4
blocks_per_checkpoint: 3
feature_discriminator:
type: discriminator
which_model_D: discriminator_vgg_128
scale: 2
nf: 64
in_nc: 3
#### path
path:
#pretrain_model_generator: <insert pretrained model path if desired>
strict_load: true
#resume_state: ../experiments/train_div2k_esrgan/training_state/0.state # <-- Set this to resume from a previous training state.
steps:
feature_discriminator:
training: feature_discriminator
after: 100000 # Discriminator doesn't "turn-on" until step 100k to allow generator to anneal on PSNR loss.
# Optimizer params
lr: !!float 2e-4
weight_decay: 0
beta1: 0.9
beta2: 0.99
injectors:
dgen_inj:
type: generator
generator: generator
grad: false
in: lq
out: dgen
losses:
gan_disc_img:
type: discriminator_gan
gan_type: ragan
weight: 1
real: hq
fake: dgen
generator:
training: generator
optimizer_params:
lr: !!float 2e-4
weight_decay: 0
beta1: 0.9
beta2: 0.99
injectors:
gen_inj:
type: generator
generator: generator
in: lq
out: gen
losses:
pix:
type: pix
weight: .05
criterion: l1
real: hq
fake: gen
feature:
type: feature
after: 80000 # Perceptual/"feature" loss doesn't turn on until step 80k.
which_model_F: vgg
criterion: l1
weight: 1
real: hq
fake: gen
gan_gen_img:
after: 100000
type: generator_gan
gan_type: ragan
weight: .02
discriminator: feature_discriminator
fake: gen
real: hq
train:
niter: 500000
warmup_iter: -1
mega_batch_factor: 1
val_freq: 2000
# LR scheduler options
default_lr_scheme: MultiStepLR
gen_lr_steps: [140000, 180000, 200000, 240000] # LR is halved at these steps. Don't do it until GAN is online.
lr_gamma: 0.5
eval:
output_state: gen
logger:
print_freq: 30
save_checkpoint_freq: 1000
visuals: [gen, hq, lq]
visual_debug_rate: 100

View File

@ -1,12 +1,11 @@
#### general settings
name: train_div2k_srflow
use_tb_logger: true
model: extensibletrainer
scale: 4
gpu_ids: [0]
fp16: false
start_step: -1
checkpointing_enabled: true
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
use_tb_logger: true
wandb: false
datasets: