DL-Art-School/recipes/esrgan/train_div2k_esrgan.yml
2020-12-20 11:50:31 -07:00

179 lines
3.9 KiB
YAML

name: train_div2k_esrgan
model: extensibletrainer
scale: 4
gpu_ids: [0]
fp16: false
start_step: -1
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
use_tb_logger: true
wandb: false
datasets:
train:
n_workers: 2
batch_size: 16
name: div2k
mode: single_image_extensible
paths: /content/div2k # <-- Put your path here.
target_size: 128
force_multiple: 1
scale: 4
strict: false
val:
name: val
mode: fullimage
dataroot_GT: /content/set14
scale: 4
networks:
generator:
type: generator
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
initial_stride: 1
nf: 64
nb: 23
scale: 4
blocks_per_checkpoint: 3
feature_discriminator:
type: discriminator
which_model_D: discriminator_vgg_128_gn
scale: 2
nf: 64
in_nc: 3
image_size: 96
#### path
path:
#pretrain_model_generator: <insert pretrained model path if desired>
strict_load: true
#resume_state: ../experiments/train_div2k_esrgan/training_state/0.state # <-- Set this to resume from a previous training state.
steps:
feature_discriminator:
training: feature_discriminator
after: 100000 # Discriminator doesn't "turn-on" until step 100k to allow generator to anneal on PSNR loss.
# Optimizer params
lr: !!float 2e-4
weight_decay: 0
beta1: 0.9
beta2: 0.99
injectors:
# "image_patch" injectors support the translational loss below. You can remove them if you remove that loss.
plq:
type: image_patch
patch_size: 24
in: lq
out: plq
phq:
type: image_patch
patch_size: 96
in: hq
out: phq
dgen_inj:
type: generator
generator: generator
grad: false
in: plq
out: dgen
losses:
gan_disc_img:
type: discriminator_gan
gan_type: gan
weight: 1
#min_loss: .4
noise: .004
gradient_penalty: true
real: phq
fake: dgen
generator:
training: generator
optimizer_params:
lr: !!float 2e-4
weight_decay: 0
beta1: 0.9
beta2: 0.99
injectors:
pglq:
type: image_patch
patch_size: 24
in: lq
out: pglq
pghq:
type: image_patch
patch_size: 96
in: hq
out: pghq
gen_inj:
type: generator
generator: generator
in: pglq
out: gen
losses:
pix:
type: pix
weight: .05
criterion: l1
real: pghq
fake: gen
feature:
type: feature
after: 80000 # Perceptual/"feature" loss doesn't turn on until step 80k.
which_model_F: vgg
criterion: l1
weight: 1
real: pghq
fake: gen
gan_gen_img:
after: 100000
type: generator_gan
gan_type: gan
weight: .02
noise: .004
discriminator: feature_discriminator
fake: gen
real: pghq
# Translational loss <- not present in the original ESRGAN paper, but I find it reduces artifacts from the GAN.
# Feel free to remove. The network will still train well.
translational:
type: translational
after: 80000
weight: 2
criterion: l1
generator: generator
generator_output_index: 0
detach_fake: false
patch_size: 96
overlap: 64
real: gen
fake: ['pglq']
train:
niter: 500000
warmup_iter: -1
mega_batch_factor: 1
val_freq: 2000
# LR scheduler options
default_lr_scheme: MultiStepLR
gen_lr_steps: [140000, 180000, 200000, 240000] # LR is halved at these steps. Don't do it until GAN is online.
lr_gamma: 0.5
eval:
output_state: gen
logger:
print_freq: 30
save_checkpoint_freq: 1000
visuals: [gen, hq, pglq, pghq]
visual_debug_rate: 100