diff --git a/codes/models/audio/music/transformer_diffusion_with_point_conditioning.py b/codes/models/audio/music/transformer_diffusion_with_point_conditioning.py index 68b09b4a..444163fd 100644 --- a/codes/models/audio/music/transformer_diffusion_with_point_conditioning.py +++ b/codes/models/audio/music/transformer_diffusion_with_point_conditioning.py @@ -233,7 +233,7 @@ class TransformerDiffusionWithConditioningEncoder(nn.Module): def get_grad_norm_parameter_groups(self): groups = self.diff.get_grad_norm_parameter_groups() groups['conditioning_encoder'] = list(self.conditioning_encoder.parameters()) - return + return groups def before_step(self, step): scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \