#### general settings name: train_imageset_byol use_tb_logger: true model: extensibletrainer scale: 1 gpu_ids: [0] fp16: false start_step: 0 checkpointing_enabled: true # <-- Highly recommended for single-GPU training. May not work in distributed settings. wandb: false datasets: train: n_workers: 4 batch_size: 256 # <-- BYOL trains on very large batch sizes. 256 was the smallest batch size possible before a # severe drop off in performance. Other parameters here are set to enable this to train on a # single 10GB GPU. mode: byol_dataset crop_size: 224 normalize: true key1: hq key2: hq dataset: mode: imagefolder paths: /content/imagenet # <-- Put your path here. Directory should be filled with square images. target_size: 224 scale: 1 skip_lq: true networks: generator: type: generator which_model_G: byol image_size: 256 subnet: which_model_G: resnet52 # <-- Specify your own network to pretrain here. pretrained: false hidden_layer: avgpool # <-- Specify a hidden layer from your network here. #### path path: #pretrain_model_generator: strict_load: true #resume_state: ../experiments/train_imageset_byol/training_state/0.state # <-- Set this to resume from a previous training state. steps: generator: training: generator optimizer: lars optimizer_params: # All parameters from appendix J of BYOL. lr: .2 # From BYOL paper: LR=.2*/256 weight_decay: !!float 1.5e-6 lars_coefficient: .001 momentum: .9 injectors: gen_inj: type: generator generator: generator in: [aug1, aug2] out: loss losses: byol_loss: type: direct key: loss weight: 1 train: niter: 500000 warmup_iter: -1 mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [8]. # Likewise, if you are running on a 24GB GPU, decrease this to [1] to improve batch stats. val_freq: 2000 # Default LR scheduler options default_lr_scheme: CosineAnnealingLR_Restart T_period: [120000, 120000, 120000] warmup: 10000 eta_min: .01 # Unspecified by the paper.. restarts: [140000, 280000] # Paper specifies a different, longer schedule that is not practical for anyone not using # 4x V100s+. Modify these parameters if you are. restart_weights: [.5, .25] eval: output_state: loss logger: print_freq: 30 save_checkpoint_freq: 1000 visuals: [hq, aug1, aug2] visual_debug_rate: 100