From 0396a9d2ca2b25856554346a955638431df88217 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 30 Sep 2021 08:09:07 -0600 Subject: [PATCH] Increase baseline codes recording across all dvae models --- codes/models/diffusion/diffusion_dvae.py | 2 +- codes/models/gpt_voice/attention_dvae.py | 2 +- codes/models/gpt_voice/lucidrains_dvae.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/codes/models/diffusion/diffusion_dvae.py b/codes/models/diffusion/diffusion_dvae.py index 1f04de0e..305f7879 100644 --- a/codes/models/diffusion/diffusion_dvae.py +++ b/codes/models/diffusion/diffusion_dvae.py @@ -112,7 +112,7 @@ class DiffusionDVAE(nn.Module): #self.quantizer = Quantize(quantize_dim, num_discrete_codes, balancing_heuristic=True) self.quantizer = GumbelQuantizer(quantize_dim, quantize_dim, num_discrete_codes) # For recording codebook usage. - self.codes = torch.zeros((131072,), dtype=torch.long) + self.codes = torch.zeros((1228800,), dtype=torch.long) self.code_ind = 0 self.internal_step = 0 decoder_channels = [model_channels * channel_mult[s-1] for s in spectrogram_conditioning_levels] diff --git a/codes/models/gpt_voice/attention_dvae.py b/codes/models/gpt_voice/attention_dvae.py index b3ce8d5b..05be0d33 100644 --- a/codes/models/gpt_voice/attention_dvae.py +++ b/codes/models/gpt_voice/attention_dvae.py @@ -147,7 +147,7 @@ class AttentionDVAE(nn.Module): self.normalization = normalization self.record_codes = record_codes if record_codes: - self.codes = torch.zeros((32768,), dtype=torch.long) + self.codes = torch.zeros((1228800,), dtype=torch.long) self.code_ind = 0 self.internal_step = 0 diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 31583d7e..812c45a6 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -142,7 +142,7 @@ class DiscreteVAE(nn.Module): self.normalization = normalization self.record_codes = record_codes if record_codes: - self.codes = torch.zeros((32768,), dtype=torch.long) + self.codes = torch.zeros((1228800,), dtype=torch.long) self.code_ind = 0 self.internal_step = 0