#### general settings
name: train_imageset_byol
use_tb_logger: true
model: extensibletrainer
scale: 1
gpu_ids: [0]
fp16: false
start_step: 0
checkpointing_enabled: true  # <-- Highly recommended for single-GPU training. May not work in distributed settings.
wandb: false

datasets:
  train:
    n_workers: 4
    batch_size: 256    # <-- BYOL trains on very large batch sizes. 256 was the smallest batch size possible before a
                       #     severe drop off in performance. Other parameters here are set to enable this to train on a
                       #     single 10GB GPU.
    mode: byol_dataset
    crop_size: 224
    normalize: true
    key1: hq
    key2: hq
    dataset:
      mode: imagefolder
      paths: /content/imagenet   # <-- Put your path here. Directory should be filled with square images.
      target_size: 224
      scale: 1
      skip_lq: true

networks:
  generator:
    type: generator
    which_model_G: byol
    image_size: 256
    subnet:
      which_model_G: resnet52  # <-- Specify your own network to pretrain here.
      pretrained: false
    hidden_layer: avgpool  # <-- Specify a hidden layer from your network here.

#### path
path:
  #pretrain_model_generator: <insert pretrained model path if desired>
  strict_load: true
  #resume_state: ../experiments/train_imageset_byol/training_state/0.state   # <-- Set this to resume from a previous training state.

steps:
  generator:
    training: generator

    optimizer: lars
    optimizer_params:
      # All parameters from appendix J of BYOL.
      lr: .2   # From BYOL paper: LR=.2*<batch_size>/256
      weight_decay: !!float 1.5e-6
      lars_coefficient: .001
      momentum: .9

    injectors:
      gen_inj:
        type: generator
        generator: generator
        in: [aug1, aug2]
        out: loss

    losses:
      byol_loss:
        type: direct
        key: loss
        weight: 1

train:
  niter: 500000
  warmup_iter: -1
  mega_batch_factor: 4    # <-- Gradient accumulation factor. If you are running OOM, increase this to [8].
                          #     Likewise, if you are running on a 24GB GPU, decrease this to [1] to improve batch stats.
  val_freq: 2000

  # Default LR scheduler options
  default_lr_scheme: CosineAnnealingLR_Restart
  T_period: [120000, 120000, 120000]
  warmup: 10000
  eta_min: .01  # Unspecified by the paper..
  restarts: [140000, 280000]  # Paper specifies a different, longer schedule that is not practical for anyone not using
                              # 4x V100s+. Modify these parameters if you are.
  restart_weights: [.5, .25]

eval:
  output_state: loss

logger:
  print_freq: 30
  save_checkpoint_freq: 1000
  visuals: [hq, aug1, aug2]
  visual_debug_rate: 100