I am going to scream

This commit is contained in:
mrq 2023-03-09 22:47:46 +00:00
parent 84c8196da5
commit bf94744514

View File

@ -423,11 +423,11 @@ class UnifiedVoice(nn.Module):
if text_first:
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent:
return mel_logits[:, :-sub] # Despite the name, these are not logits.
return mel_logits[:, :sub] # Despite the name, these are not logits.
else:
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent:
return text_logits[:, :-sub] # Despite the name, these are not logits
return text_logits[:, :sub] # Despite the name, these are not logits
if return_attentions:
return mel_logits