# This is a config file that trains ESRGAN using the dynamics spelled out in the paper with no modifications. # This has not been trained to completion in some time. I make no guarantees that it will work well. name: train_div2k_esrgan_reference model: extensibletrainer scale: 4 gpu_ids: [0] fp16: false start_step: -1 checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training. use_tb_logger: true wandb: false datasets: train: n_workers: 2 batch_size: 16 name: div2k mode: single_image_extensible paths: /content/div2k # <-- Put your path here. target_size: 128 force_multiple: 1 scale: 4 strict: false val: name: val mode: fullimage dataroot_GT: /content/set14 scale: 4 networks: generator: type: generator which_model_G: RRDBNet in_nc: 3 out_nc: 3 initial_stride: 1 nf: 64 nb: 23 scale: 4 blocks_per_checkpoint: 3 feature_discriminator: type: discriminator which_model_D: discriminator_vgg_128 scale: 2 nf: 64 in_nc: 3 #### path path: #pretrain_model_generator: strict_load: true #resume_state: ../experiments/train_div2k_esrgan/training_state/0.state # <-- Set this to resume from a previous training state. steps: feature_discriminator: training: feature_discriminator after: 100000 # Discriminator doesn't "turn-on" until step 100k to allow generator to anneal on PSNR loss. # Optimizer params lr: !!float 2e-4 weight_decay: 0 beta1: 0.9 beta2: 0.99 injectors: dgen_inj: type: generator generator: generator grad: false in: lq out: dgen losses: gan_disc_img: type: discriminator_gan gan_type: ragan weight: 1 real: hq fake: dgen generator: training: generator optimizer_params: lr: !!float 2e-4 weight_decay: 0 beta1: 0.9 beta2: 0.99 injectors: gen_inj: type: generator generator: generator in: lq out: gen losses: pix: type: pix weight: .05 criterion: l1 real: hq fake: gen feature: type: feature after: 80000 # Perceptual/"feature" loss doesn't turn on until step 80k. which_model_F: vgg criterion: l1 weight: 1 real: hq fake: gen gan_gen_img: after: 100000 type: generator_gan gan_type: ragan weight: .02 discriminator: feature_discriminator fake: gen real: hq train: niter: 500000 warmup_iter: -1 mega_batch_factor: 1 val_freq: 2000 # LR scheduler options default_lr_scheme: MultiStepLR gen_lr_steps: [140000, 180000, 200000, 240000] # LR is halved at these steps. Don't do it until GAN is online. lr_gamma: 0.5 eval: output_state: gen logger: print_freq: 30 save_checkpoint_freq: 1000 visuals: [gen, hq, lq] visual_debug_rate: 100