forked from mrq/DL-Art-School
150 lines
4.3 KiB
YAML
150 lines
4.3 KiB
YAML
name: train_unet_diffusion
|
|
use_tb_logger: true
|
|
model: extensibletrainer
|
|
scale: 1
|
|
gpu_ids: [0]
|
|
start_step: -1
|
|
checkpointing_enabled: true # If using the UNet architecture, this is pretty much required.
|
|
fp16: false
|
|
wandb: false # Set to true to enable wandb logging.
|
|
force_start_step: -1
|
|
|
|
datasets:
|
|
train:
|
|
name: imgset5
|
|
n_workers: 4
|
|
batch_size: 256 # The OpenAI paper uses this batch size for 256px generation. The UNet model uses attention, which benefits from large batch sizes.
|
|
mode: imagefolder
|
|
rgb_n1_to_1: true
|
|
paths: <insert path to a folder full of 256x256 tiled images here>
|
|
target_size: 256
|
|
scale: 2
|
|
fixed_corruptions: [ jpeg-broad, gaussian_blur ] # This model is trained to correct JPEG artifacts and blurring.
|
|
random_corruptions: [ none ]
|
|
num_corrupts_per_image: 1
|
|
corruption_blur_scale: 1
|
|
corrupt_before_downsize: false
|
|
|
|
networks:
|
|
generator:
|
|
type: generator
|
|
which_model_G: unet_diffusion
|
|
args:
|
|
image_size: 256
|
|
in_channels: 3
|
|
num_corruptions: 2
|
|
model_channels: 192
|
|
out_channels: 6
|
|
num_res_blocks: 2
|
|
attention_resolutions: [8,16]
|
|
dropout: 0
|
|
channel_mult: [1,1,2,2,4,4] # These will need to be reduced if you lower the operating resolution.
|
|
num_heads: 4
|
|
num_heads_upsample: -1
|
|
use_scale_shift_norm: true
|
|
|
|
#### path
|
|
path:
|
|
#pretrain_model_generator: <Insert pretrained generator here>
|
|
strict_load: true
|
|
#resume_state: <Insert resume training_state here to resume existing training>
|
|
|
|
steps:
|
|
generator:
|
|
training: generator
|
|
|
|
optimizer: adamw
|
|
optimizer_params:
|
|
lr: !!float 3e-4 # Hyperparameters from OpenAI paper.
|
|
weight_decay: 0
|
|
beta1: 0.9
|
|
beta2: 0.9999
|
|
|
|
injectors:
|
|
diffusion:
|
|
type: gaussian_diffusion
|
|
in: hq
|
|
generator: generator
|
|
beta_schedule:
|
|
schedule_name: linear
|
|
num_diffusion_timesteps: 4000
|
|
diffusion_args:
|
|
model_mean_type: epsilon
|
|
model_var_type: learned_range
|
|
loss_type: mse
|
|
sampler_type: uniform
|
|
model_input_keys:
|
|
low_res: lq
|
|
corruption_factor: corruption_entropy
|
|
out: loss
|
|
out_key_vb_loss: vb_loss
|
|
out_key_x_start: x_start_pred
|
|
losses:
|
|
diffusion_loss:
|
|
type: direct
|
|
weight: 1
|
|
key: loss
|
|
var_loss:
|
|
type: direct
|
|
weight: 1
|
|
key: vb_loss
|
|
|
|
train:
|
|
niter: 500000
|
|
warmup_iter: -1
|
|
mega_batch_factor: 32 # This is massive. Expect ~60sec/step on a RTX3090 at 90%+ memory utilization. I recommend using multiple GPUs to train this network.
|
|
ema_rate: .999
|
|
val_freq: 500
|
|
|
|
default_lr_scheme: MultiStepLR
|
|
gen_lr_steps: [ 50000, 100000, 150000 ]
|
|
lr_gamma: 0.5
|
|
|
|
eval:
|
|
evaluators:
|
|
# Validation for this network is a special FID computation that compares the full resolution images from the specified
|
|
# dataset to the same images, downsampled and corrupted then fed through the network.
|
|
fid:
|
|
type: sr_diffusion_fid
|
|
for: generator # Unused for this evaluator.
|
|
batch_size: 8
|
|
dataset:
|
|
name: sr_fid_set
|
|
mode: imagefolder
|
|
rgb_n1_to_1: true
|
|
paths: <insert path to a folder of 128-512 validation images here, drawn from your dataset>
|
|
target_size: 256
|
|
scale: 2
|
|
fixed_corruptions: [ jpeg-broad, gaussian_blur ]
|
|
random_corruptions: [ none ]
|
|
num_corrupts_per_image: 1
|
|
corruption_blur_scale: 1
|
|
corrupt_before_downsize: false
|
|
random_seed: 1234
|
|
diffusion_params:
|
|
type: gaussian_diffusion_inference
|
|
generator: generator
|
|
use_ema_model: true
|
|
output_batch_size: 8
|
|
output_scale_factor: 2
|
|
respaced_timestep_spacing: 50
|
|
undo_n1_to_1: true
|
|
beta_schedule:
|
|
schedule_name: linear
|
|
num_diffusion_timesteps: 4000
|
|
diffusion_args:
|
|
model_mean_type: epsilon
|
|
model_var_type: learned_range
|
|
loss_type: mse
|
|
model_input_keys:
|
|
low_res: lq
|
|
corruption_factor: corruption_entropy
|
|
out: sample # Unused
|
|
|
|
logger:
|
|
print_freq: 30
|
|
save_checkpoint_freq: 500
|
|
visuals: [x_start_pred, hq, lq]
|
|
visual_debug_rate: 500
|
|
reverse_n1_to_1: true
|
|
reverse_imagenet_norm: false |