From bbc677dc7bb5d3ad32e3754198eecb156093051a Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 20 Dec 2020 11:50:31 -0700 Subject: [PATCH] Add ESRGAN docs --- recipes/esrgan/README.md | 74 ++++++++ recipes/esrgan/rrdb_process_video.yml | 69 +++++++ recipes/esrgan/train_div2k_esrgan.yml | 179 ++++++++++++++++++ .../esrgan/train_div2k_esrgan_reference.yml | 142 ++++++++++++++ recipes/srflow/train_div2k_srflow.yml | 5 +- 5 files changed, 466 insertions(+), 3 deletions(-) create mode 100644 recipes/esrgan/README.md create mode 100644 recipes/esrgan/rrdb_process_video.yml create mode 100644 recipes/esrgan/train_div2k_esrgan.yml create mode 100644 recipes/esrgan/train_div2k_esrgan_reference.yml diff --git a/recipes/esrgan/README.md b/recipes/esrgan/README.md new file mode 100644 index 00000000..6bb80fa4 --- /dev/null +++ b/recipes/esrgan/README.md @@ -0,0 +1,74 @@ +# Training super-resolution networks with ESRGAN + +[SRGAN](https://arxiv.org/abs/1609.04802) is a landmark SR technique. It is quickly approaching "seminal" status because of how many papers +use some or all of the techniques originally introduced in this paper. [ESRGAN](https://arxiv.org/abs/1809.00219) is a followup +paper by the same authors which strictly improves the results of SRGAN. + +After considerable trial and error, I recommend an additional set of modifications to ESRGAN to +improve training performance and reduce artifacts: + +* Gradient penalty loss on the discriminator keeps the discriminator gradients to the generator in check. +* Adding noise of 1/255 to the discriminator prevents it from using the fixed input range of HR images for discrimination. (e.g. - natural HR images can only have values in increments of 1/255, while the generator has continuous outputs. The discriminator can cheat by using this fact.) +* Adding GroupNorm to the discriminator layers. This further stabilizes gradients without the downsides of BatchNorm. +* Adding a translational loss to the generator term. This loss works by computing using the generator to compute two HQ images + during each training pass from random sub-patches of the original image. A L1 loss is then computed across the shared + region of the two outputs with a very high gain. I found this to be tremendously helpful in reducing GAN artifacts + as it forces the generator to be self-consistent. +* Use a vanilla GAN. The ESRGAN paper promotes the use of RAGAN but I found its effect on result qualit to be minimal + with the above modifications. In some cases, it can actually be harmful because it drives strange training + dynamics on the discriminator. For example, I've observed the output of the discriminator to sometimes + "explode" when using RAGAN because it does not force a fixed output value. It is also more computationally expensive + to compute. + +The examples below have all of these modifications added. I've also provided a reference file that +should be closer to the original ESRGAN implementation, `train_div2k_esrgan_reference.yml`. + +## Training ESRGAN + +DLAS can train and use ESRGAN models end-to-end. These docs will show you how. + +### Dataset Preparation + +Start by assembling your dataset. The ESRGAN paper uses the [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and +[Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) datasets. These include a small set of high-resolution +images. ESRGAN is trained on small sub-patches of those images. Generate these patches using the instructions found +in 'Generating a chunked dataset' [here](https://github.com/neonbjb/DL-Art-School/blob/gan_lab/codes/data/README.md). + +Consider creating a validation set at the same time. These can just be a few medium-resolution, high-quality +images. DLAS will downsample them for you and send them through your network for validation. + +### Training the model + +Use the train_div2k_esrgan.yml configuration file in this directory as a template to train your +ESRGAN. Search the file for `<--` to find options that will need to be adjusted for your installation. + +Train with: +`python train.py -opt train_div2k_esrgan.yml` + +Note that this configuration trains an RRDB network with an L1 pixel loss only for the first 100k +steps. I recommend you save the model at step 100k (this is done by default, just copy the file +out of the experiments/train_div2k_esrgan/models directory once it hits step 100k) so that you +do not need to repeat this training in future experiments. + +## Using an ESRGAN model + +### Image SR + +You can apply a pre-trained ESRGAN model against a set of images using the code in `test.py`. +Documentation for this script is forthcoming but basically you feed it your training configuration +file with the `pretrain_model_generator` option set properly and your folder with test images +pointed to in the datasets section in lieu of the validation set. + +### Video SR + +I've put together a script that strips a video into its constituent frames, applies an ESRGAN +model to each frame one a time, then recombines the frames back into videos (without sound). +You will need to use ffmpeg to stitch the videos back together and add sound, but this is +trivial. + +This script is called `process_video.py` and it takes a special configuration file. A sample +config is provided in `rrdb_process_video.yml` in this directory. Further documentation on this +procedure is forthcoming. + +Fun fact: the foundations of DLAS lie in the (now defunct) MMSR github repo, which was +primarily an implementation of ESRGAN. diff --git a/recipes/esrgan/rrdb_process_video.yml b/recipes/esrgan/rrdb_process_video.yml new file mode 100644 index 00000000..fac52777 --- /dev/null +++ b/recipes/esrgan/rrdb_process_video.yml @@ -0,0 +1,69 @@ +name: video_process +suffix: ~ # add suffix to saved images +model: extensibletrainer +scale: 4 +gpu_ids: [0] +fp16: true +minivid_crf: 12 # Defines the 'crf' output video quality parameter fed to FFMPEG +frames_per_mini_vid: 360 # How many frames to process before generating a small video segment. Used to reduce number of images you must store to convert an entire video. +minivid_start_no: 360 +recurrent_mode: false + +dataset: + n_workers: 1 + name: myvideo + video_file: # <-- Path to your video file here. any format supported by ffmpeg works. + frame_rate: 30 # Set to the frame rate of your video. + start_at_seconds: 0 # Set this if you want to start somewhere other than the beginning of the video. + end_at_seconds: 5000 # Set to the time you want to stop at. + batch_size: 1 # Set to the number of frames to convert at once. Larger batches provide a modest performance increase. + vertical_splits: 1 # Used for 3d binocular videos. Leave at 1. + force_multiple: 1 + +#### network structures +networks: + generator: + type: generator + which_model_G: RRDBNet + in_nc: 3 + out_nc: 3 + initial_stride: 1 + nf: 64 + nb: 23 + scale: 4 + blocks_per_checkpoint: 3 + +#### path +path: + pretrain_model_generator: # <-- Set your generator path here. + +steps: + generator: + training: generator + generator: generator + + # Optimizer params. Not used, but currently required to initialize ExtensibleTrainer, even in eval mode. + lr: !!float 5e-6 + weight_decay: 0 + beta1: 0.9 + beta2: 0.99 + + injectors: + gen_inj: + type: generator + generator: generator + in: lq + out: gen + +# Train section is required, even though we are just evaluating. +train: + niter: 500000 + warmup_iter: -1 + mega_batch_factor: 1 + val_freq: 500 + default_lr_scheme: MultiStepLR + gen_lr_steps: [20000, 40000, 80000, 100000, 140000, 180000] + lr_gamma: 0.5 + +eval: + output_state: gen \ No newline at end of file diff --git a/recipes/esrgan/train_div2k_esrgan.yml b/recipes/esrgan/train_div2k_esrgan.yml new file mode 100644 index 00000000..bb1e4bc1 --- /dev/null +++ b/recipes/esrgan/train_div2k_esrgan.yml @@ -0,0 +1,179 @@ +name: train_div2k_esrgan +model: extensibletrainer +scale: 4 +gpu_ids: [0] +fp16: false +start_step: -1 +checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training. +use_tb_logger: true +wandb: false + +datasets: + train: + n_workers: 2 + batch_size: 16 + name: div2k + mode: single_image_extensible + paths: /content/div2k # <-- Put your path here. + target_size: 128 + force_multiple: 1 + scale: 4 + strict: false + val: + name: val + mode: fullimage + dataroot_GT: /content/set14 + scale: 4 + +networks: + generator: + type: generator + which_model_G: RRDBNet + in_nc: 3 + out_nc: 3 + initial_stride: 1 + nf: 64 + nb: 23 + scale: 4 + blocks_per_checkpoint: 3 + + feature_discriminator: + type: discriminator + which_model_D: discriminator_vgg_128_gn + scale: 2 + nf: 64 + in_nc: 3 + image_size: 96 + +#### path +path: + #pretrain_model_generator: + strict_load: true + #resume_state: ../experiments/train_div2k_esrgan/training_state/0.state # <-- Set this to resume from a previous training state. + +steps: + + feature_discriminator: + training: feature_discriminator + after: 100000 # Discriminator doesn't "turn-on" until step 100k to allow generator to anneal on PSNR loss. + + # Optimizer params + lr: !!float 2e-4 + weight_decay: 0 + beta1: 0.9 + beta2: 0.99 + + injectors: + # "image_patch" injectors support the translational loss below. You can remove them if you remove that loss. + plq: + type: image_patch + patch_size: 24 + in: lq + out: plq + phq: + type: image_patch + patch_size: 96 + in: hq + out: phq + dgen_inj: + type: generator + generator: generator + grad: false + in: plq + out: dgen + + losses: + gan_disc_img: + type: discriminator_gan + gan_type: gan + weight: 1 + #min_loss: .4 + noise: .004 + gradient_penalty: true + real: phq + fake: dgen + + generator: + training: generator + + optimizer_params: + lr: !!float 2e-4 + weight_decay: 0 + beta1: 0.9 + beta2: 0.99 + + injectors: + pglq: + type: image_patch + patch_size: 24 + in: lq + out: pglq + pghq: + type: image_patch + patch_size: 96 + in: hq + out: pghq + gen_inj: + type: generator + generator: generator + in: pglq + out: gen + + losses: + pix: + type: pix + weight: .05 + criterion: l1 + real: pghq + fake: gen + feature: + type: feature + after: 80000 # Perceptual/"feature" loss doesn't turn on until step 80k. + which_model_F: vgg + criterion: l1 + weight: 1 + real: pghq + fake: gen + gan_gen_img: + after: 100000 + type: generator_gan + gan_type: gan + weight: .02 + noise: .004 + discriminator: feature_discriminator + fake: gen + real: pghq + # Translational loss <- not present in the original ESRGAN paper, but I find it reduces artifacts from the GAN. + # Feel free to remove. The network will still train well. + translational: + type: translational + after: 80000 + weight: 2 + criterion: l1 + generator: generator + generator_output_index: 0 + detach_fake: false + patch_size: 96 + overlap: 64 + real: gen + fake: ['pglq'] + +train: + niter: 500000 + warmup_iter: -1 + mega_batch_factor: 1 + val_freq: 2000 + + # LR scheduler options + default_lr_scheme: MultiStepLR + gen_lr_steps: [140000, 180000, 200000, 240000] # LR is halved at these steps. Don't do it until GAN is online. + lr_gamma: 0.5 + +eval: + output_state: gen + +logger: + print_freq: 30 + save_checkpoint_freq: 1000 + visuals: [gen, hq, pglq, pghq] + visual_debug_rate: 100 \ No newline at end of file diff --git a/recipes/esrgan/train_div2k_esrgan_reference.yml b/recipes/esrgan/train_div2k_esrgan_reference.yml new file mode 100644 index 00000000..0e769f80 --- /dev/null +++ b/recipes/esrgan/train_div2k_esrgan_reference.yml @@ -0,0 +1,142 @@ +# This is a config file that trains ESRGAN using the dynamics spelled out in the paper with no modifications. +# This has not been trained to completion in some time. I make no guarantees that it will work well. + +name: train_div2k_esrgan_reference +model: extensibletrainer +scale: 4 +gpu_ids: [0] +fp16: false +start_step: -1 +checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training. +use_tb_logger: true +wandb: false + +datasets: + train: + n_workers: 2 + batch_size: 16 + name: div2k + mode: single_image_extensible + paths: /content/div2k # <-- Put your path here. + target_size: 128 + force_multiple: 1 + scale: 4 + strict: false + val: + name: val + mode: fullimage + dataroot_GT: /content/set14 + scale: 4 + +networks: + generator: + type: generator + which_model_G: RRDBNet + in_nc: 3 + out_nc: 3 + initial_stride: 1 + nf: 64 + nb: 23 + scale: 4 + blocks_per_checkpoint: 3 + + feature_discriminator: + type: discriminator + which_model_D: discriminator_vgg_128 + scale: 2 + nf: 64 + in_nc: 3 + +#### path +path: + #pretrain_model_generator: + strict_load: true + #resume_state: ../experiments/train_div2k_esrgan/training_state/0.state # <-- Set this to resume from a previous training state. + +steps: + + feature_discriminator: + training: feature_discriminator + after: 100000 # Discriminator doesn't "turn-on" until step 100k to allow generator to anneal on PSNR loss. + + # Optimizer params + lr: !!float 2e-4 + weight_decay: 0 + beta1: 0.9 + beta2: 0.99 + + injectors: + dgen_inj: + type: generator + generator: generator + grad: false + in: lq + out: dgen + + losses: + gan_disc_img: + type: discriminator_gan + gan_type: ragan + weight: 1 + real: hq + fake: dgen + + generator: + training: generator + + optimizer_params: + lr: !!float 2e-4 + weight_decay: 0 + beta1: 0.9 + beta2: 0.99 + + injectors: + gen_inj: + type: generator + generator: generator + in: lq + out: gen + + losses: + pix: + type: pix + weight: .05 + criterion: l1 + real: hq + fake: gen + feature: + type: feature + after: 80000 # Perceptual/"feature" loss doesn't turn on until step 80k. + which_model_F: vgg + criterion: l1 + weight: 1 + real: hq + fake: gen + gan_gen_img: + after: 100000 + type: generator_gan + gan_type: ragan + weight: .02 + discriminator: feature_discriminator + fake: gen + real: hq + +train: + niter: 500000 + warmup_iter: -1 + mega_batch_factor: 1 + val_freq: 2000 + + # LR scheduler options + default_lr_scheme: MultiStepLR + gen_lr_steps: [140000, 180000, 200000, 240000] # LR is halved at these steps. Don't do it until GAN is online. + lr_gamma: 0.5 + +eval: + output_state: gen + +logger: + print_freq: 30 + save_checkpoint_freq: 1000 + visuals: [gen, hq, lq] + visual_debug_rate: 100 \ No newline at end of file diff --git a/recipes/srflow/train_div2k_srflow.yml b/recipes/srflow/train_div2k_srflow.yml index 122576a6..e95e2d3d 100644 --- a/recipes/srflow/train_div2k_srflow.yml +++ b/recipes/srflow/train_div2k_srflow.yml @@ -1,12 +1,11 @@ -#### general settings name: train_div2k_srflow -use_tb_logger: true model: extensibletrainer scale: 4 gpu_ids: [0] fp16: false start_step: -1 -checkpointing_enabled: true +checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training. +use_tb_logger: true wandb: false datasets: