Increase baseline codes recording across all dvae models
This commit is contained in:
parent
f84ccbdfb2
commit
0396a9d2ca
|
@ -112,7 +112,7 @@ class DiffusionDVAE(nn.Module):
|
|||
#self.quantizer = Quantize(quantize_dim, num_discrete_codes, balancing_heuristic=True)
|
||||
self.quantizer = GumbelQuantizer(quantize_dim, quantize_dim, num_discrete_codes)
|
||||
# For recording codebook usage.
|
||||
self.codes = torch.zeros((131072,), dtype=torch.long)
|
||||
self.codes = torch.zeros((1228800,), dtype=torch.long)
|
||||
self.code_ind = 0
|
||||
self.internal_step = 0
|
||||
decoder_channels = [model_channels * channel_mult[s-1] for s in spectrogram_conditioning_levels]
|
||||
|
|
|
@ -147,7 +147,7 @@ class AttentionDVAE(nn.Module):
|
|||
self.normalization = normalization
|
||||
self.record_codes = record_codes
|
||||
if record_codes:
|
||||
self.codes = torch.zeros((32768,), dtype=torch.long)
|
||||
self.codes = torch.zeros((1228800,), dtype=torch.long)
|
||||
self.code_ind = 0
|
||||
self.internal_step = 0
|
||||
|
||||
|
|
|
@ -142,7 +142,7 @@ class DiscreteVAE(nn.Module):
|
|||
self.normalization = normalization
|
||||
self.record_codes = record_codes
|
||||
if record_codes:
|
||||
self.codes = torch.zeros((32768,), dtype=torch.long)
|
||||
self.codes = torch.zeros((1228800,), dtype=torch.long)
|
||||
self.code_ind = 0
|
||||
self.internal_step = 0
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user