name: train_unet_diffusion
use_tb_logger: true
model: extensibletrainer
scale: 1
gpu_ids: [0]
start_step: -1
checkpointing_enabled: true   # If using the UNet architecture, this is pretty much required.
fp16: false
wandb: false   # Set to true to enable wandb logging.
force_start_step: -1

datasets:
  train:
    name: imgset5
    n_workers: 4
    batch_size: 256   # The OpenAI paper uses this batch size for 256px generation. The UNet model uses attention, which benefits from large batch sizes.
    mode: imagefolder
    rgb_n1_to_1: true
    paths: <insert path to a folder full of 256x256 tiled images here>
    target_size: 256
    scale: 2
    fixed_corruptions: [ jpeg-broad, gaussian_blur ]  # This model is trained to correct JPEG artifacts and blurring.
    random_corruptions: [ none ]
    num_corrupts_per_image: 1
    corruption_blur_scale: 1
    corrupt_before_downsize: false

networks:
  generator:
    type: generator
    which_model_G: unet_diffusion
    args:
      image_size: 256
      in_channels: 3
      num_corruptions: 2
      model_channels: 192
      out_channels: 6
      num_res_blocks: 2
      attention_resolutions: [8,16]
      dropout: 0
      channel_mult: [1,1,2,2,4,4]  # These will need to be reduced if you lower the operating resolution.
      num_heads: 4
      num_heads_upsample: -1
      use_scale_shift_norm: true

#### path
path:
  #pretrain_model_generator: <Insert pretrained generator here>
  strict_load: true
  #resume_state: <Insert resume training_state here to resume existing training>

steps:
  generator:
    training: generator

    optimizer: adamw
    optimizer_params:
      lr: !!float 3e-4  # Hyperparameters from OpenAI paper.
      weight_decay: 0
      beta1: 0.9
      beta2: 0.9999

    injectors:
      diffusion:
        type: gaussian_diffusion
        in: hq
        generator: generator
        beta_schedule:
          schedule_name: linear
          num_diffusion_timesteps: 4000
        diffusion_args:
          model_mean_type: epsilon
          model_var_type: learned_range
          loss_type: mse
        sampler_type: uniform
        model_input_keys:
          low_res: lq
          corruption_factor: corruption_entropy
        out: loss
        out_key_vb_loss: vb_loss
        out_key_x_start: x_start_pred
    losses:
      diffusion_loss:
        type: direct
        weight: 1
        key: loss
      var_loss:
        type: direct
        weight: 1
        key: vb_loss

train:
  niter: 500000
  warmup_iter: -1
  mega_batch_factor: 32  # This is massive. Expect ~60sec/step on a RTX3090 at 90%+ memory utilization. I recommend using multiple GPUs to train this network.
  ema_rate: .999
  val_freq: 500

  default_lr_scheme: MultiStepLR
  gen_lr_steps: [ 50000, 100000, 150000 ]
  lr_gamma: 0.5

eval:
  evaluators:
    # Validation for this network is a special FID computation that compares the full resolution images from the specified
    # dataset to the same images, downsampled and corrupted then fed through the network.
    fid:
      type: sr_diffusion_fid
      for: generator  # Unused for this evaluator.
      batch_size: 8
      dataset:
        name: sr_fid_set
        mode: imagefolder
        rgb_n1_to_1: true
        paths: <insert path to a folder of 128-512 validation images here, drawn from your dataset>
        target_size: 256
        scale: 2
        fixed_corruptions: [ jpeg-broad, gaussian_blur ]
        random_corruptions: [ none ]
        num_corrupts_per_image: 1
        corruption_blur_scale: 1
        corrupt_before_downsize: false
        random_seed: 1234
      diffusion_params:
        type: gaussian_diffusion_inference
        generator: generator
        use_ema_model: true
        output_batch_size: 8
        output_scale_factor: 2
        respaced_timestep_spacing: 50
        undo_n1_to_1: true
        beta_schedule:
          schedule_name: linear
          num_diffusion_timesteps: 4000
        diffusion_args:
          model_mean_type: epsilon
          model_var_type: learned_range
          loss_type: mse
        model_input_keys:
          low_res: lq
          corruption_factor: corruption_entropy
        out: sample  # Unused

logger:
  print_freq: 30
  save_checkpoint_freq: 500
  visuals: [x_start_pred, hq, lq]
  visual_debug_rate: 500
  reverse_n1_to_1: true
  reverse_imagenet_norm: false