vqvae docs (unfinished)

This commit is contained in:
James Betker 2021-01-07 16:31:57 -07:00
parent acf1535b14
commit 197d19714f
3 changed files with 130 additions and 1 deletions

View File

@ -60,7 +60,6 @@ steps:
training: generator
optimizer_params:
# Optimizer params
lr: !!float 2e-4
weight_decay: 0
beta1: 0.9

22
recipes/vqvae2/README.md Normal file
View File

@ -0,0 +1,22 @@
# VQVAE2 in Pytorch
[VQVAE2](https://arxiv.org/pdf/1906.00446.pdf) is a generative autoencoder developed by Deepmind. It's unique innovation is
discretizing the latent space into a fixed set of "codebook" vectors. This codebook
can then be used in downstream tasks to rebuild images from the training set.
This model is in DLAS thanks to work [@rosinality](https://github.com/rosinality) did
[converting the Deepmind model](https://github.com/rosinality/vq-vae-2-pytorch) to Pytorch.
# Training VQVAE2
VQVAE2 is trained in two steps:
## Training the autoencoder
This first step is to train the autoencoder itself. The config file `train_imgnet_vqvae_stage1.yml` provided shows how to do this
for imagenet with the hyperparameters specified by deepmind. You'll need to bring your own imagenet folder for this.
## Training the PixelCNN encoder
The second step is to train the PixelCNN model which will create "codebook" vectors given an
input image.

View File

@ -0,0 +1,108 @@
name: train_imgnet_vqvae_stage1
model: extensibletrainer
scale: 1
gpu_ids: [0]
start_step: -1
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
fp16: false
wandb: false # <-- enable to log to wandb. tensorboard logging is always enabled.
datasets:
train:
name: imgnet
n_workers: 8
batch_size: 128
mode: imagefolder
paths: /content/imagenet # <-- Put your imagenet path here.
target_size: 224
scale: 1
val:
name: val
mode: fullimage
dataroot_GT: /content/imagenet_val
min_tile_size: 32
scale: 1
force_multiple: 16
networks:
generator:
type: generator
which_model_G: vqvae
kwargs:
# Hyperparameters specified from VQVAE2 paper.
in_channel: 3
channel: 128
n_res_block: 2
n_res_channel: 32
codebook_dim: 64
codebook_size: 512
#### path
path:
#pretrain_model_generator: <insert pretrained model path if desired>
strict_load: true
#resume_state: ../experiments/train_imgnet_vqvae_stage1/training_state/0.state # <-- Set this to resume from a previous training state.
steps:
generator:
training: generator
optimizer_params:
lr: !!float 3e-4
weight_decay: 0
beta1: 0.9
beta2: 0.99
injectors:
# Cool hack for more training diversity:
# Make sure to change below references to `hq` to `cropped`.
#random_crop:
# train: true
# type: random_crop
# dim_in: 224
# dim_out: 192
# in: hq
# out: cropped
gen_inj_train:
train: true
type: generator
generator: generator
in: hq
out: [gen, codebook_commitment_loss]
losses:
pixel_mse_loss:
type: pix
criterion: l2
weight: 1
fake: gen
real: hq
commitment_loss:
type: direct
weight: .25
key: codebook_commitment_loss
train:
niter: 500000
warmup_iter: -1
mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
val_freq: 4000
# Optimizer/LR schedule was not specified in the paper. Using an arbitrary default one.
default_lr_scheme: MultiStepLR
gen_lr_steps: [50000, 100000, 140000, 180000]
lr_gamma: 0.5
eval:
output_state: gen
injectors:
gen_inj_eval:
type: generator
generator: generator
in: hq
out: [gen, codebook_commitment_loss]
logger:
print_freq: 30
save_checkpoint_freq: 2000
visuals: [gen, hq, cropped]
visual_debug_rate: 100