From a947f064cc307cf92089a3a3c40a0ef9bd86ece6 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 24 Dec 2020 09:35:03 -0700 Subject: [PATCH] Update BYOL docs --- recipes/byol/README.md | 4 +-- recipes/byol/train_div2k_byol.yml | 55 ++++++++++++++++++------------- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/recipes/byol/README.md b/recipes/byol/README.md index 4559b7fa..6decc3fb 100644 --- a/recipes/byol/README.md +++ b/recipes/byol/README.md @@ -30,8 +30,8 @@ Run the trainer by: `python train.py -opt train_div2k_byol.yml` -BYOL is data hungry, as most unsupervised training methods are. You'll definitely want to provide -your own dataset - DIV2K is here as an example only. +BYOL is data hungry, as most unsupervised training methods are. If you're providing your own dataset, make sure it is +the hundreds of K-images or more! ## Using your own model diff --git a/recipes/byol/train_div2k_byol.yml b/recipes/byol/train_div2k_byol.yml index bac72712..9d95db2f 100644 --- a/recipes/byol/train_div2k_byol.yml +++ b/recipes/byol/train_div2k_byol.yml @@ -1,55 +1,59 @@ #### general settings -name: train_div2k_byol +name: train_imageset_byol use_tb_logger: true model: extensibletrainer scale: 1 gpu_ids: [0] fp16: false start_step: 0 -checkpointing_enabled: true # <-- Highly recommended for single-GPU training. Will not work with DDP. +checkpointing_enabled: true # <-- Highly recommended for single-GPU training. May not work in distributed settings. wandb: false datasets: train: n_workers: 4 - batch_size: 32 + batch_size: 256 # <-- BYOL trains on very large batch sizes. 256 was the smallest batch size possible before a + # severe drop off in performance. Other parameters here are set to enable this to train on a + # single 10GB GPU. mode: byol_dataset - crop_size: 256 + crop_size: 224 normalize: true + key1: hq + key2: hq dataset: mode: imagefolder - paths: /content/div2k # <-- Put your path here. Note: full images. - target_size: 256 + paths: /content/imagenet # <-- Put your path here. Directory should be filled with square images. + target_size: 224 scale: 1 + skip_lq: true networks: generator: type: generator which_model_G: byol image_size: 256 - subnet: # <-- Specify your own network to pretrain here. - which_model_G: spinenet - arch: 49 - use_input_norm: true - - hidden_layer: endpoint_convs.4.conv # <-- Specify a hidden layer from your network here. + subnet: + which_model_G: resnet52 # <-- Specify your own network to pretrain here. + pretrained: false + hidden_layer: avgpool # <-- Specify a hidden layer from your network here. #### path path: #pretrain_model_generator: strict_load: true - #resume_state: ../experiments/train_div2k_byol/training_state/0.state # <-- Set this to resume from a previous training state. + #resume_state: ../experiments/train_imageset_byol/training_state/0.state # <-- Set this to resume from a previous training state. steps: generator: training: generator + optimizer: lars optimizer_params: - # Optimizer params - lr: !!float 3e-4 - weight_decay: 0 - beta1: 0.9 - beta2: 0.99 + # All parameters from appendix J of BYOL. + lr: .2 # From BYOL paper: LR=.2*/256 + weight_decay: !!float 1.5e-6 + lars_coefficient: .001 + momentum: .9 injectors: gen_inj: @@ -67,13 +71,18 @@ steps: train: niter: 500000 warmup_iter: -1 - mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8]. + mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [8]. + # Likewise, if you are running on a 24GB GPU, decrease this to [1] to improve batch stats. val_freq: 2000 # Default LR scheduler options - default_lr_scheme: MultiStepLR - gen_lr_steps: [50000, 100000, 150000, 200000] - lr_gamma: 0.5 + default_lr_scheme: CosineAnnealingLR_Restart + T_period: [120000, 120000, 120000] + warmup: 10000 + eta_min: .01 # Unspecified by the paper.. + restarts: [140000, 280000] # Paper specifies a different, longer schedule that is not practical for anyone not using + # 4x V100s+. Modify these parameters if you are. + restart_weights: [.5, .25] eval: output_state: loss @@ -81,5 +90,5 @@ eval: logger: print_freq: 30 save_checkpoint_freq: 1000 - visuals: [hq, lq, aug1, aug2] + visuals: [hq, aug1, aug2] visual_debug_rate: 100 \ No newline at end of file