name: train_unet_diffusion use_tb_logger: true model: extensibletrainer scale: 1 gpu_ids: [0] start_step: -1 checkpointing_enabled: true # If using the UNet architecture, this is pretty much required. fp16: false wandb: false # Set to true to enable wandb logging. force_start_step: -1 datasets: train: name: imgset5 n_workers: 4 batch_size: 256 # The OpenAI paper uses this batch size for 256px generation. The UNet model uses attention, which benefits from large batch sizes. mode: imagefolder rgb_n1_to_1: true paths: target_size: 256 scale: 2 fixed_corruptions: [ jpeg-broad, gaussian_blur ] # This model is trained to correct JPEG artifacts and blurring. random_corruptions: [ none ] num_corrupts_per_image: 1 corruption_blur_scale: 1 corrupt_before_downsize: false networks: generator: type: generator which_model_G: unet_diffusion args: image_size: 256 in_channels: 3 num_corruptions: 2 model_channels: 192 out_channels: 6 num_res_blocks: 2 attention_resolutions: [8,16] dropout: 0 channel_mult: [1,1,2,2,4,4] # These will need to be reduced if you lower the operating resolution. num_heads: 4 num_heads_upsample: -1 use_scale_shift_norm: true #### path path: #pretrain_model_generator: strict_load: true #resume_state: steps: generator: training: generator optimizer: adamw optimizer_params: lr: !!float 3e-4 # Hyperparameters from OpenAI paper. weight_decay: 0 beta1: 0.9 beta2: 0.9999 injectors: diffusion: type: gaussian_diffusion in: hq generator: generator beta_schedule: schedule_name: linear num_diffusion_timesteps: 4000 diffusion_args: model_mean_type: epsilon model_var_type: learned_range loss_type: mse sampler_type: uniform model_input_keys: low_res: lq corruption_factor: corruption_entropy out: loss out_key_vb_loss: vb_loss out_key_x_start: x_start_pred losses: diffusion_loss: type: direct weight: 1 key: loss var_loss: type: direct weight: 1 key: vb_loss train: niter: 500000 warmup_iter: -1 mega_batch_factor: 32 # This is massive. Expect ~60sec/step on a RTX3090 at 90%+ memory utilization. I recommend using multiple GPUs to train this network. ema_rate: .999 val_freq: 500 default_lr_scheme: MultiStepLR gen_lr_steps: [ 50000, 100000, 150000 ] lr_gamma: 0.5 eval: evaluators: # Validation for this network is a special FID computation that compares the full resolution images from the specified # dataset to the same images, downsampled and corrupted then fed through the network. fid: type: sr_diffusion_fid for: generator # Unused for this evaluator. batch_size: 8 dataset: name: sr_fid_set mode: imagefolder rgb_n1_to_1: true paths: target_size: 256 scale: 2 fixed_corruptions: [ jpeg-broad, gaussian_blur ] random_corruptions: [ none ] num_corrupts_per_image: 1 corruption_blur_scale: 1 corrupt_before_downsize: false random_seed: 1234 diffusion_params: type: gaussian_diffusion_inference generator: generator use_ema_model: true output_batch_size: 8 output_scale_factor: 2 respaced_timestep_spacing: 50 undo_n1_to_1: true beta_schedule: schedule_name: linear num_diffusion_timesteps: 4000 diffusion_args: model_mean_type: epsilon model_var_type: learned_range loss_type: mse model_input_keys: low_res: lq corruption_factor: corruption_entropy out: sample # Unused logger: print_freq: 30 save_checkpoint_freq: 500 visuals: [x_start_pred, hq, lq] visual_debug_rate: 500 reverse_n1_to_1: true reverse_imagenet_norm: false