#### general settings name: train_imgset_rrdb_diffusion model: extensibletrainer scale: 1 gpu_ids: [0] start_step: -1 checkpointing_enabled: true fp16: false use_tb_logger: true wandb: false datasets: train: n_workers: 4 batch_size: 32 name: div2k mode: single_image_extensible paths: /content/div2k # <-- Put your path here. target_size: 128 force_multiple: 1 scale: 4 num_corrupts_per_image: 0 networks: generator: type: generator which_model_G: rrdb_diffusion args: in_channels: 6 out_channels: 6 num_blocks: 10 #### path path: #pretrain_model_generator: strict_load: true #resume_state: ../experiments/train_imgset_rrdb_diffusion/training_state/0.state # <-- Set this to resume from a previous training state. steps: generator: training: generator optimizer_params: lr: !!float 3e-4 weight_decay: !!float 1e-2 beta1: 0.9 beta2: 0.9999 injectors: # "Do it all injector": produces a reverse prediction and calculates losses on it. diffusion: type: gaussian_diffusion in: hq generator: generator beta_schedule: schedule_name: linear num_diffusion_timesteps: 4000 diffusion_args: model_mean_type: epsilon model_var_type: learned_range loss_type: mse sampler_type: uniform model_input_keys: low_res: lq out: loss # Injector for visualizing what your network is doing (every 500 steps) visual_debug: every: 500 type: gaussian_diffusion_inference generator: generator output_shape: [8,3,128,128] # Change "8" to your desired output batch size. beta_schedule: schedule_name: linear num_diffusion_timesteps: 500 # Change higher (up to training steps) for improved quality. Lower for faster speed. diffusion_args: model_mean_type: epsilon model_var_type: learned_range loss_type: mse model_input_keys: low_res: lq out: sample losses: diffusion_loss: type: direct weight: 1 key: loss train: niter: 500000 warmup_iter: -1 mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8]. val_freq: 4000 # Default LR scheduler options default_lr_scheme: CosineAnnealingLR_Restart T_period: [ 200000, 200000 ] warmup: 0 eta_min: !!float 1e-7 restarts: [ 200000, 400000 ] restart_weights: [ .5, .5 ] logger: print_freq: 30 save_checkpoint_freq: 2000 visuals: [sample, hq, lq] visual_debug_rate: 500 reverse_n1_to_1: true