diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index 3c72332c..f03f69e2 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -423,11 +423,11 @@ class UnifiedVoice(nn.Module): if text_first: text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) if return_latent: - return mel_logits[:, :-sub] # Despite the name, these are not logits. + return mel_logits[:, :sub] # Despite the name, these are not logits. else: mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent) if return_latent: - return text_logits[:, :-sub] # Despite the name, these are not logits + return text_logits[:, :sub] # Despite the name, these are not logits if return_attentions: return mel_logits