DL-Art-School/recipes/ddpm/train_ddpm_rrdb.yml
James Betker bf811f80c1 GD mods & fixes
- Report variational loss separately
- Report model prediction from injector
- Log these things
- Use respacing like guided diffusion
2021-06-04 17:13:16 -06:00

109 lines
2.7 KiB
YAML

#### general settings
name: train_imgset_rrdb_diffusion
model: extensibletrainer
scale: 1
gpu_ids: [0]
start_step: -1
checkpointing_enabled: true
fp16: false
use_tb_logger: true
wandb: false
datasets:
train:
n_workers: 4
batch_size: 32
name: div2k
mode: single_image_extensible
paths: /content/div2k # <-- Put your path here.
target_size: 128
force_multiple: 1
scale: 4
num_corrupts_per_image: 0
networks:
generator:
type: generator
which_model_G: rrdb_diffusion
args:
in_channels: 6
out_channels: 6
num_blocks: 10
#### path
path:
#pretrain_model_generator: <insert pretrained model path if desired>
strict_load: true
#resume_state: ../experiments/train_imgset_rrdb_diffusion/training_state/0.state # <-- Set this to resume from a previous training state.
steps:
generator:
training: generator
optimizer_params:
lr: !!float 3e-4
weight_decay: !!float 1e-2
beta1: 0.9
beta2: 0.9999
injectors:
# "Do it all injector": produces a reverse prediction and calculates losses on it.
diffusion:
type: gaussian_diffusion
in: hq
generator: generator
beta_schedule:
schedule_name: linear
num_diffusion_timesteps: 4000
diffusion_args:
model_mean_type: epsilon
model_var_type: learned_range
loss_type: mse
sampler_type: uniform
model_input_keys:
low_res: lq
out: loss
# Injector for visualizing what your network is doing (every 500 steps)
visual_debug:
every: 500
type: gaussian_diffusion_inference
generator: generator
output_shape: [8,3,128,128] # Change "8" to your desired output batch size.
beta_schedule:
schedule_name: linear
num_diffusion_timesteps: 500 # Change higher (up to training steps) for improved quality. Lower for faster speed.
diffusion_args:
model_mean_type: epsilon
model_var_type: learned_range
loss_type: mse
model_input_keys:
low_res: lq
out: sample
losses:
diffusion_loss:
type: direct
weight: 1
key: loss
train:
niter: 500000
warmup_iter: -1
mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
val_freq: 4000
# Default LR scheduler options
default_lr_scheme: CosineAnnealingLR_Restart
T_period: [ 200000, 200000 ]
warmup: 0
eta_min: !!float 1e-7
restarts: [ 200000, 400000 ]
restart_weights: [ .5, .5 ]
logger:
print_freq: 30
save_checkpoint_freq: 2000
visuals: [sample, hq, lq]
visual_debug_rate: 500
reverse_n1_to_1: true