gdi..
This commit is contained in:
parent
101a01f744
commit
380a5d5475
|
@ -71,7 +71,7 @@ class GptTtsHf(nn.Module):
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device))
|
conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device))
|
||||||
|
|
||||||
emb = torch.cat([mel_emb, conds, text_emb], dim=1)
|
emb = torch.cat([text_emb, conds, mel_emb], dim=1)
|
||||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||||
if get_attns:
|
if get_attns:
|
||||||
return gpt_out.attentions
|
return gpt_out.attentions
|
||||||
|
|
Loading…
Reference in New Issue
Block a user