103 lines
2.0 KiB
YAML
103 lines
2.0 KiB
YAML
#### general settings
|
|
name: train_byol_segformer
|
|
use_tb_logger: true
|
|
model: extensibletrainer
|
|
distortion: sr
|
|
scale: 1
|
|
gpu_ids: [0]
|
|
fp16: false
|
|
start_step: -1
|
|
checkpointing_enabled: false
|
|
wandb: false
|
|
|
|
datasets:
|
|
train:
|
|
n_workers: 1
|
|
batch_size: 96
|
|
mode: byol_dataset
|
|
crop_size: 224
|
|
key1: hq
|
|
key2: hq
|
|
dataset:
|
|
mode: imagefolder
|
|
paths: <>
|
|
target_size: 224
|
|
scale: 1
|
|
fetch_alt_image: false
|
|
skip_lq: true
|
|
normalize: imagenet
|
|
|
|
networks:
|
|
generator:
|
|
type: generator
|
|
which_model_G: pixel_local_byol
|
|
image_size: 224
|
|
hidden_layer: tail
|
|
subnet:
|
|
which_model_G: segformer
|
|
|
|
#### path
|
|
path:
|
|
strict_load: true
|
|
#resume_state: <>
|
|
|
|
steps:
|
|
generator:
|
|
training: generator
|
|
optimizer: lars
|
|
optimizer_params:
|
|
# All parameters from appendix J of BYOL.
|
|
lr: .08 # From BYOL: LR=.2*<batch_size>/256
|
|
weight_decay: !!float 1.5e-6
|
|
lars_coefficient: .001
|
|
momentum: .9
|
|
|
|
injectors:
|
|
gen_inj:
|
|
type: generator
|
|
generator: generator
|
|
in: aug1
|
|
out: loss
|
|
|
|
losses:
|
|
byol_loss:
|
|
type: direct
|
|
key: loss
|
|
weight: 1
|
|
|
|
train:
|
|
warmup_iter: -1
|
|
mega_batch_factor: 2
|
|
val_freq: 1000
|
|
niter: 300000
|
|
|
|
# Default LR scheduler options
|
|
default_lr_scheme: CosineAnnealingLR_Restart
|
|
T_period: [120000, 120000, 120000]
|
|
warmup: 10000
|
|
eta_min: .01 # Unspecified by the paper..
|
|
restarts: [140000, 280000] # Paper says no re-starts, but this scheduler will add them automatically if we don't set them.
|
|
# likely I won't train this far.
|
|
restart_weights: [.5, .25]
|
|
|
|
|
|
eval:
|
|
output_state: loss
|
|
evaluators:
|
|
single_point_pair_contrastive_eval:
|
|
for: generator
|
|
type: single_point_pair_contrastive_eval
|
|
batch_size: 16
|
|
quantity: 96
|
|
similar_set_args:
|
|
path: <>
|
|
size: 256
|
|
dissimilar_set_args:
|
|
path: <>
|
|
size: 256
|
|
|
|
logger:
|
|
print_freq: 30
|
|
save_checkpoint_freq: 1000
|
|
visuals: [hq, aug1]
|
|
visual_debug_rate: 100 |