diff --git a/codes/models/audio/music/transformer_diffusion8.py b/codes/models/audio/music/transformer_diffusion8.py index 98c9ab21..973c83ca 100644 --- a/codes/models/audio/music/transformer_diffusion8.py +++ b/codes/models/audio/music/transformer_diffusion8.py @@ -237,7 +237,7 @@ class TransformerDiffusionWithQuantizer(nn.Module): def get_debug_values(self, step, __): if self.quantizer.total_codes > 0: - return {'histogram_codes': self.quantizer.codes[:self.quantizer.total_codes], + return {'histogram_quant_codes': self.quantizer.codes[:self.quantizer.total_codes], 'gumbel_temperature': self.quantizer.quantizer.temperature} else: return {}