From 197d19714f412d470bba22d54f08c79702ec92a3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 7 Jan 2021 16:31:57 -0700 Subject: [PATCH] vqvae docs (unfinished) --- recipes/srflow/train_div2k_srflow.yml | 1 - recipes/vqvae2/README.md | 22 ++++ recipes/vqvae2/train_imgnet_vqvae_stage1.yml | 108 +++++++++++++++++++ 3 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 recipes/vqvae2/README.md create mode 100644 recipes/vqvae2/train_imgnet_vqvae_stage1.yml diff --git a/recipes/srflow/train_div2k_srflow.yml b/recipes/srflow/train_div2k_srflow.yml index e95e2d3d..2b3683fe 100644 --- a/recipes/srflow/train_div2k_srflow.yml +++ b/recipes/srflow/train_div2k_srflow.yml @@ -60,7 +60,6 @@ steps: training: generator optimizer_params: - # Optimizer params lr: !!float 2e-4 weight_decay: 0 beta1: 0.9 diff --git a/recipes/vqvae2/README.md b/recipes/vqvae2/README.md new file mode 100644 index 00000000..e016d9de --- /dev/null +++ b/recipes/vqvae2/README.md @@ -0,0 +1,22 @@ +# VQVAE2 in Pytorch + +[VQVAE2](https://arxiv.org/pdf/1906.00446.pdf) is a generative autoencoder developed by Deepmind. It's unique innovation is +discretizing the latent space into a fixed set of "codebook" vectors. This codebook +can then be used in downstream tasks to rebuild images from the training set. + +This model is in DLAS thanks to work [@rosinality](https://github.com/rosinality) did +[converting the Deepmind model](https://github.com/rosinality/vq-vae-2-pytorch) to Pytorch. + +# Training VQVAE2 + +VQVAE2 is trained in two steps: + +## Training the autoencoder + +This first step is to train the autoencoder itself. The config file `train_imgnet_vqvae_stage1.yml` provided shows how to do this +for imagenet with the hyperparameters specified by deepmind. You'll need to bring your own imagenet folder for this. + +## Training the PixelCNN encoder + +The second step is to train the PixelCNN model which will create "codebook" vectors given an +input image. \ No newline at end of file diff --git a/recipes/vqvae2/train_imgnet_vqvae_stage1.yml b/recipes/vqvae2/train_imgnet_vqvae_stage1.yml new file mode 100644 index 00000000..3a7f38d3 --- /dev/null +++ b/recipes/vqvae2/train_imgnet_vqvae_stage1.yml @@ -0,0 +1,108 @@ +name: train_imgnet_vqvae_stage1 +model: extensibletrainer +scale: 1 +gpu_ids: [0] +start_step: -1 +checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training. +fp16: false +wandb: false # <-- enable to log to wandb. tensorboard logging is always enabled. + +datasets: + train: + name: imgnet + n_workers: 8 + batch_size: 128 + mode: imagefolder + paths: /content/imagenet # <-- Put your imagenet path here. + target_size: 224 + scale: 1 + val: + name: val + mode: fullimage + dataroot_GT: /content/imagenet_val + min_tile_size: 32 + scale: 1 + force_multiple: 16 + +networks: + generator: + type: generator + which_model_G: vqvae + kwargs: + # Hyperparameters specified from VQVAE2 paper. + in_channel: 3 + channel: 128 + n_res_block: 2 + n_res_channel: 32 + codebook_dim: 64 + codebook_size: 512 + +#### path +path: + #pretrain_model_generator: + strict_load: true + #resume_state: ../experiments/train_imgnet_vqvae_stage1/training_state/0.state # <-- Set this to resume from a previous training state. + +steps: + generator: + training: generator + + optimizer_params: + lr: !!float 3e-4 + weight_decay: 0 + beta1: 0.9 + beta2: 0.99 + + injectors: + # Cool hack for more training diversity: + # Make sure to change below references to `hq` to `cropped`. + #random_crop: + # train: true + # type: random_crop + # dim_in: 224 + # dim_out: 192 + # in: hq + # out: cropped + gen_inj_train: + train: true + type: generator + generator: generator + in: hq + out: [gen, codebook_commitment_loss] + losses: + pixel_mse_loss: + type: pix + criterion: l2 + weight: 1 + fake: gen + real: hq + commitment_loss: + type: direct + weight: .25 + key: codebook_commitment_loss + +train: + niter: 500000 + warmup_iter: -1 + mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8]. + val_freq: 4000 + + # Optimizer/LR schedule was not specified in the paper. Using an arbitrary default one. + default_lr_scheme: MultiStepLR + gen_lr_steps: [50000, 100000, 140000, 180000] + lr_gamma: 0.5 + +eval: + output_state: gen + injectors: + gen_inj_eval: + type: generator + generator: generator + in: hq + out: [gen, codebook_commitment_loss] + +logger: + print_freq: 30 + save_checkpoint_freq: 2000 + visuals: [gen, hq, cropped] + visual_debug_rate: 100 \ No newline at end of file