DL-Art-School/recipes/diffusion/train_ddpm_unet.yml
2021-06-15 17:14:37 -06:00

150 lines
4.3 KiB
YAML

name: train_unet_diffusion
use_tb_logger: true
model: extensibletrainer
scale: 1
gpu_ids: [0]
start_step: -1
checkpointing_enabled: true # If using the UNet architecture, this is pretty much required.
fp16: false
wandb: false # Set to true to enable wandb logging.
force_start_step: -1
datasets:
train:
name: imgset5
n_workers: 4
batch_size: 256 # The OpenAI paper uses this batch size for 256px generation. The UNet model uses attention, which benefits from large batch sizes.
mode: imagefolder
rgb_n1_to_1: true
paths: <insert path to a folder full of 256x256 tiled images here>
target_size: 256
scale: 2
fixed_corruptions: [ jpeg-broad, gaussian_blur ] # This model is trained to correct JPEG artifacts and blurring.
random_corruptions: [ none ]
num_corrupts_per_image: 1
corruption_blur_scale: 1
corrupt_before_downsize: false
networks:
generator:
type: generator
which_model_G: unet_diffusion
args:
image_size: 256
in_channels: 3
num_corruptions: 2
model_channels: 192
out_channels: 6
num_res_blocks: 2
attention_resolutions: [8,16]
dropout: 0
channel_mult: [1,1,2,2,4,4] # These will need to be reduced if you lower the operating resolution.
num_heads: 4
num_heads_upsample: -1
use_scale_shift_norm: true
#### path
path:
#pretrain_model_generator: <Insert pretrained generator here>
strict_load: true
#resume_state: <Insert resume training_state here to resume existing training>
steps:
generator:
training: generator
optimizer: adamw
optimizer_params:
lr: !!float 3e-4 # Hyperparameters from OpenAI paper.
weight_decay: 0
beta1: 0.9
beta2: 0.9999
injectors:
diffusion:
type: gaussian_diffusion
in: hq
generator: generator
beta_schedule:
schedule_name: linear
num_diffusion_timesteps: 4000
diffusion_args:
model_mean_type: epsilon
model_var_type: learned_range
loss_type: mse
sampler_type: uniform
model_input_keys:
low_res: lq
corruption_factor: corruption_entropy
out: loss
out_key_vb_loss: vb_loss
out_key_x_start: x_start_pred
losses:
diffusion_loss:
type: direct
weight: 1
key: loss
var_loss:
type: direct
weight: 1
key: vb_loss
train:
niter: 500000
warmup_iter: -1
mega_batch_factor: 32 # This is massive. Expect ~60sec/step on a RTX3090 at 90%+ memory utilization. I recommend using multiple GPUs to train this network.
ema_rate: .999
val_freq: 500
default_lr_scheme: MultiStepLR
gen_lr_steps: [ 50000, 100000, 150000 ]
lr_gamma: 0.5
eval:
evaluators:
# Validation for this network is a special FID computation that compares the full resolution images from the specified
# dataset to the same images, downsampled and corrupted then fed through the network.
fid:
type: sr_diffusion_fid
for: generator # Unused for this evaluator.
batch_size: 8
dataset:
name: sr_fid_set
mode: imagefolder
rgb_n1_to_1: true
paths: <insert path to a folder of 128-512 validation images here, drawn from your dataset>
target_size: 256
scale: 2
fixed_corruptions: [ jpeg-broad, gaussian_blur ]
random_corruptions: [ none ]
num_corrupts_per_image: 1
corruption_blur_scale: 1
corrupt_before_downsize: false
random_seed: 1234
diffusion_params:
type: gaussian_diffusion_inference
generator: generator
use_ema_model: true
output_batch_size: 8
output_scale_factor: 2
respaced_timestep_spacing: 50
undo_n1_to_1: true
beta_schedule:
schedule_name: linear
num_diffusion_timesteps: 4000
diffusion_args:
model_mean_type: epsilon
model_var_type: learned_range
loss_type: mse
model_input_keys:
low_res: lq
corruption_factor: corruption_entropy
out: sample # Unused
logger:
print_freq: 30
save_checkpoint_freq: 500
visuals: [x_start_pred, hq, lq]
visual_debug_rate: 500
reverse_n1_to_1: true
reverse_imagenet_norm: false