Fix gpt_asr inference bug
This commit is contained in:
parent
3580c52eac
commit
b8bec22f1a
|
@ -120,7 +120,7 @@ class GptAsr(nn.Module):
|
||||||
|
|
||||||
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
|
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
|
||||||
probabilities = torch.ones((b,), device=mel_emb.device)
|
probabilities = torch.ones((b,), device=mel_emb.device)
|
||||||
while len(text_seq) < self.max_mel_frames:
|
while text_seq.shape[-1] < self.MAX_SYMBOLS_PER_PHRASE:
|
||||||
text_emb = self.text_embedding(text_seq)
|
text_emb = self.text_embedding(text_seq)
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
|
||||||
if text_emb.shape[0] != mel_emb.shape[0]:
|
if text_emb.shape[0] != mel_emb.shape[0]:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user