forked from mrq/DL-Art-School
update vqvae to double codebook size for bottom quantizer
This commit is contained in:
parent
dac7d768fa
commit
b374dcdd46
|
@ -199,7 +199,7 @@ class VQVAE(nn.Module):
|
||||||
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2
|
codebook_dim, codebook_dim, channel, n_res_block, n_res_channel, stride=2
|
||||||
)
|
)
|
||||||
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
||||||
self.quantize_b = Quantize(codebook_dim, codebook_size)
|
self.quantize_b = Quantize(codebook_dim, codebook_size*2)
|
||||||
self.upsample_t = UpsampleConv(
|
self.upsample_t = UpsampleConv(
|
||||||
codebook_dim, codebook_dim, 5, padding=2
|
codebook_dim, codebook_dim, 5, padding=2
|
||||||
)
|
)
|
||||||
|
|
|
@ -65,8 +65,11 @@ def sequential_checkpoint(fn, partitions, *args):
|
||||||
return fn(*args)
|
return fn(*args)
|
||||||
|
|
||||||
# A fancy alternative to if <flag> checkpoint() else <call>
|
# A fancy alternative to if <flag> checkpoint() else <call>
|
||||||
def possible_checkpoint(enabled, fn, *args):
|
def possible_checkpoint(opt_en, fn, *args):
|
||||||
opt_en = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
if loaded_options is None:
|
||||||
|
enabled = False
|
||||||
|
else:
|
||||||
|
enabled = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
||||||
if enabled and opt_en:
|
if enabled and opt_en:
|
||||||
return torch.utils.checkpoint.checkpoint(fn, *args)
|
return torch.utils.checkpoint.checkpoint(fn, *args)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user