Fix gpt_asr inference bug

This commit is contained in:
James Betker 2021-08-15 20:53:42 -06:00
parent 3580c52eac
commit b8bec22f1a

View File

@ -120,7 +120,7 @@ class GptAsr(nn.Module):
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
probabilities = torch.ones((b,), device=mel_emb.device)
while len(text_seq) < self.max_mel_frames:
while text_seq.shape[-1] < self.MAX_SYMBOLS_PER_PHRASE:
text_emb = self.text_embedding(text_seq)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
if text_emb.shape[0] != mel_emb.shape[0]: