forked from mrq/DL-Art-School
Add ESRGAN docs
This commit is contained in:
parent
7938f9f50b
commit
bbc677dc7b
74
recipes/esrgan/README.md
Normal file
74
recipes/esrgan/README.md
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
# Training super-resolution networks with ESRGAN
|
||||||
|
|
||||||
|
[SRGAN](https://arxiv.org/abs/1609.04802) is a landmark SR technique. It is quickly approaching "seminal" status because of how many papers
|
||||||
|
use some or all of the techniques originally introduced in this paper. [ESRGAN](https://arxiv.org/abs/1809.00219) is a followup
|
||||||
|
paper by the same authors which strictly improves the results of SRGAN.
|
||||||
|
|
||||||
|
After considerable trial and error, I recommend an additional set of modifications to ESRGAN to
|
||||||
|
improve training performance and reduce artifacts:
|
||||||
|
|
||||||
|
* Gradient penalty loss on the discriminator keeps the discriminator gradients to the generator in check.
|
||||||
|
* Adding noise of 1/255 to the discriminator prevents it from using the fixed input range of HR images for discrimination. (e.g. - natural HR images can only have values in increments of 1/255, while the generator has continuous outputs. The discriminator can cheat by using this fact.)
|
||||||
|
* Adding GroupNorm to the discriminator layers. This further stabilizes gradients without the downsides of BatchNorm.
|
||||||
|
* Adding a translational loss to the generator term. This loss works by computing using the generator to compute two HQ images
|
||||||
|
during each training pass from random sub-patches of the original image. A L1 loss is then computed across the shared
|
||||||
|
region of the two outputs with a very high gain. I found this to be tremendously helpful in reducing GAN artifacts
|
||||||
|
as it forces the generator to be self-consistent.
|
||||||
|
* Use a vanilla GAN. The ESRGAN paper promotes the use of RAGAN but I found its effect on result qualit to be minimal
|
||||||
|
with the above modifications. In some cases, it can actually be harmful because it drives strange training
|
||||||
|
dynamics on the discriminator. For example, I've observed the output of the discriminator to sometimes
|
||||||
|
"explode" when using RAGAN because it does not force a fixed output value. It is also more computationally expensive
|
||||||
|
to compute.
|
||||||
|
|
||||||
|
The examples below have all of these modifications added. I've also provided a reference file that
|
||||||
|
should be closer to the original ESRGAN implementation, `train_div2k_esrgan_reference.yml`.
|
||||||
|
|
||||||
|
## Training ESRGAN
|
||||||
|
|
||||||
|
DLAS can train and use ESRGAN models end-to-end. These docs will show you how.
|
||||||
|
|
||||||
|
### Dataset Preparation
|
||||||
|
|
||||||
|
Start by assembling your dataset. The ESRGAN paper uses the [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and
|
||||||
|
[Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) datasets. These include a small set of high-resolution
|
||||||
|
images. ESRGAN is trained on small sub-patches of those images. Generate these patches using the instructions found
|
||||||
|
in 'Generating a chunked dataset' [here](https://github.com/neonbjb/DL-Art-School/blob/gan_lab/codes/data/README.md).
|
||||||
|
|
||||||
|
Consider creating a validation set at the same time. These can just be a few medium-resolution, high-quality
|
||||||
|
images. DLAS will downsample them for you and send them through your network for validation.
|
||||||
|
|
||||||
|
### Training the model
|
||||||
|
|
||||||
|
Use the train_div2k_esrgan.yml configuration file in this directory as a template to train your
|
||||||
|
ESRGAN. Search the file for `<--` to find options that will need to be adjusted for your installation.
|
||||||
|
|
||||||
|
Train with:
|
||||||
|
`python train.py -opt train_div2k_esrgan.yml`
|
||||||
|
|
||||||
|
Note that this configuration trains an RRDB network with an L1 pixel loss only for the first 100k
|
||||||
|
steps. I recommend you save the model at step 100k (this is done by default, just copy the file
|
||||||
|
out of the experiments/train_div2k_esrgan/models directory once it hits step 100k) so that you
|
||||||
|
do not need to repeat this training in future experiments.
|
||||||
|
|
||||||
|
## Using an ESRGAN model
|
||||||
|
|
||||||
|
### Image SR
|
||||||
|
|
||||||
|
You can apply a pre-trained ESRGAN model against a set of images using the code in `test.py`.
|
||||||
|
Documentation for this script is forthcoming but basically you feed it your training configuration
|
||||||
|
file with the `pretrain_model_generator` option set properly and your folder with test images
|
||||||
|
pointed to in the datasets section in lieu of the validation set.
|
||||||
|
|
||||||
|
### Video SR
|
||||||
|
|
||||||
|
I've put together a script that strips a video into its constituent frames, applies an ESRGAN
|
||||||
|
model to each frame one a time, then recombines the frames back into videos (without sound).
|
||||||
|
You will need to use ffmpeg to stitch the videos back together and add sound, but this is
|
||||||
|
trivial.
|
||||||
|
|
||||||
|
This script is called `process_video.py` and it takes a special configuration file. A sample
|
||||||
|
config is provided in `rrdb_process_video.yml` in this directory. Further documentation on this
|
||||||
|
procedure is forthcoming.
|
||||||
|
|
||||||
|
Fun fact: the foundations of DLAS lie in the (now defunct) MMSR github repo, which was
|
||||||
|
primarily an implementation of ESRGAN.
|
69
recipes/esrgan/rrdb_process_video.yml
Normal file
69
recipes/esrgan/rrdb_process_video.yml
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
name: video_process
|
||||||
|
suffix: ~ # add suffix to saved images
|
||||||
|
model: extensibletrainer
|
||||||
|
scale: 4
|
||||||
|
gpu_ids: [0]
|
||||||
|
fp16: true
|
||||||
|
minivid_crf: 12 # Defines the 'crf' output video quality parameter fed to FFMPEG
|
||||||
|
frames_per_mini_vid: 360 # How many frames to process before generating a small video segment. Used to reduce number of images you must store to convert an entire video.
|
||||||
|
minivid_start_no: 360
|
||||||
|
recurrent_mode: false
|
||||||
|
|
||||||
|
dataset:
|
||||||
|
n_workers: 1
|
||||||
|
name: myvideo
|
||||||
|
video_file: <your path> # <-- Path to your video file here. any format supported by ffmpeg works.
|
||||||
|
frame_rate: 30 # Set to the frame rate of your video.
|
||||||
|
start_at_seconds: 0 # Set this if you want to start somewhere other than the beginning of the video.
|
||||||
|
end_at_seconds: 5000 # Set to the time you want to stop at.
|
||||||
|
batch_size: 1 # Set to the number of frames to convert at once. Larger batches provide a modest performance increase.
|
||||||
|
vertical_splits: 1 # Used for 3d binocular videos. Leave at 1.
|
||||||
|
force_multiple: 1
|
||||||
|
|
||||||
|
#### network structures
|
||||||
|
networks:
|
||||||
|
generator:
|
||||||
|
type: generator
|
||||||
|
which_model_G: RRDBNet
|
||||||
|
in_nc: 3
|
||||||
|
out_nc: 3
|
||||||
|
initial_stride: 1
|
||||||
|
nf: 64
|
||||||
|
nb: 23
|
||||||
|
scale: 4
|
||||||
|
blocks_per_checkpoint: 3
|
||||||
|
|
||||||
|
#### path
|
||||||
|
path:
|
||||||
|
pretrain_model_generator: <your path> # <-- Set your generator path here.
|
||||||
|
|
||||||
|
steps:
|
||||||
|
generator:
|
||||||
|
training: generator
|
||||||
|
generator: generator
|
||||||
|
|
||||||
|
# Optimizer params. Not used, but currently required to initialize ExtensibleTrainer, even in eval mode.
|
||||||
|
lr: !!float 5e-6
|
||||||
|
weight_decay: 0
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.99
|
||||||
|
|
||||||
|
injectors:
|
||||||
|
gen_inj:
|
||||||
|
type: generator
|
||||||
|
generator: generator
|
||||||
|
in: lq
|
||||||
|
out: gen
|
||||||
|
|
||||||
|
# Train section is required, even though we are just evaluating.
|
||||||
|
train:
|
||||||
|
niter: 500000
|
||||||
|
warmup_iter: -1
|
||||||
|
mega_batch_factor: 1
|
||||||
|
val_freq: 500
|
||||||
|
default_lr_scheme: MultiStepLR
|
||||||
|
gen_lr_steps: [20000, 40000, 80000, 100000, 140000, 180000]
|
||||||
|
lr_gamma: 0.5
|
||||||
|
|
||||||
|
eval:
|
||||||
|
output_state: gen
|
179
recipes/esrgan/train_div2k_esrgan.yml
Normal file
179
recipes/esrgan/train_div2k_esrgan.yml
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
name: train_div2k_esrgan
|
||||||
|
model: extensibletrainer
|
||||||
|
scale: 4
|
||||||
|
gpu_ids: [0]
|
||||||
|
fp16: false
|
||||||
|
start_step: -1
|
||||||
|
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
|
||||||
|
use_tb_logger: true
|
||||||
|
wandb: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
n_workers: 2
|
||||||
|
batch_size: 16
|
||||||
|
name: div2k
|
||||||
|
mode: single_image_extensible
|
||||||
|
paths: /content/div2k # <-- Put your path here.
|
||||||
|
target_size: 128
|
||||||
|
force_multiple: 1
|
||||||
|
scale: 4
|
||||||
|
strict: false
|
||||||
|
val:
|
||||||
|
name: val
|
||||||
|
mode: fullimage
|
||||||
|
dataroot_GT: /content/set14
|
||||||
|
scale: 4
|
||||||
|
|
||||||
|
networks:
|
||||||
|
generator:
|
||||||
|
type: generator
|
||||||
|
which_model_G: RRDBNet
|
||||||
|
in_nc: 3
|
||||||
|
out_nc: 3
|
||||||
|
initial_stride: 1
|
||||||
|
nf: 64
|
||||||
|
nb: 23
|
||||||
|
scale: 4
|
||||||
|
blocks_per_checkpoint: 3
|
||||||
|
|
||||||
|
feature_discriminator:
|
||||||
|
type: discriminator
|
||||||
|
which_model_D: discriminator_vgg_128_gn
|
||||||
|
scale: 2
|
||||||
|
nf: 64
|
||||||
|
in_nc: 3
|
||||||
|
image_size: 96
|
||||||
|
|
||||||
|
#### path
|
||||||
|
path:
|
||||||
|
#pretrain_model_generator: <insert pretrained model path if desired>
|
||||||
|
strict_load: true
|
||||||
|
#resume_state: ../experiments/train_div2k_esrgan/training_state/0.state # <-- Set this to resume from a previous training state.
|
||||||
|
|
||||||
|
steps:
|
||||||
|
|
||||||
|
feature_discriminator:
|
||||||
|
training: feature_discriminator
|
||||||
|
after: 100000 # Discriminator doesn't "turn-on" until step 100k to allow generator to anneal on PSNR loss.
|
||||||
|
|
||||||
|
# Optimizer params
|
||||||
|
lr: !!float 2e-4
|
||||||
|
weight_decay: 0
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.99
|
||||||
|
|
||||||
|
injectors:
|
||||||
|
# "image_patch" injectors support the translational loss below. You can remove them if you remove that loss.
|
||||||
|
plq:
|
||||||
|
type: image_patch
|
||||||
|
patch_size: 24
|
||||||
|
in: lq
|
||||||
|
out: plq
|
||||||
|
phq:
|
||||||
|
type: image_patch
|
||||||
|
patch_size: 96
|
||||||
|
in: hq
|
||||||
|
out: phq
|
||||||
|
dgen_inj:
|
||||||
|
type: generator
|
||||||
|
generator: generator
|
||||||
|
grad: false
|
||||||
|
in: plq
|
||||||
|
out: dgen
|
||||||
|
|
||||||
|
losses:
|
||||||
|
gan_disc_img:
|
||||||
|
type: discriminator_gan
|
||||||
|
gan_type: gan
|
||||||
|
weight: 1
|
||||||
|
#min_loss: .4
|
||||||
|
noise: .004
|
||||||
|
gradient_penalty: true
|
||||||
|
real: phq
|
||||||
|
fake: dgen
|
||||||
|
|
||||||
|
generator:
|
||||||
|
training: generator
|
||||||
|
|
||||||
|
optimizer_params:
|
||||||
|
lr: !!float 2e-4
|
||||||
|
weight_decay: 0
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.99
|
||||||
|
|
||||||
|
injectors:
|
||||||
|
pglq:
|
||||||
|
type: image_patch
|
||||||
|
patch_size: 24
|
||||||
|
in: lq
|
||||||
|
out: pglq
|
||||||
|
pghq:
|
||||||
|
type: image_patch
|
||||||
|
patch_size: 96
|
||||||
|
in: hq
|
||||||
|
out: pghq
|
||||||
|
gen_inj:
|
||||||
|
type: generator
|
||||||
|
generator: generator
|
||||||
|
in: pglq
|
||||||
|
out: gen
|
||||||
|
|
||||||
|
losses:
|
||||||
|
pix:
|
||||||
|
type: pix
|
||||||
|
weight: .05
|
||||||
|
criterion: l1
|
||||||
|
real: pghq
|
||||||
|
fake: gen
|
||||||
|
feature:
|
||||||
|
type: feature
|
||||||
|
after: 80000 # Perceptual/"feature" loss doesn't turn on until step 80k.
|
||||||
|
which_model_F: vgg
|
||||||
|
criterion: l1
|
||||||
|
weight: 1
|
||||||
|
real: pghq
|
||||||
|
fake: gen
|
||||||
|
gan_gen_img:
|
||||||
|
after: 100000
|
||||||
|
type: generator_gan
|
||||||
|
gan_type: gan
|
||||||
|
weight: .02
|
||||||
|
noise: .004
|
||||||
|
discriminator: feature_discriminator
|
||||||
|
fake: gen
|
||||||
|
real: pghq
|
||||||
|
# Translational loss <- not present in the original ESRGAN paper, but I find it reduces artifacts from the GAN.
|
||||||
|
# Feel free to remove. The network will still train well.
|
||||||
|
translational:
|
||||||
|
type: translational
|
||||||
|
after: 80000
|
||||||
|
weight: 2
|
||||||
|
criterion: l1
|
||||||
|
generator: generator
|
||||||
|
generator_output_index: 0
|
||||||
|
detach_fake: false
|
||||||
|
patch_size: 96
|
||||||
|
overlap: 64
|
||||||
|
real: gen
|
||||||
|
fake: ['pglq']
|
||||||
|
|
||||||
|
train:
|
||||||
|
niter: 500000
|
||||||
|
warmup_iter: -1
|
||||||
|
mega_batch_factor: 1
|
||||||
|
val_freq: 2000
|
||||||
|
|
||||||
|
# LR scheduler options
|
||||||
|
default_lr_scheme: MultiStepLR
|
||||||
|
gen_lr_steps: [140000, 180000, 200000, 240000] # LR is halved at these steps. Don't do it until GAN is online.
|
||||||
|
lr_gamma: 0.5
|
||||||
|
|
||||||
|
eval:
|
||||||
|
output_state: gen
|
||||||
|
|
||||||
|
logger:
|
||||||
|
print_freq: 30
|
||||||
|
save_checkpoint_freq: 1000
|
||||||
|
visuals: [gen, hq, pglq, pghq]
|
||||||
|
visual_debug_rate: 100
|
142
recipes/esrgan/train_div2k_esrgan_reference.yml
Normal file
142
recipes/esrgan/train_div2k_esrgan_reference.yml
Normal file
|
@ -0,0 +1,142 @@
|
||||||
|
# This is a config file that trains ESRGAN using the dynamics spelled out in the paper with no modifications.
|
||||||
|
# This has not been trained to completion in some time. I make no guarantees that it will work well.
|
||||||
|
|
||||||
|
name: train_div2k_esrgan_reference
|
||||||
|
model: extensibletrainer
|
||||||
|
scale: 4
|
||||||
|
gpu_ids: [0]
|
||||||
|
fp16: false
|
||||||
|
start_step: -1
|
||||||
|
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
|
||||||
|
use_tb_logger: true
|
||||||
|
wandb: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
n_workers: 2
|
||||||
|
batch_size: 16
|
||||||
|
name: div2k
|
||||||
|
mode: single_image_extensible
|
||||||
|
paths: /content/div2k # <-- Put your path here.
|
||||||
|
target_size: 128
|
||||||
|
force_multiple: 1
|
||||||
|
scale: 4
|
||||||
|
strict: false
|
||||||
|
val:
|
||||||
|
name: val
|
||||||
|
mode: fullimage
|
||||||
|
dataroot_GT: /content/set14
|
||||||
|
scale: 4
|
||||||
|
|
||||||
|
networks:
|
||||||
|
generator:
|
||||||
|
type: generator
|
||||||
|
which_model_G: RRDBNet
|
||||||
|
in_nc: 3
|
||||||
|
out_nc: 3
|
||||||
|
initial_stride: 1
|
||||||
|
nf: 64
|
||||||
|
nb: 23
|
||||||
|
scale: 4
|
||||||
|
blocks_per_checkpoint: 3
|
||||||
|
|
||||||
|
feature_discriminator:
|
||||||
|
type: discriminator
|
||||||
|
which_model_D: discriminator_vgg_128
|
||||||
|
scale: 2
|
||||||
|
nf: 64
|
||||||
|
in_nc: 3
|
||||||
|
|
||||||
|
#### path
|
||||||
|
path:
|
||||||
|
#pretrain_model_generator: <insert pretrained model path if desired>
|
||||||
|
strict_load: true
|
||||||
|
#resume_state: ../experiments/train_div2k_esrgan/training_state/0.state # <-- Set this to resume from a previous training state.
|
||||||
|
|
||||||
|
steps:
|
||||||
|
|
||||||
|
feature_discriminator:
|
||||||
|
training: feature_discriminator
|
||||||
|
after: 100000 # Discriminator doesn't "turn-on" until step 100k to allow generator to anneal on PSNR loss.
|
||||||
|
|
||||||
|
# Optimizer params
|
||||||
|
lr: !!float 2e-4
|
||||||
|
weight_decay: 0
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.99
|
||||||
|
|
||||||
|
injectors:
|
||||||
|
dgen_inj:
|
||||||
|
type: generator
|
||||||
|
generator: generator
|
||||||
|
grad: false
|
||||||
|
in: lq
|
||||||
|
out: dgen
|
||||||
|
|
||||||
|
losses:
|
||||||
|
gan_disc_img:
|
||||||
|
type: discriminator_gan
|
||||||
|
gan_type: ragan
|
||||||
|
weight: 1
|
||||||
|
real: hq
|
||||||
|
fake: dgen
|
||||||
|
|
||||||
|
generator:
|
||||||
|
training: generator
|
||||||
|
|
||||||
|
optimizer_params:
|
||||||
|
lr: !!float 2e-4
|
||||||
|
weight_decay: 0
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.99
|
||||||
|
|
||||||
|
injectors:
|
||||||
|
gen_inj:
|
||||||
|
type: generator
|
||||||
|
generator: generator
|
||||||
|
in: lq
|
||||||
|
out: gen
|
||||||
|
|
||||||
|
losses:
|
||||||
|
pix:
|
||||||
|
type: pix
|
||||||
|
weight: .05
|
||||||
|
criterion: l1
|
||||||
|
real: hq
|
||||||
|
fake: gen
|
||||||
|
feature:
|
||||||
|
type: feature
|
||||||
|
after: 80000 # Perceptual/"feature" loss doesn't turn on until step 80k.
|
||||||
|
which_model_F: vgg
|
||||||
|
criterion: l1
|
||||||
|
weight: 1
|
||||||
|
real: hq
|
||||||
|
fake: gen
|
||||||
|
gan_gen_img:
|
||||||
|
after: 100000
|
||||||
|
type: generator_gan
|
||||||
|
gan_type: ragan
|
||||||
|
weight: .02
|
||||||
|
discriminator: feature_discriminator
|
||||||
|
fake: gen
|
||||||
|
real: hq
|
||||||
|
|
||||||
|
train:
|
||||||
|
niter: 500000
|
||||||
|
warmup_iter: -1
|
||||||
|
mega_batch_factor: 1
|
||||||
|
val_freq: 2000
|
||||||
|
|
||||||
|
# LR scheduler options
|
||||||
|
default_lr_scheme: MultiStepLR
|
||||||
|
gen_lr_steps: [140000, 180000, 200000, 240000] # LR is halved at these steps. Don't do it until GAN is online.
|
||||||
|
lr_gamma: 0.5
|
||||||
|
|
||||||
|
eval:
|
||||||
|
output_state: gen
|
||||||
|
|
||||||
|
logger:
|
||||||
|
print_freq: 30
|
||||||
|
save_checkpoint_freq: 1000
|
||||||
|
visuals: [gen, hq, lq]
|
||||||
|
visual_debug_rate: 100
|
|
@ -1,12 +1,11 @@
|
||||||
#### general settings
|
|
||||||
name: train_div2k_srflow
|
name: train_div2k_srflow
|
||||||
use_tb_logger: true
|
|
||||||
model: extensibletrainer
|
model: extensibletrainer
|
||||||
scale: 4
|
scale: 4
|
||||||
gpu_ids: [0]
|
gpu_ids: [0]
|
||||||
fp16: false
|
fp16: false
|
||||||
start_step: -1
|
start_step: -1
|
||||||
checkpointing_enabled: true
|
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
|
||||||
|
use_tb_logger: true
|
||||||
wandb: false
|
wandb: false
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user