#### general settings
name: test_diffusion_unet
use_tb_logger: true
model: extensibletrainer
scale: 1
gpu_ids: [0]
start_step: -1
checkpointing_enabled: true
fp16: false
wandb: false

datasets:
  train:
    name: my_inference_images
    n_workers: 0
    batch_size: 1
    mode: imagefolder
    rgb_n1_to_1: true
    disable_flip: true
    force_square: false
    paths: <low resolution images you want to upsample>
    scale: 1
    skip_lq: true
    fixed_parameters:
      # Specify correction factors here. For networks trained with the paired training configuration, the first number
      # is a JPEG correction factor, and the second number is a deblurring factor. Testing shows that if you attempt to
      # deblur too far, you get extremely distorted images. It's actually pretty cool - the network clearly knows how
      # much deblurring is appropriate.
      corruption_entropy: [.2, .5]

networks:
  generator:
    type: generator
    which_model_G: unet_diffusion
    args:
      image_size: 256
      in_channels: 3
      num_corruptions: 2
      model_channels: 192
      out_channels: 6
      num_res_blocks: 2
      attention_resolutions: [8,16]
      dropout: 0
      channel_mult: [1,1,2,2,4,4]
      num_heads: 4
      num_heads_upsample: -1
      use_scale_shift_norm: true

#### path
path:
  pretrain_model_generator: <Your model (or EMA) path>
  strict_load: true

steps:        
  generator:
    training: generator
    injectors:
      visual_debug:
        type: gaussian_diffusion_inference
        generator: generator
        output_batch_size: 1
        output_scale_factor: 2
        respaced_timestep_spacing: 50  # This can be tweaked to perform inference faster or slower. 50-200 seems to be the sweet spot. At 4000 steps, the quality is actually worse often.
        undo_n1_to_1: true
        beta_schedule:
          schedule_name: linear
          num_diffusion_timesteps: 4000
        diffusion_args:
          model_mean_type: epsilon
          model_var_type: learned_range
          loss_type: mse
        model_input_keys:
          low_res: hq
          corruption_factor: corruption_entropy
        out: sample

eval:
    output_state: sample