diff --git a/codes/models/gpt_voice/gpt_tts_hf.py b/codes/models/gpt_voice/gpt_tts_hf.py index 1e6f3611..315f1e27 100644 --- a/codes/models/gpt_voice/gpt_tts_hf.py +++ b/codes/models/gpt_voice/gpt_tts_hf.py @@ -71,7 +71,7 @@ class GptTtsHf(nn.Module): conds = torch.stack(conds, dim=1) conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device)) - emb = torch.cat([mel_emb, conds, text_emb], dim=1) + emb = torch.cat([text_emb, conds, mel_emb], dim=1) gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) if get_attns: return gpt_out.attentions