diff --git a/codes/models/audio/music/flat_diffusion.py b/codes/models/audio/music/flat_diffusion.py index e90196b7..0d6a343a 100644 --- a/codes/models/audio/music/flat_diffusion.py +++ b/codes/models/audio/music/flat_diffusion.py @@ -271,6 +271,8 @@ class FlatDiffusion(nn.Module): assert not (return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you." unused_params = [] + if not return_code_pred: + unused_params.extend(list(self.mel_head.parameters())) if conditioning_free: code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))