From 730c0135fd741fae730f466bbeb9d8227070ddf9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 15 Jun 2021 17:14:37 -0600 Subject: [PATCH] Guided diffusion documentation --- recipes/ddpm/train_ddpm_rrdb.yml | 109 ---------------- recipes/diffusion/README.md | 36 ++++++ recipes/diffusion/test_diffusion_unet.yml | 78 +++++++++++ recipes/diffusion/train_ddpm_unet.yml | 150 ++++++++++++++++++++++ 4 files changed, 264 insertions(+), 109 deletions(-) delete mode 100644 recipes/ddpm/train_ddpm_rrdb.yml create mode 100644 recipes/diffusion/README.md create mode 100644 recipes/diffusion/test_diffusion_unet.yml create mode 100644 recipes/diffusion/train_ddpm_unet.yml diff --git a/recipes/ddpm/train_ddpm_rrdb.yml b/recipes/ddpm/train_ddpm_rrdb.yml deleted file mode 100644 index 2416f8f6..00000000 --- a/recipes/ddpm/train_ddpm_rrdb.yml +++ /dev/null @@ -1,109 +0,0 @@ -#### general settings -name: train_imgset_rrdb_diffusion -model: extensibletrainer -scale: 1 -gpu_ids: [0] -start_step: -1 -checkpointing_enabled: true -fp16: false -use_tb_logger: true -wandb: false - -datasets: - train: - n_workers: 4 - batch_size: 32 - name: div2k - mode: single_image_extensible - paths: /content/div2k # <-- Put your path here. - target_size: 128 - force_multiple: 1 - scale: 4 - num_corrupts_per_image: 0 - -networks: - generator: - type: generator - which_model_G: rrdb_diffusion - args: - in_channels: 6 - out_channels: 6 - num_blocks: 10 - -#### path -path: - #pretrain_model_generator: - strict_load: true - #resume_state: ../experiments/train_imgset_rrdb_diffusion/training_state/0.state # <-- Set this to resume from a previous training state. - -steps: - generator: - training: generator - - optimizer_params: - lr: !!float 3e-4 - weight_decay: !!float 1e-2 - beta1: 0.9 - beta2: 0.9999 - - injectors: - # "Do it all injector": produces a reverse prediction and calculates losses on it. - diffusion: - type: gaussian_diffusion - in: hq - generator: generator - beta_schedule: - schedule_name: linear - num_diffusion_timesteps: 4000 - diffusion_args: - model_mean_type: epsilon - model_var_type: learned_range - loss_type: mse - sampler_type: uniform - model_input_keys: - low_res: lq - out: loss - - # Injector for visualizing what your network is doing (every 500 steps) - visual_debug: - every: 500 - type: gaussian_diffusion_inference - generator: generator - output_shape: [8,3,128,128] # Change "8" to your desired output batch size. - beta_schedule: - schedule_name: linear - num_diffusion_timesteps: 500 # Change higher (up to training steps) for improved quality. Lower for faster speed. - diffusion_args: - model_mean_type: epsilon - model_var_type: learned_range - loss_type: mse - model_input_keys: - low_res: lq - out: sample - - losses: - diffusion_loss: - type: direct - weight: 1 - key: loss - -train: - niter: 500000 - warmup_iter: -1 - mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8]. - val_freq: 4000 - - # Default LR scheduler options - default_lr_scheme: CosineAnnealingLR_Restart - T_period: [ 200000, 200000 ] - warmup: 0 - eta_min: !!float 1e-7 - restarts: [ 200000, 400000 ] - restart_weights: [ .5, .5 ] - -logger: - print_freq: 30 - save_checkpoint_freq: 2000 - visuals: [sample, hq, lq] - visual_debug_rate: 500 - reverse_n1_to_1: true \ No newline at end of file diff --git a/recipes/diffusion/README.md b/recipes/diffusion/README.md new file mode 100644 index 00000000..4267c5d5 --- /dev/null +++ b/recipes/diffusion/README.md @@ -0,0 +1,36 @@ +# Working with Gaussian Diffusion models in DLAS + +Diffusion Models are a method of generating structural data using a gradual de-noising process. This process allows a +simple network training regime. + +This implementation of Gaussian Diffusion is largely based on the work done by OpenAI in their paper ["Diffusion Models +Beat GANs on Image Synthesis"](https://arxiv.org/pdf/2105.05233.pdf) and ["Improved Denoising Diffusion Probabilistic +Models"](https://arxiv.org/pdf/2102.09672). + +OpenAI opened sourced their reference implementations [here](https://github.com/openai/guided-diffusion). The diffusion +model that DLAS trains uses the [gaussian_diffusion.py](https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/gaussian_diffusion.py) +script from that repo for training and inference with these models. We also include the UNet from that repo as a model +that can be used to train a diffusion network. + +Diffusion networks can be re-purposed to pretty much any image generation task, including super-resolution. Even though +they are trained with MSE losses, they produce incredibly crisp images with FID scores competitive with the best GANs. +More importantly, it is easy to track training progress since diffusion networks use a "normal" loss. + +Diffusion networks are unique in that during inference, they perform multiple forward passes to generate a single image. +During training, these networks are trained to denoise images over 4000 steps. In inference, this sample rate can be +adjusted. For the purposes of super-resolution, I have found that images sampled in 50 steps to be of very good quality. +This still means that a diffusion generator is 50x slower than generators trained in different ways. + +What's more is that I have found that diffusion networks can be trained in the tiled methodology used by ESRGAN: instead +of training on whole images, you can train on tiles of larger images. At inference time, the network can be applied to +larger images than the network was initially trained on. I have found this works well on inference images within ~3x +the training size. I have not tried larger, because the size of the UNet model means that inference at ultra-high +resolutions is impossible (I run out of GPU memory). + +I have provided a reference configuration for training a diffusion model in this manner. The config performs a 2x +upsampling to 256px, de-blurs it and removes JPEG artifacts. The deblurring and image repairs are done on a configurable +scale. The scale is [0,1] passed to the model as `corruption_entropy`. `1` represents a maximum correction factor. +You can try reducing this to 128px for faster training. It should work fine. + +Diffusion models also have a fairly arcane inference method. To help you along, I've provided an inference configuration +that can be used with models trained in DLAS. \ No newline at end of file diff --git a/recipes/diffusion/test_diffusion_unet.yml b/recipes/diffusion/test_diffusion_unet.yml new file mode 100644 index 00000000..e6c18902 --- /dev/null +++ b/recipes/diffusion/test_diffusion_unet.yml @@ -0,0 +1,78 @@ +#### general settings +name: test_diffusion_unet +use_tb_logger: true +model: extensibletrainer +scale: 1 +gpu_ids: [0] +start_step: -1 +checkpointing_enabled: true +fp16: false +wandb: false + +datasets: + train: + name: my_inference_images + n_workers: 0 + batch_size: 1 + mode: imagefolder + rgb_n1_to_1: true + disable_flip: true + force_square: false + paths: + scale: 1 + skip_lq: true + fixed_parameters: + # Specify correction factors here. For networks trained with the paired training configuration, the first number + # is a JPEG correction factor, and the second number is a deblurring factor. Testing shows that if you attempt to + # deblur too far, you get extremely distorted images. It's actually pretty cool - the network clearly knows how + # much deblurring is appropriate. + corruption_entropy: [.2, .5] + +networks: + generator: + type: generator + which_model_G: unet_diffusion + args: + image_size: 256 + in_channels: 3 + num_corruptions: 2 + model_channels: 192 + out_channels: 6 + num_res_blocks: 2 + attention_resolutions: [8,16] + dropout: 0 + channel_mult: [1,1,2,2,4,4] + num_heads: 4 + num_heads_upsample: -1 + use_scale_shift_norm: true + +#### path +path: + pretrain_model_generator: + strict_load: true + +steps: + generator: + training: generator + injectors: + visual_debug: + type: gaussian_diffusion_inference + generator: generator + output_batch_size: 1 + output_scale_factor: 2 + respaced_timestep_spacing: 50 # This can be tweaked to perform inference faster or slower. 50-200 seems to be the sweet spot. At 4000 steps, the quality is actually worse often. + undo_n1_to_1: true + beta_schedule: + schedule_name: linear + num_diffusion_timesteps: 4000 + diffusion_args: + model_mean_type: epsilon + model_var_type: learned_range + loss_type: mse + model_input_keys: + low_res: hq + corruption_factor: corruption_entropy + out: sample + +eval: + output_state: sample \ No newline at end of file diff --git a/recipes/diffusion/train_ddpm_unet.yml b/recipes/diffusion/train_ddpm_unet.yml new file mode 100644 index 00000000..11936537 --- /dev/null +++ b/recipes/diffusion/train_ddpm_unet.yml @@ -0,0 +1,150 @@ +name: train_unet_diffusion +use_tb_logger: true +model: extensibletrainer +scale: 1 +gpu_ids: [0] +start_step: -1 +checkpointing_enabled: true # If using the UNet architecture, this is pretty much required. +fp16: false +wandb: false # Set to true to enable wandb logging. +force_start_step: -1 + +datasets: + train: + name: imgset5 + n_workers: 4 + batch_size: 256 # The OpenAI paper uses this batch size for 256px generation. The UNet model uses attention, which benefits from large batch sizes. + mode: imagefolder + rgb_n1_to_1: true + paths: + target_size: 256 + scale: 2 + fixed_corruptions: [ jpeg-broad, gaussian_blur ] # This model is trained to correct JPEG artifacts and blurring. + random_corruptions: [ none ] + num_corrupts_per_image: 1 + corruption_blur_scale: 1 + corrupt_before_downsize: false + +networks: + generator: + type: generator + which_model_G: unet_diffusion + args: + image_size: 256 + in_channels: 3 + num_corruptions: 2 + model_channels: 192 + out_channels: 6 + num_res_blocks: 2 + attention_resolutions: [8,16] + dropout: 0 + channel_mult: [1,1,2,2,4,4] # These will need to be reduced if you lower the operating resolution. + num_heads: 4 + num_heads_upsample: -1 + use_scale_shift_norm: true + +#### path +path: + #pretrain_model_generator: + strict_load: true + #resume_state: + +steps: + generator: + training: generator + + optimizer: adamw + optimizer_params: + lr: !!float 3e-4 # Hyperparameters from OpenAI paper. + weight_decay: 0 + beta1: 0.9 + beta2: 0.9999 + + injectors: + diffusion: + type: gaussian_diffusion + in: hq + generator: generator + beta_schedule: + schedule_name: linear + num_diffusion_timesteps: 4000 + diffusion_args: + model_mean_type: epsilon + model_var_type: learned_range + loss_type: mse + sampler_type: uniform + model_input_keys: + low_res: lq + corruption_factor: corruption_entropy + out: loss + out_key_vb_loss: vb_loss + out_key_x_start: x_start_pred + losses: + diffusion_loss: + type: direct + weight: 1 + key: loss + var_loss: + type: direct + weight: 1 + key: vb_loss + +train: + niter: 500000 + warmup_iter: -1 + mega_batch_factor: 32 # This is massive. Expect ~60sec/step on a RTX3090 at 90%+ memory utilization. I recommend using multiple GPUs to train this network. + ema_rate: .999 + val_freq: 500 + + default_lr_scheme: MultiStepLR + gen_lr_steps: [ 50000, 100000, 150000 ] + lr_gamma: 0.5 + +eval: + evaluators: + # Validation for this network is a special FID computation that compares the full resolution images from the specified + # dataset to the same images, downsampled and corrupted then fed through the network. + fid: + type: sr_diffusion_fid + for: generator # Unused for this evaluator. + batch_size: 8 + dataset: + name: sr_fid_set + mode: imagefolder + rgb_n1_to_1: true + paths: + target_size: 256 + scale: 2 + fixed_corruptions: [ jpeg-broad, gaussian_blur ] + random_corruptions: [ none ] + num_corrupts_per_image: 1 + corruption_blur_scale: 1 + corrupt_before_downsize: false + random_seed: 1234 + diffusion_params: + type: gaussian_diffusion_inference + generator: generator + use_ema_model: true + output_batch_size: 8 + output_scale_factor: 2 + respaced_timestep_spacing: 50 + undo_n1_to_1: true + beta_schedule: + schedule_name: linear + num_diffusion_timesteps: 4000 + diffusion_args: + model_mean_type: epsilon + model_var_type: learned_range + loss_type: mse + model_input_keys: + low_res: lq + corruption_factor: corruption_entropy + out: sample # Unused + +logger: + print_freq: 30 + save_checkpoint_freq: 500 + visuals: [x_start_pred, hq, lq] + visual_debug_rate: 500 + reverse_n1_to_1: true + reverse_imagenet_norm: false \ No newline at end of file