name: train_div2k_esrgan model: extensibletrainer scale: 4 gpu_ids: [0] fp16: false start_step: -1 checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training. use_tb_logger: true wandb: false datasets: train: n_workers: 2 batch_size: 16 name: div2k mode: single_image_extensible paths: /content/div2k # <-- Put your path here. target_size: 128 force_multiple: 1 scale: 4 strict: false val: name: val mode: fullimage dataroot_GT: /content/set14 scale: 4 networks: generator: type: generator which_model_G: RRDBNet in_nc: 3 out_nc: 3 initial_stride: 1 nf: 64 nb: 23 scale: 4 blocks_per_checkpoint: 3 feature_discriminator: type: discriminator which_model_D: discriminator_vgg_128_gn scale: 2 nf: 64 in_nc: 3 image_size: 96 #### path path: #pretrain_model_generator: strict_load: true #resume_state: ../experiments/train_div2k_esrgan/training_state/0.state # <-- Set this to resume from a previous training state. steps: feature_discriminator: training: feature_discriminator after: 100000 # Discriminator doesn't "turn-on" until step 100k to allow generator to anneal on PSNR loss. # Optimizer params lr: !!float 2e-4 weight_decay: 0 beta1: 0.9 beta2: 0.99 injectors: # "image_patch" injectors support the translational loss below. You can remove them if you remove that loss. plq: type: image_patch patch_size: 24 in: lq out: plq phq: type: image_patch patch_size: 96 in: hq out: phq dgen_inj: type: generator generator: generator grad: false in: plq out: dgen losses: gan_disc_img: type: discriminator_gan gan_type: gan weight: 1 #min_loss: .4 noise: .004 gradient_penalty: true real: phq fake: dgen generator: training: generator optimizer_params: lr: !!float 2e-4 weight_decay: 0 beta1: 0.9 beta2: 0.99 injectors: pglq: type: image_patch patch_size: 24 in: lq out: pglq pghq: type: image_patch patch_size: 96 in: hq out: pghq gen_inj: type: generator generator: generator in: pglq out: gen losses: pix: type: pix weight: .05 criterion: l1 real: pghq fake: gen feature: type: feature after: 80000 # Perceptual/"feature" loss doesn't turn on until step 80k. which_model_F: vgg criterion: l1 weight: 1 real: pghq fake: gen gan_gen_img: after: 100000 type: generator_gan gan_type: gan weight: .02 noise: .004 discriminator: feature_discriminator fake: gen real: pghq # Translational loss <- not present in the original ESRGAN paper, but I find it reduces artifacts from the GAN. # Feel free to remove. The network will still train well. translational: type: translational after: 80000 weight: 2 criterion: l1 generator: generator generator_output_index: 0 detach_fake: false patch_size: 96 overlap: 64 real: gen fake: ['pglq'] train: niter: 500000 warmup_iter: -1 mega_batch_factor: 1 val_freq: 2000 # LR scheduler options default_lr_scheme: MultiStepLR gen_lr_steps: [140000, 180000, 200000, 240000] # LR is halved at these steps. Don't do it until GAN is online. lr_gamma: 0.5 eval: output_state: gen logger: print_freq: 30 save_checkpoint_freq: 1000 visuals: [gen, hq, pglq, pghq] visual_debug_rate: 100