108 lines
2.6 KiB
YAML
108 lines
2.6 KiB
YAML
name: train_imgnet_vqvae_stage1
|
|
model: extensibletrainer
|
|
scale: 1
|
|
gpu_ids: [0]
|
|
start_step: -1
|
|
checkpointing_enabled: true # <-- Gradient checkpointing. Enable for huge GPU memory savings. Disable for distributed training.
|
|
fp16: false
|
|
wandb: false # <-- enable to log to wandb. tensorboard logging is always enabled.
|
|
|
|
datasets:
|
|
train:
|
|
name: imgnet
|
|
n_workers: 8
|
|
batch_size: 128
|
|
mode: imagefolder
|
|
paths: /content/imagenet # <-- Put your imagenet path here.
|
|
target_size: 224
|
|
scale: 1
|
|
val:
|
|
name: val
|
|
mode: fullimage
|
|
dataroot_GT: /content/imagenet_val
|
|
min_tile_size: 32
|
|
scale: 1
|
|
force_multiple: 16
|
|
|
|
networks:
|
|
generator:
|
|
type: generator
|
|
which_model_G: vqvae
|
|
kwargs:
|
|
# Hyperparameters specified from VQVAE2 paper.
|
|
in_channel: 3
|
|
channel: 128
|
|
n_res_block: 2
|
|
n_res_channel: 32
|
|
codebook_dim: 64
|
|
codebook_size: 512
|
|
|
|
#### path
|
|
path:
|
|
#pretrain_model_generator: <insert pretrained model path if desired>
|
|
strict_load: true
|
|
#resume_state: ../experiments/train_imgnet_vqvae_stage1/training_state/0.state # <-- Set this to resume from a previous training state.
|
|
|
|
steps:
|
|
generator:
|
|
training: generator
|
|
|
|
optimizer_params:
|
|
lr: !!float 3e-4
|
|
weight_decay: 0
|
|
beta1: 0.9
|
|
beta2: 0.99
|
|
|
|
injectors:
|
|
# Cool hack for more training diversity:
|
|
# Make sure to change below references to `hq` to `cropped`.
|
|
#random_crop:
|
|
# train: true
|
|
# type: random_crop
|
|
# dim_in: 224
|
|
# dim_out: 192
|
|
# in: hq
|
|
# out: cropped
|
|
gen_inj_train:
|
|
train: true
|
|
type: generator
|
|
generator: generator
|
|
in: hq
|
|
out: [gen, codebook_commitment_loss]
|
|
losses:
|
|
pixel_mse_loss:
|
|
type: pix
|
|
criterion: l2
|
|
weight: 1
|
|
fake: gen
|
|
real: hq
|
|
commitment_loss:
|
|
type: direct
|
|
weight: .25
|
|
key: codebook_commitment_loss
|
|
|
|
train:
|
|
niter: 500000
|
|
warmup_iter: -1
|
|
mega_batch_factor: 1 # <-- Gradient accumulation factor. If you are running OOM, increase this to [2,4,8].
|
|
val_freq: 4000
|
|
|
|
# Optimizer/LR schedule was not specified in the paper. Using an arbitrary default one.
|
|
default_lr_scheme: MultiStepLR
|
|
gen_lr_steps: [50000, 100000, 140000, 180000]
|
|
lr_gamma: 0.5
|
|
|
|
eval:
|
|
output_state: gen
|
|
injectors:
|
|
gen_inj_eval:
|
|
type: generator
|
|
generator: generator
|
|
in: hq
|
|
out: [gen, codebook_commitment_loss]
|
|
|
|
logger:
|
|
print_freq: 30
|
|
save_checkpoint_freq: 2000
|
|
visuals: [gen, hq, cropped]
|
|
visual_debug_rate: 100 |