name: train_div2k_esrgan
model: extensibletrainer
scale: 4
gpu_ids: [0]
fp16: false
start_step: -1
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
use_tb_logger: true
wandb: false

datasets:
  train:
    n_workers: 2
    batch_size: 16
    name: div2k
    mode: single_image_extensible
    paths: /content/div2k   # <-- Put your path here.
    target_size: 128
    force_multiple: 1
    scale: 4
    strict: false
  val:
    name: val
    mode: fullimage
    dataroot_GT: /content/set14
    scale: 4

networks:
  generator:
    type: generator
    which_model_G: RRDBNet
    in_nc: 3
    out_nc: 3
    initial_stride: 1
    nf: 64
    nb: 23
    scale: 4
    blocks_per_checkpoint: 3

  feature_discriminator:
    type: discriminator
    which_model_D: discriminator_vgg_128_gn
    scale: 2
    nf: 64
    in_nc: 3
    image_size: 96

#### path
path:
  #pretrain_model_generator: <insert pretrained model path if desired>
  strict_load: true
  #resume_state: ../experiments/train_div2k_esrgan/training_state/0.state   # <-- Set this to resume from a previous training state.

steps:

  feature_discriminator:
    training: feature_discriminator
    after: 100000  # Discriminator doesn't "turn-on" until step 100k to allow generator to anneal on PSNR loss.

    # Optimizer params
    lr: !!float 2e-4
    weight_decay: 0
    beta1: 0.9
    beta2: 0.99

    injectors:
      # "image_patch" injectors support the translational loss below. You can remove them if you remove that loss.
      plq:
        type: image_patch
        patch_size: 24
        in: lq
        out: plq
      phq:
        type: image_patch
        patch_size: 96
        in: hq
        out: phq
      dgen_inj:
        type: generator
        generator: generator
        grad: false
        in: plq
        out: dgen

    losses:
      gan_disc_img:
        type: discriminator_gan
        gan_type: gan
        weight: 1
        #min_loss: .4
        noise: .004
        gradient_penalty: true
        real: phq
        fake: dgen
        
  generator:
    training: generator

    optimizer_params:
      lr: !!float 2e-4
      weight_decay: 0
      beta1: 0.9
      beta2: 0.99

    injectors:
      pglq:
        type: image_patch
        patch_size: 24
        in: lq
        out: pglq
      pghq:
        type: image_patch
        patch_size: 96
        in: hq
        out: pghq
      gen_inj:
        type: generator
        generator: generator
        in: pglq
        out: gen
        
    losses:
      pix:
        type: pix
        weight: .05
        criterion: l1
        real: pghq
        fake: gen
      feature:
        type: feature
        after: 80000  # Perceptual/"feature" loss doesn't turn on until step 80k.
        which_model_F: vgg
        criterion: l1
        weight: 1
        real: pghq
        fake: gen
      gan_gen_img:
        after: 100000
        type: generator_gan
        gan_type: gan
        weight: .02
        noise: .004
        discriminator: feature_discriminator
        fake: gen
        real: pghq
      # Translational loss <- not present in the original ESRGAN paper, but I find it reduces artifacts from the GAN.
      # Feel free to remove. The network will still train well.
      translational:
        type: translational
        after: 80000
        weight: 2
        criterion: l1
        generator: generator
        generator_output_index: 0
        detach_fake: false
        patch_size: 96
        overlap: 64
        real: gen
        fake: ['pglq']

train:
  niter: 500000
  warmup_iter: -1
  mega_batch_factor: 1
  val_freq: 2000

  # LR scheduler options
  default_lr_scheme: MultiStepLR
  gen_lr_steps: [140000, 180000, 200000, 240000]  # LR is halved at these steps. Don't do it until GAN is online.
  lr_gamma: 0.5

eval:
  output_state: gen

logger:
  print_freq: 30
  save_checkpoint_freq: 1000
  visuals: [gen, hq, pglq, pghq]
  visual_debug_rate: 100