update vqvae to double codebook size for bottom quantizer

This commit is contained in:
James Betker 2021-01-23 13:47:07 -07:00
parent dac7d768fa
commit b374dcdd46
2 changed files with 6 additions and 3 deletions

View File

@ -199,7 +199,7 @@ class VQVAE(nn.Module):
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2
)
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
self.quantize_b = Quantize(codebook_dim, codebook_size)
self.quantize_b = Quantize(codebook_dim, codebook_size*2)
self.upsample_t = UpsampleConv(
codebook_dim, codebook_dim, 5, padding=2
)

View File

@ -65,8 +65,11 @@ def sequential_checkpoint(fn, partitions, *args):
return fn(*args)
# A fancy alternative to if <flag> checkpoint() else <call>
def possible_checkpoint(enabled, fn, *args):
opt_en = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
def possible_checkpoint(opt_en, fn, *args):
if loaded_options is None:
enabled = False
else:
enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
if enabled and opt_en:
return torch.utils.checkpoint.checkpoint(fn, *args)
else: