# This is a config file that trains ESRGAN using the dynamics spelled out in the paper with no modifications.
# This has not been trained to completion in some time. I make no guarantees that it will work well.

name: train_div2k_esrgan_reference
model: extensibletrainer
scale: 4
gpu_ids: [0]
fp16: false
start_step: -1
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
use_tb_logger: true
wandb: false

datasets:
  train:
    n_workers: 2
    batch_size: 16
    name: div2k
    mode: single_image_extensible
    paths: /content/div2k   # <-- Put your path here.
    target_size: 128
    force_multiple: 1
    scale: 4
    strict: false
  val:
    name: val
    mode: fullimage
    dataroot_GT: /content/set14
    scale: 4

networks:
  generator:
    type: generator
    which_model_G: RRDBNet
    in_nc: 3
    out_nc: 3
    initial_stride: 1
    nf: 64
    nb: 23
    scale: 4
    blocks_per_checkpoint: 3

  feature_discriminator:
    type: discriminator
    which_model_D: discriminator_vgg_128
    scale: 2
    nf: 64
    in_nc: 3

#### path
path:
  #pretrain_model_generator: <insert pretrained model path if desired>
  strict_load: true
  #resume_state: ../experiments/train_div2k_esrgan/training_state/0.state   # <-- Set this to resume from a previous training state.

steps:

  feature_discriminator:
    training: feature_discriminator
    after: 100000  # Discriminator doesn't "turn-on" until step 100k to allow generator to anneal on PSNR loss.

    # Optimizer params
    lr: !!float 2e-4
    weight_decay: 0
    beta1: 0.9
    beta2: 0.99

    injectors:
      dgen_inj:
        type: generator
        generator: generator
        grad: false
        in: lq
        out: dgen

    losses:
      gan_disc_img:
        type: discriminator_gan
        gan_type: ragan
        weight: 1
        real: hq
        fake: dgen
        
  generator:
    training: generator

    optimizer_params:
      lr: !!float 2e-4
      weight_decay: 0
      beta1: 0.9
      beta2: 0.99

    injectors:
      gen_inj:
        type: generator
        generator: generator
        in: lq
        out: gen
        
    losses:
      pix:
        type: pix
        weight: .05
        criterion: l1
        real: hq
        fake: gen
      feature:
        type: feature
        after: 80000  # Perceptual/"feature" loss doesn't turn on until step 80k.
        which_model_F: vgg
        criterion: l1
        weight: 1
        real: hq
        fake: gen
      gan_gen_img:
        after: 100000
        type: generator_gan
        gan_type: ragan
        weight: .02
        discriminator: feature_discriminator
        fake: gen
        real: hq

train:
  niter: 500000
  warmup_iter: -1
  mega_batch_factor: 1
  val_freq: 2000

  # LR scheduler options
  default_lr_scheme: MultiStepLR
  gen_lr_steps: [140000, 180000, 200000, 240000]  # LR is halved at these steps. Don't do it until GAN is online.
  lr_gamma: 0.5

eval:
  output_state: gen

logger:
  print_freq: 30
  save_checkpoint_freq: 1000
  visuals: [gen, hq, lq]
  visual_debug_rate: 100