From bf94744514e0628c6e6ba21eda76d1fd71fb1252 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 9 Mar 2023 22:47:46 +0000 Subject: [PATCH] I am going to scream --- codes/models/audio/tts/unified_voice2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index 3c72332c..f03f69e2 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -423,11 +423,11 @@ class UnifiedVoice(nn.Module): if text_first: text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) if return_latent: - return mel_logits[:, :-sub] # Despite the name, these are not logits. + return mel_logits[:, :sub] # Despite the name, these are not logits. else: mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent) if return_latent: - return text_logits[:, :-sub] # Despite the name, these are not logits + return text_logits[:, :sub] # Despite the name, these are not logits if return_attentions: return mel_logits