diff --git a/codes/models/gpt_voice/gpt_asr.py b/codes/models/gpt_voice/gpt_asr.py index c63b2b71..6dc1311f 100644 --- a/codes/models/gpt_voice/gpt_asr.py +++ b/codes/models/gpt_voice/gpt_asr.py @@ -120,7 +120,7 @@ class GptAsr(nn.Module): text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device) probabilities = torch.ones((b,), device=mel_emb.device) - while len(text_seq) < self.max_mel_frames: + while text_seq.shape[-1] < self.MAX_SYMBOLS_PER_PHRASE: text_emb = self.text_embedding(text_seq) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device)) if text_emb.shape[0] != mel_emb.shape[0]: