diff --git a/recipes/glean/README.md b/recipes/glean/README.md new file mode 100644 index 00000000..09799fb3 --- /dev/null +++ b/recipes/glean/README.md @@ -0,0 +1,25 @@ +# GLEAN + +DLAS contains an attempt at implementing [GLEAN](https://ckkelvinchan.github.io/papers/glean.pdf), which performs image +super-resolution guided by pretrained StyleGAN networks. Since this paper is currently closed-source, it was +implemented entirely on what information I could glean from the paper. + +## Training + +GLEAN requires a pre-trained StyleGAN network to operate. DLAS currently only has support for StyleGAN2 models, so +you will need to use one of those. The pre-eminent StyleGAN 2 model is the one trained on FFHQ faces, so I will use +that in this training example. + +1. Download the ffhq model from [nVidias Drive](https://drive.google.com/drive/folders/1yanUI9m4b4PWzR0eurKNq6JR1Bbfbh6L). + This repo currently only supports the "-f.pkl" files without further modifications, so choose one of those. +1. Download and extract the [FFHQ dataset](https://github.com/NVlabs/ffhq-dataset). +1. Convert the TF model to a Pytorch one supported by DLAS: + + `python scripts/stylegan2/convert_weights_rosinality.py stylegan2-ffhq-config-f.pkl` + +1. The above conversion script outputs a *.pth file as well as JPG preview of model outputs. Check the JPG to ensure + the StyleGAN is performing as expected. If so, copy the *.pth file to your experiments/ directory within DLAS. +1. Edit the provided trainer configuration. Find comments starting with '<--' and make changes as indicated. +1. Train the model: + + `python train.py -opt train_ffhq_glean.yml` \ No newline at end of file diff --git a/recipes/glean/train_ffhq_glean.yml b/recipes/glean/train_ffhq_glean.yml new file mode 100644 index 00000000..af4a44ea --- /dev/null +++ b/recipes/glean/train_ffhq_glean.yml @@ -0,0 +1,133 @@ +#### general settings +name: train_faces_glean +use_tb_logger: true +model: extensibletrainer +scale: 8 +gpu_ids: [0] +fp16: false +start_step: -1 +checkpointing_enabled: true +wandb: false + +datasets: + train: + n_workers: 4 + batch_size: 32 + name: ffhq + mode: imagefolder + paths: /content/flickr_faces_hq # <-- Put your data path here. + target_size: 256 + scale: 8 + +networks: + generator: + type: generator + which_model_G: glean + nf: 64 + pretrained_stylegan: ../experiments/stylegan2-ffhq-config-f.pth + + feature_discriminator: + type: discriminator + which_model_D: discriminator_vgg_128_gn + extra_conv: true + scale: 2 + nf: 64 + in_nc: 3 + image_size: 256 + +#### path +path: + #pretrain_model_generator: + strict_load: true + #resume_state: ../experiments/train_faces_glean/training_state/0.state # <-- Uncomment to continue training at a checkpoint. + +steps: + feature_discriminator: + training: feature_discriminator + after: 10000 # Delays starting discriminator training until step 10k + + # Optimizer params + lr: !!float 2e-4 + weight_decay: 0 + beta1: 0.9 + beta2: 0.99 + + injectors: + dgen_inj: + type: generator + generator: generator + grad: false + in: lq + out: dgen + + losses: + gan_disc_img: + type: discriminator_gan + gan_type: gan + weight: 1 + noise: .004 + gradient_penalty: true + real: hq + fake: dgen + + generator: + training: generator + + optimizer_params: + # Optimizer params + lr: !!float 2e-4 + weight_decay: 0 + beta1: 0.9 + beta2: 0.99 + + injectors: + gen_inj: + type: generator + generator: generator + in: lq + out: gen + + losses: + pix: + type: pix + weight: .05 + criterion: l1 + real: hq + fake: gen + feature: + type: feature + after: 5000 + which_model_F: vgg + criterion: l1 + weight: 1 + real: hq + fake: gen + gan_gen_img: + after: 10000 + type: generator_gan + gan_type: gan + weight: .02 + noise: .004 + discriminator: feature_discriminator + fake: gen + real: hq + +train: + niter: 500000 + warmup_iter: -1 + mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8]. + val_freq: 4000 # No validation currently in this config, this is irrelevant. + + # Default LR scheduler options + default_lr_scheme: MultiStepLR + gen_lr_steps: [40000, 80000, 100000, 120000] + lr_gamma: 0.5 + +eval: + output_state: gen + +logger: + print_freq: 30 + save_checkpoint_freq: 2000 + visuals: [gen, hq, lq] + visual_debug_rate: 100 \ No newline at end of file