diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 77e716ec..58a11870 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -540,6 +540,7 @@ class TransformerDiffusionWithCheaterLatent(nn.Module): def get_debug_values(self, step, __): self.internal_step = step + return {} def get_grad_norm_parameter_groups(self): attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.diff.layers]))