#### general settings
name: train_byol_segformer
use_tb_logger: true
model: extensibletrainer
distortion: sr
scale: 1
gpu_ids: [0]
fp16: false
start_step: -1
checkpointing_enabled: false
wandb: false

datasets:
  train:
    n_workers: 1
    batch_size: 96
    mode: byol_dataset
    crop_size: 224
    key1: hq
    key2: hq
    dataset:
      mode: imagefolder
      paths: <>
      target_size: 224
      scale: 1
      fetch_alt_image: false
      skip_lq: true
      normalize: imagenet

networks:
  generator:
    type: generator
    which_model_G: pixel_local_byol
    image_size: 224
    hidden_layer: tail
    subnet:
      which_model_G: segformer

#### path
path:
  strict_load: true
  #resume_state: <>

steps:
  generator:
    training: generator
    optimizer: lars
    optimizer_params:
      # All parameters from appendix J of BYOL.
      lr: .08   # From BYOL: LR=.2*<batch_size>/256
      weight_decay: !!float 1.5e-6
      lars_coefficient: .001
      momentum: .9

    injectors:
      gen_inj:
        type: generator
        generator: generator
        in: aug1
        out: loss

    losses:
      byol_loss:
        type: direct
        key: loss
        weight: 1

train:
  warmup_iter: -1
  mega_batch_factor: 2
  val_freq: 1000
  niter: 300000

  # Default LR scheduler options
  default_lr_scheme: CosineAnnealingLR_Restart
  T_period: [120000, 120000, 120000]
  warmup: 10000
  eta_min: .01  # Unspecified by the paper..
  restarts: [140000, 280000]  # Paper says no re-starts, but this scheduler will add them automatically if we don't set them.
                              # likely I won't train this far.
  restart_weights: [.5, .25]


eval:
  output_state: loss
  evaluators:
    single_point_pair_contrastive_eval:
      for: generator
      type: single_point_pair_contrastive_eval
      batch_size: 16
      quantity: 96
      similar_set_args:
        path: <>
        size: 256
      dissimilar_set_args:
        path: <>
        size: 256

logger:
  print_freq: 30
  save_checkpoint_freq: 1000
  visuals: [hq, aug1]
  visual_debug_rate: 100