diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat0.py b/codes/models/audio/tts/unet_diffusion_tts_flat0.py index d46c4549..57f8a960 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat0.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat0.py @@ -264,12 +264,12 @@ class DiffusionTtsFlat(nn.Module): code_emb = precomputed_aligned_embeddings else: code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, x.shape[-1], True) + if is_latent(aligned_conditioning): + unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + else: + unused_params.extend(list(self.latent_converter.parameters())) unused_params.append(self.unconditioned_embedding) - if is_latent(aligned_conditioning): - unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) - else: - unused_params.extend(list(self.autoregressive_latent_converter.parameters())) time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)