diff --git a/codes/models/audio/music/flat_diffusion.py b/codes/models/audio/music/flat_diffusion.py index c143f6f1..e90196b7 100644 --- a/codes/models/audio/music/flat_diffusion.py +++ b/codes/models/audio/music/flat_diffusion.py @@ -282,8 +282,7 @@ class FlatDiffusion(nn.Module): else: code_emb, cond_emb, mel_pred = self.timestep_independent(codes, conditioning_input, x.shape[-1], prenet_latent, True) if prenet_latent is None: - unused_params.extend(list(self.latent_conditioner.parameters())) - + unused_params.extend(list(self.latent_conditioner.parameters()) + [self.latent_fade]) unused_params.append(self.unconditioned_embedding) blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb