diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat.py b/codes/models/audio/tts/unet_diffusion_tts_flat.py index 714f5a03..699585e0 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat.py @@ -158,7 +158,7 @@ class DiffusionTtsFlat(nn.Module): # Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent. unused_params = [] if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1) + code_emb = self.unconditioned_embedding.repeat(conditioning_input.shape[0], 1, 1) else: unused_params.append(self.unconditioned_embedding) cond_emb = self.contextual_embedder(conditioning_input)