name: train_div2k_srflow
model: extensibletrainer
scale: 4
gpu_ids: [0]
fp16: false
start_step: -1
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
use_tb_logger: true
wandb: false

datasets:
  train:
    n_workers: 4
    batch_size: 32
    name: div2k
    mode: single_image_extensible
    paths: /content/div2k   # <-- Put your path here.
    target_size: 160
    force_multiple: 1
    scale: 4
    num_corrupts_per_image: 0

networks:
  generator:
    type: generator
    which_model_G: srflow
    nf: 64
    nb: 23
    K: 16
    scale: 4
    initial_stride: 1
    flow_scale: 4
    train_RRDB: false    # <-- Start false. After some time, ~20k-50k steps, set to true. TODO: automate this.
    pretrain_rrdb: ../experiments/pretrained_rrdb.pth  # <-- Insert path to your pretrained RRDB here.

    flow:
      patch_size: 160
      K: 16
      L: 3
      noInitialInj: true
      coupling: CondAffineSeparatedAndCond
      additionalFlowNoAffine: 2
      split:
        enable: true
      fea_up0: true
      fea_up-1: true
      stackRRDB:
        blocks: [ 1, 8, 15, 22 ]
        concat: true
      gaussian_loss_weight: 1

#### path
path:
  #pretrain_model_generator: <insert pretrained model path if desired>
  strict_load: true
  #resume_state: ../experiments/train_div2k_srflow/training_state/0.state   # <-- Set this to resume from a previous training state.

steps:
  generator:
    training: generator

    optimizer_params:
      lr: !!float 2e-4
      weight_decay: 0
      beta1: 0.9
      beta2: 0.99

    injectors:
      z_inj:
        type: generator
        generator: generator
        in: [hq, lq, None, None, False]
        out: [z, nll]
      # This is computed solely for visual_dbg - that is, to see what your model is actually doing.
      gen_inj:
        every: 50
        type: generator
        generator: generator
        in: [None, lq, None, .4, True]
        out: [gen]

    losses:
      log_likelihood:
        type: direct
        key: nll
        weight: 1

train:
  niter: 500000
  warmup_iter: -1
  mega_batch_factor: 1    # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
  val_freq: 1000

  # Default LR scheduler options
  default_lr_scheme: MultiStepLR
  gen_lr_steps: [20000, 40000, 80000, 100000, 140000, 180000]
  lr_gamma: 0.5

eval:
  evaluators:
    # This is the best metric I have come up with for monitoring the training progress of srflow networks. You should
    # feed this evaluator a random set of images from your target distribution.
    gaussian:
      for: generator
      type: flownet_gaussian
      batch_size: 2
      dataset:
        paths: /content/random_100_images
        target_size: 512
        force_multiple: 1
        scale: 4
        eval: False
        num_corrupts_per_image: 0
        corruption_blur_scale: 1
  output_state: eval_gen

logger:
  print_freq: 30
  save_checkpoint_freq: 500
  visuals: [gen, hq, lq]
  visual_debug_rate: 50