85 lines
1.9 KiB
YAML
85 lines
1.9 KiB
YAML
|
#### general settings
|
||
|
name: train_div2k_byol
|
||
|
use_tb_logger: true
|
||
|
model: extensibletrainer
|
||
|
scale: 1
|
||
|
gpu_ids: [0]
|
||
|
fp16: false
|
||
|
start_step: 0
|
||
|
checkpointing_enabled: true # <-- Highly recommended for single-GPU training. Will not work with DDP.
|
||
|
wandb: false
|
||
|
|
||
|
datasets:
|
||
|
train:
|
||
|
n_workers: 4
|
||
|
batch_size: 32
|
||
|
mode: byol_dataset
|
||
|
crop_size: 256
|
||
|
normalize: true
|
||
|
dataset:
|
||
|
mode: imagefolder
|
||
|
paths: /content/div2k # <-- Put your path here. Note: full images.
|
||
|
target_size: 256
|
||
|
scale: 1
|
||
|
|
||
|
networks:
|
||
|
generator:
|
||
|
type: generator
|
||
|
which_model_G: byol
|
||
|
image_size: 256
|
||
|
subnet: # <-- Specify your own network to pretrain here.
|
||
|
which_model_G: spinenet
|
||
|
arch: 49
|
||
|
use_input_norm: true
|
||
|
|
||
|
hidden_layer: endpoint_convs.4.conv # <-- Specify a hidden layer from your network here.
|
||
|
|
||
|
#### path
|
||
|
path:
|
||
|
#pretrain_model_generator: <insert pretrained model path if desired>
|
||
|
strict_load: true
|
||
|
#resume_state: ../experiments/train_div2k_byol/training_state/0.state # <-- Set this to resume from a previous training state.
|
||
|
|
||
|
steps:
|
||
|
generator:
|
||
|
training: generator
|
||
|
|
||
|
optimizer_params:
|
||
|
# Optimizer params
|
||
|
lr: !!float 3e-4
|
||
|
weight_decay: 0
|
||
|
beta1: 0.9
|
||
|
beta2: 0.99
|
||
|
|
||
|
injectors:
|
||
|
gen_inj:
|
||
|
type: generator
|
||
|
generator: generator
|
||
|
in: [aug1, aug2]
|
||
|
out: loss
|
||
|
|
||
|
losses:
|
||
|
byol_loss:
|
||
|
type: direct
|
||
|
key: loss
|
||
|
weight: 1
|
||
|
|
||
|
train:
|
||
|
niter: 500000
|
||
|
warmup_iter: -1
|
||
|
mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
|
||
|
val_freq: 2000
|
||
|
|
||
|
# Default LR scheduler options
|
||
|
default_lr_scheme: MultiStepLR
|
||
|
gen_lr_steps: [50000, 100000, 150000, 200000]
|
||
|
lr_gamma: 0.5
|
||
|
|
||
|
eval:
|
||
|
output_state: loss
|
||
|
|
||
|
logger:
|
||
|
print_freq: 30
|
||
|
save_checkpoint_freq: 1000
|
||
|
visuals: [hq, lq, aug1, aug2]
|
||
|
visual_debug_rate: 100
|