This commit is contained in:
James Betker 2021-12-03 08:53:09 -07:00
parent 101a01f744
commit 380a5d5475

View File

@ -71,7 +71,7 @@ class GptTtsHf(nn.Module):
conds = torch.stack(conds, dim=1) conds = torch.stack(conds, dim=1)
conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device)) conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device))
emb = torch.cat([mel_emb, conds, text_emb], dim=1) emb = torch.cat([text_emb, conds, mel_emb], dim=1)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
if get_attns: if get_attns:
return gpt_out.attentions return gpt_out.attentions