forked from mrq/DL-Art-School
update vqvae to double codebook size for bottom quantizer
This commit is contained in:
parent
dac7d768fa
commit
b374dcdd46
|
@ -199,7 +199,7 @@ class VQVAE(nn.Module):
|
|||
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2
|
||||
)
|
||||
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
||||
self.quantize_b = Quantize(codebook_dim, codebook_size)
|
||||
self.quantize_b = Quantize(codebook_dim, codebook_size*2)
|
||||
self.upsample_t = UpsampleConv(
|
||||
codebook_dim, codebook_dim, 5, padding=2
|
||||
)
|
||||
|
|
|
@ -65,8 +65,11 @@ def sequential_checkpoint(fn, partitions, *args):
|
|||
return fn(*args)
|
||||
|
||||
# A fancy alternative to if <flag> checkpoint() else <call>
|
||||
def possible_checkpoint(enabled, fn, *args):
|
||||
opt_en = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
||||
def possible_checkpoint(opt_en, fn, *args):
|
||||
if loaded_options is None:
|
||||
enabled = False
|
||||
else:
|
||||
enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
||||
if enabled and opt_en:
|
||||
return torch.utils.checkpoint.checkpoint(fn, *args)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue
Block a user