DL-Art-School/recipes/esrgan/train_div2k_esrgan_reference.yml
2020-12-20 11:50:31 -07:00

142 lines
3.0 KiB
YAML

# This is a config file that trains ESRGAN using the dynamics spelled out in the paper with no modifications.
# This has not been trained to completion in some time. I make no guarantees that it will work well.
name: train_div2k_esrgan_reference
model: extensibletrainer
scale: 4
gpu_ids: [0]
fp16: false
start_step: -1
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
use_tb_logger: true
wandb: false
datasets:
train:
n_workers: 2
batch_size: 16
name: div2k
mode: single_image_extensible
paths: /content/div2k # <-- Put your path here.
target_size: 128
force_multiple: 1
scale: 4
strict: false
val:
name: val
mode: fullimage
dataroot_GT: /content/set14
scale: 4
networks:
generator:
type: generator
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
initial_stride: 1
nf: 64
nb: 23
scale: 4
blocks_per_checkpoint: 3
feature_discriminator:
type: discriminator
which_model_D: discriminator_vgg_128
scale: 2
nf: 64
in_nc: 3
#### path
path:
#pretrain_model_generator: <insert pretrained model path if desired>
strict_load: true
#resume_state: ../experiments/train_div2k_esrgan/training_state/0.state # <-- Set this to resume from a previous training state.
steps:
feature_discriminator:
training: feature_discriminator
after: 100000 # Discriminator doesn't "turn-on" until step 100k to allow generator to anneal on PSNR loss.
# Optimizer params
lr: !!float 2e-4
weight_decay: 0
beta1: 0.9
beta2: 0.99
injectors:
dgen_inj:
type: generator
generator: generator
grad: false
in: lq
out: dgen
losses:
gan_disc_img:
type: discriminator_gan
gan_type: ragan
weight: 1
real: hq
fake: dgen
generator:
training: generator
optimizer_params:
lr: !!float 2e-4
weight_decay: 0
beta1: 0.9
beta2: 0.99
injectors:
gen_inj:
type: generator
generator: generator
in: lq
out: gen
losses:
pix:
type: pix
weight: .05
criterion: l1
real: hq
fake: gen
feature:
type: feature
after: 80000 # Perceptual/"feature" loss doesn't turn on until step 80k.
which_model_F: vgg
criterion: l1
weight: 1
real: hq
fake: gen
gan_gen_img:
after: 100000
type: generator_gan
gan_type: ragan
weight: .02
discriminator: feature_discriminator
fake: gen
real: hq
train:
niter: 500000
warmup_iter: -1
mega_batch_factor: 1
val_freq: 2000
# LR scheduler options
default_lr_scheme: MultiStepLR
gen_lr_steps: [140000, 180000, 200000, 240000] # LR is halved at these steps. Don't do it until GAN is online.
lr_gamma: 0.5
eval:
output_state: gen
logger:
print_freq: 30
save_checkpoint_freq: 1000
visuals: [gen, hq, lq]
visual_debug_rate: 100