diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 16e6b148..cd5c25cc 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -125,8 +125,11 @@ class DiscreteVAE(nn.Module): return images def get_debug_values(self, step, __): - # Report annealing schedule - return {'histogram_codes': self.codes} + if self.record_codes: + # Report annealing schedule + return {'histogram_codes': self.codes} + else: + return {} @torch.no_grad() @eval_decorator