#### general settings
name: train_faces_glean
use_tb_logger: true
model: extensibletrainer
scale: 8
gpu_ids: [0]
fp16: false
start_step: -1
checkpointing_enabled: true
wandb: false

datasets:
  train:
    n_workers: 4
    batch_size: 32
    name: ffhq
    mode: imagefolder
    paths: /content/flickr_faces_hq # <-- Put your data path here.
    target_size: 256
    scale: 8

networks:
  generator:
    type: generator
    which_model_G: glean
    nf: 64
    latent_bank_pretrained_weights: ../experiments/stylegan2-ffhq-config-f.pth

  feature_discriminator:
    type: discriminator
    which_model_D: discriminator_vgg_128_gn
    extra_conv: true
    scale: 2
    nf: 64
    in_nc: 3
    image_size: 256

#### path
path:
  #pretrain_model_generator: <insert pretrained *GLEAN* model path if desired. Pretrained stylegan goes above.>
  strict_load: true
  #resume_state: ../experiments/train_faces_glean/training_state/0.state # <-- Uncomment to continue training at a checkpoint.

steps:
  feature_discriminator:
    training: feature_discriminator
    after: 10000  # Delays starting discriminator training until step 10k

    # Optimizer params
    lr: !!float 2e-4
    weight_decay: 0
    beta1: 0.9
    beta2: 0.99

    injectors:
      dgen_inj:
        type: generator
        generator: generator
        grad: false
        in: lq
        out: dgen

    losses:
      gan_disc_img:
        type: discriminator_gan
        gan_type: gan
        weight: 1
        noise: .004
        gradient_penalty: true
        real: hq
        fake: dgen
        
  generator:
    training: generator

    optimizer_params:
      # Optimizer params
      lr: !!float 2e-4
      weight_decay: 0
      beta1: 0.9
      beta2: 0.99

    injectors:
      gen_inj:
        type: generator
        generator: generator
        in: lq
        out: gen
        
    losses:
      pix:
        type: pix
        weight: 1
        criterion: l2
        real: hq
        fake: gen
      feature:
        type: feature
        after: 5000
        which_model_F: vgg
        criterion: l2
        weight: .01
        real: hq
        fake: gen
      gan_gen_img:
        after: 10000
        type: generator_gan
        gan_type: gan
        weight: .01
        noise: .004
        discriminator: feature_discriminator
        fake: gen
        real: hq

train:
  niter: 500000
  warmup_iter: -1
  mega_batch_factor: 1    # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
  val_freq: 4000  # No validation currently in this config, this is irrelevant.

  # Default LR scheduler options
  default_lr_scheme: MultiStepLR
  gen_lr_steps: [40000, 80000, 100000, 120000]
  lr_gamma: 0.5

eval:
  output_state: gen

logger:
  print_freq: 30
  save_checkpoint_freq: 2000
  visuals: [gen, hq, lq]
  visual_debug_rate: 100