#### general settings
name: train_div2k_byol
use_tb_logger: true
model: extensibletrainer
distortion: sr
scale: 1
gpu_ids: [0]
fp16: false
start_step: 0
checkpointing_enabled: true  # <-- Highly recommended for single-GPU training. Will not work with DDP.
wandb: false

datasets:
  train:
    n_workers: 4
    batch_size: 32
    mode: byol_dataset
    crop_size: 256
    normalize: true
    dataset:
      mode: imagefolder
      paths: /content/div2k   # <-- Put your path here. Note: full images.
      target_size: 256
      scale: 1

networks:
  generator:
    type: generator
    which_model_G: byol
    image_size: 256
    subnet:  # <-- Specify your own network to pretrain here.
      which_model_G: spinenet
      arch: 49
      use_input_norm: true

    hidden_layer: endpoint_convs.4.conv  # <-- Specify a hidden layer from your network here.

#### path
path:
  #pretrain_model_generator: <insert pretrained model path if desired>
  strict_load: true
  #resume_state: ../experiments/train_div2k_byol/training_state/0.state   # <-- Set this to resume from a previous training state.

steps:
  generator:
    training: generator

    optimizer_params:
      # Optimizer params
      lr: !!float 3e-4
      weight_decay: 0
      beta1: 0.9
      beta2: 0.99

    injectors:
      gen_inj:
        type: generator
        generator: generator
        in: [aug1, aug2]
        out: loss

    losses:
      byol_loss:
        type: direct
        key: loss
        weight: 1

train:
  niter: 500000
  warmup_iter: -1
  mega_batch_factor: 1    # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
  val_freq: 2000

  # Default LR scheduler options
  default_lr_scheme: MultiStepLR
  gen_lr_steps: [50000, 100000, 150000, 200000]
  lr_gamma: 0.5

eval:
  output_state: loss

logger:
  print_freq: 30
  save_checkpoint_freq: 1000
  visuals: [hq, lq, aug1, aug2]
  visual_debug_rate: 100