diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 1295a50b..888753c0 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -176,7 +176,7 @@ class DiscreteVAE(nn.Module): return images def get_debug_values(self, step, __): - if self.record_codes: + if self.record_codes and self.total_codes > 0: # Report annealing schedule return {'histogram_codes': self.codes[:self.total_codes]} else: