From b374dcdd462cf483afe4e5923ba1135c8fb28b71 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 23 Jan 2021 13:47:07 -0700 Subject: [PATCH] update vqvae to double codebook size for bottom quantizer --- codes/models/vqvae/vqvae_no_conv_transpose.py | 2 +- codes/utils/util.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/codes/models/vqvae/vqvae_no_conv_transpose.py b/codes/models/vqvae/vqvae_no_conv_transpose.py index 1d28a82d..7415d4ad 100644 --- a/codes/models/vqvae/vqvae_no_conv_transpose.py +++ b/codes/models/vqvae/vqvae_no_conv_transpose.py @@ -199,7 +199,7 @@ class VQVAE(nn.Module): codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2 ) self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) - self.quantize_b = Quantize(codebook_dim, codebook_size) + self.quantize_b = Quantize(codebook_dim, codebook_size*2) self.upsample_t = UpsampleConv( codebook_dim, codebook_dim, 5, padding=2 ) diff --git a/codes/utils/util.py b/codes/utils/util.py index cf42ccb9..5097499d 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -65,8 +65,11 @@ def sequential_checkpoint(fn, partitions, *args): return fn(*args) # A fancy alternative to if checkpoint() else -def possible_checkpoint(enabled, fn, *args): - opt_en = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True +def possible_checkpoint(opt_en, fn, *args): + if loaded_options is None: + enabled = False + else: + enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True if enabled and opt_en: return torch.utils.checkpoint.checkpoint(fn, *args) else: