Update BYOL docs

This commit is contained in:
James Betker 2020-12-24 09:35:03 -07:00
parent 29db7c7a02
commit a947f064cc
2 changed files with 34 additions and 25 deletions

View File

@ -30,8 +30,8 @@ Run the trainer by:
`python train.py -opt train_div2k_byol.yml` `python train.py -opt train_div2k_byol.yml`
BYOL is data hungry, as most unsupervised training methods are. You'll definitely want to provide BYOL is data hungry, as most unsupervised training methods are. If you're providing your own dataset, make sure it is
your own dataset - DIV2K is here as an example only. the hundreds of K-images or more!
## Using your own model ## Using your own model

View File

@ -1,55 +1,59 @@
#### general settings #### general settings
name: train_div2k_byol name: train_imageset_byol
use_tb_logger: true use_tb_logger: true
model: extensibletrainer model: extensibletrainer
scale: 1 scale: 1
gpu_ids: [0] gpu_ids: [0]
fp16: false fp16: false
start_step: 0 start_step: 0
checkpointing_enabled: true # <-- Highly recommended for single-GPU training. Will not work with DDP. checkpointing_enabled: true # <-- Highly recommended for single-GPU training. May not work in distributed settings.
wandb: false wandb: false
datasets: datasets:
train: train:
n_workers: 4 n_workers: 4
batch_size: 32 batch_size: 256 # <-- BYOL trains on very large batch sizes. 256 was the smallest batch size possible before a
# severe drop off in performance. Other parameters here are set to enable this to train on a
# single 10GB GPU.
mode: byol_dataset mode: byol_dataset
crop_size: 256 crop_size: 224
normalize: true normalize: true
key1: hq
key2: hq
dataset: dataset:
mode: imagefolder mode: imagefolder
paths: /content/div2k # <-- Put your path here. Note: full images. paths: /content/imagenet # <-- Put your path here. Directory should be filled with square images.
target_size: 256 target_size: 224
scale: 1 scale: 1
skip_lq: true
networks: networks:
generator: generator:
type: generator type: generator
which_model_G: byol which_model_G: byol
image_size: 256 image_size: 256
subnet: # <-- Specify your own network to pretrain here. subnet:
which_model_G: spinenet which_model_G: resnet52 # <-- Specify your own network to pretrain here.
arch: 49 pretrained: false
use_input_norm: true hidden_layer: avgpool # <-- Specify a hidden layer from your network here.
hidden_layer: endpoint_convs.4.conv # <-- Specify a hidden layer from your network here.
#### path #### path
path: path:
#pretrain_model_generator: <insert pretrained model path if desired> #pretrain_model_generator: <insert pretrained model path if desired>
strict_load: true strict_load: true
#resume_state: ../experiments/train_div2k_byol/training_state/0.state # <-- Set this to resume from a previous training state. #resume_state: ../experiments/train_imageset_byol/training_state/0.state # <-- Set this to resume from a previous training state.
steps: steps:
generator: generator:
training: generator training: generator
optimizer: lars
optimizer_params: optimizer_params:
# Optimizer params # All parameters from appendix J of BYOL.
lr: !!float 3e-4 lr: .2 # From BYOL paper: LR=.2*<batch_size>/256
weight_decay: 0 weight_decay: !!float 1.5e-6
beta1: 0.9 lars_coefficient: .001
beta2: 0.99 momentum: .9
injectors: injectors:
gen_inj: gen_inj:
@ -67,13 +71,18 @@ steps:
train: train:
niter: 500000 niter: 500000
warmup_iter: -1 warmup_iter: -1
mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8]. mega_batch_factor: 4 # <-- Gradient accumulation factor. If you are running OOM, increase this to [8].
# Likewise, if you are running on a 24GB GPU, decrease this to [1] to improve batch stats.
val_freq: 2000 val_freq: 2000
# Default LR scheduler options # Default LR scheduler options
default_lr_scheme: MultiStepLR default_lr_scheme: CosineAnnealingLR_Restart
gen_lr_steps: [50000, 100000, 150000, 200000] T_period: [120000, 120000, 120000]
lr_gamma: 0.5 warmup: 10000
eta_min: .01 # Unspecified by the paper..
restarts: [140000, 280000] # Paper specifies a different, longer schedule that is not practical for anyone not using
# 4x V100s+. Modify these parameters if you are.
restart_weights: [.5, .25]
eval: eval:
output_state: loss output_state: loss
@ -81,5 +90,5 @@ eval:
logger: logger:
print_freq: 30 print_freq: 30
save_checkpoint_freq: 1000 save_checkpoint_freq: 1000
visuals: [hq, lq, aug1, aug2] visuals: [hq, aug1, aug2]
visual_debug_rate: 100 visual_debug_rate: 100