vqvae docs (unfinished)
This commit is contained in:
parent
acf1535b14
commit
197d19714f
|
@ -60,7 +60,6 @@ steps:
|
|||
training: generator
|
||||
|
||||
optimizer_params:
|
||||
# Optimizer params
|
||||
lr: !!float 2e-4
|
||||
weight_decay: 0
|
||||
beta1: 0.9
|
||||
|
|
22
recipes/vqvae2/README.md
Normal file
22
recipes/vqvae2/README.md
Normal file
|
@ -0,0 +1,22 @@
|
|||
# VQVAE2 in Pytorch
|
||||
|
||||
[VQVAE2](https://arxiv.org/pdf/1906.00446.pdf) is a generative autoencoder developed by Deepmind. It's unique innovation is
|
||||
discretizing the latent space into a fixed set of "codebook" vectors. This codebook
|
||||
can then be used in downstream tasks to rebuild images from the training set.
|
||||
|
||||
This model is in DLAS thanks to work [@rosinality](https://github.com/rosinality) did
|
||||
[converting the Deepmind model](https://github.com/rosinality/vq-vae-2-pytorch) to Pytorch.
|
||||
|
||||
# Training VQVAE2
|
||||
|
||||
VQVAE2 is trained in two steps:
|
||||
|
||||
## Training the autoencoder
|
||||
|
||||
This first step is to train the autoencoder itself. The config file `train_imgnet_vqvae_stage1.yml` provided shows how to do this
|
||||
for imagenet with the hyperparameters specified by deepmind. You'll need to bring your own imagenet folder for this.
|
||||
|
||||
## Training the PixelCNN encoder
|
||||
|
||||
The second step is to train the PixelCNN model which will create "codebook" vectors given an
|
||||
input image.
|
108
recipes/vqvae2/train_imgnet_vqvae_stage1.yml
Normal file
108
recipes/vqvae2/train_imgnet_vqvae_stage1.yml
Normal file
|
@ -0,0 +1,108 @@
|
|||
name: train_imgnet_vqvae_stage1
|
||||
model: extensibletrainer
|
||||
scale: 1
|
||||
gpu_ids: [0]
|
||||
start_step: -1
|
||||
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
|
||||
fp16: false
|
||||
wandb: false # <-- enable to log to wandb. tensorboard logging is always enabled.
|
||||
|
||||
datasets:
|
||||
train:
|
||||
name: imgnet
|
||||
n_workers: 8
|
||||
batch_size: 128
|
||||
mode: imagefolder
|
||||
paths: /content/imagenet # <-- Put your imagenet path here.
|
||||
target_size: 224
|
||||
scale: 1
|
||||
val:
|
||||
name: val
|
||||
mode: fullimage
|
||||
dataroot_GT: /content/imagenet_val
|
||||
min_tile_size: 32
|
||||
scale: 1
|
||||
force_multiple: 16
|
||||
|
||||
networks:
|
||||
generator:
|
||||
type: generator
|
||||
which_model_G: vqvae
|
||||
kwargs:
|
||||
# Hyperparameters specified from VQVAE2 paper.
|
||||
in_channel: 3
|
||||
channel: 128
|
||||
n_res_block: 2
|
||||
n_res_channel: 32
|
||||
codebook_dim: 64
|
||||
codebook_size: 512
|
||||
|
||||
#### path
|
||||
path:
|
||||
#pretrain_model_generator: <insert pretrained model path if desired>
|
||||
strict_load: true
|
||||
#resume_state: ../experiments/train_imgnet_vqvae_stage1/training_state/0.state # <-- Set this to resume from a previous training state.
|
||||
|
||||
steps:
|
||||
generator:
|
||||
training: generator
|
||||
|
||||
optimizer_params:
|
||||
lr: !!float 3e-4
|
||||
weight_decay: 0
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
|
||||
injectors:
|
||||
# Cool hack for more training diversity:
|
||||
# Make sure to change below references to `hq` to `cropped`.
|
||||
#random_crop:
|
||||
# train: true
|
||||
# type: random_crop
|
||||
# dim_in: 224
|
||||
# dim_out: 192
|
||||
# in: hq
|
||||
# out: cropped
|
||||
gen_inj_train:
|
||||
train: true
|
||||
type: generator
|
||||
generator: generator
|
||||
in: hq
|
||||
out: [gen, codebook_commitment_loss]
|
||||
losses:
|
||||
pixel_mse_loss:
|
||||
type: pix
|
||||
criterion: l2
|
||||
weight: 1
|
||||
fake: gen
|
||||
real: hq
|
||||
commitment_loss:
|
||||
type: direct
|
||||
weight: .25
|
||||
key: codebook_commitment_loss
|
||||
|
||||
train:
|
||||
niter: 500000
|
||||
warmup_iter: -1
|
||||
mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
|
||||
val_freq: 4000
|
||||
|
||||
# Optimizer/LR schedule was not specified in the paper. Using an arbitrary default one.
|
||||
default_lr_scheme: MultiStepLR
|
||||
gen_lr_steps: [50000, 100000, 140000, 180000]
|
||||
lr_gamma: 0.5
|
||||
|
||||
eval:
|
||||
output_state: gen
|
||||
injectors:
|
||||
gen_inj_eval:
|
||||
type: generator
|
||||
generator: generator
|
||||
in: hq
|
||||
out: [gen, codebook_commitment_loss]
|
||||
|
||||
logger:
|
||||
print_freq: 30
|
||||
save_checkpoint_freq: 2000
|
||||
visuals: [gen, hq, cropped]
|
||||
visual_debug_rate: 100
|
Loading…
Reference in New Issue
Block a user