diff --git a/codes/models/audio/music/transformer_diffusion7.py b/codes/models/audio/music/transformer_diffusion7.py index b7e8190b..e1c66852 100644 --- a/codes/models/audio/music/transformer_diffusion7.py +++ b/codes/models/audio/music/transformer_diffusion7.py @@ -208,11 +208,6 @@ class TransformerDiffusionWithQuantizer(nn.Module): self.m2v.quantizer.temperature = self.m2v.min_gumbel_temperature del self.m2v.up - self.codes = torch.zeros((3000000,), dtype=torch.long) - self.internal_step = 0 - self.code_ind = 0 - self.total_codes = 0 - def update_for_step(self, step, *args): self.internal_step = step self.m2v.quantizer.temperature = max( @@ -226,7 +221,7 @@ class TransformerDiffusionWithQuantizer(nn.Module): conditioning_free=conditioning_free) def get_debug_values(self, step, __): - if self.total_codes > 0: + if self.m2v.total_codes > 0: return {'histogram_codes': self.m2v.codes[:self.m2v.total_codes]} else: return {}