From f3db41f125914da7e7ee01872f529b7ed187b7d8 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 18 Nov 2021 00:34:37 -0700 Subject: [PATCH] Fix code logging --- codes/models/gpt_voice/lucidrains_dvae.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index d9c9eac7..eedaa9ce 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -147,6 +147,7 @@ class DiscreteVAE(nn.Module): if record_codes: self.codes = torch.zeros((1228800,), dtype=torch.long) self.code_ind = 0 + self.total_codes = 0 self.internal_step = 0 def norm(self, images): @@ -163,7 +164,7 @@ class DiscreteVAE(nn.Module): def get_debug_values(self, step, __): if self.record_codes: # Report annealing schedule - return {'histogram_codes': self.codes} + return {'histogram_codes': self.codes[:self.total_codes]} else: return {} @@ -243,6 +244,7 @@ class DiscreteVAE(nn.Module): self.code_ind = self.code_ind + l if self.code_ind >= self.codes.shape[0]: self.code_ind = 0 + self.total_codes += 1 self.internal_step += 1