From b8bec22f1a1947e646928e1064b1a17a7a9804ca Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 15 Aug 2021 20:53:42 -0600 Subject: [PATCH] Fix gpt_asr inference bug --- codes/models/gpt_voice/gpt_asr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codes/models/gpt_voice/gpt_asr.py b/codes/models/gpt_voice/gpt_asr.py index c63b2b71..6dc1311f 100644 --- a/codes/models/gpt_voice/gpt_asr.py +++ b/codes/models/gpt_voice/gpt_asr.py @@ -120,7 +120,7 @@ class GptAsr(nn.Module): text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device) probabilities = torch.ones((b,), device=mel_emb.device) - while len(text_seq) < self.max_mel_frames: + while text_seq.shape[-1] < self.MAX_SYMBOLS_PER_PHRASE: text_emb = self.text_embedding(text_seq) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device)) if text_emb.shape[0] != mel_emb.shape[0]: