Fix up inference for gpt_tts
This commit is contained in:
parent
5037220ac7
commit
4017236ba9
|
@ -73,10 +73,14 @@ class GptTtsCollater():
|
|||
|
||||
filenames = [j[2] for j in batch]
|
||||
|
||||
padded_qmel_gt = torch.stack(qmels)[:, 1:-1]
|
||||
padded_qmel_gt = padded_qmel_gt * (padded_qmel_gt < 512)
|
||||
|
||||
return {
|
||||
'padded_text': torch.stack(texts),
|
||||
'input_lengths': LongTensor(text_lens),
|
||||
'padded_qmel': torch.stack(qmels),
|
||||
'padded_qmel_gt': padded_qmel_gt,
|
||||
'output_lengths': LongTensor(mel_lens),
|
||||
'filenames': filenames
|
||||
}
|
||||
|
|
|
@ -73,24 +73,27 @@ class GptTts(nn.Module):
|
|||
text_emb = self.text_embedding(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||
|
||||
mel_seq = [self.MEL_START_TOKEN, 0]
|
||||
while mel_seq[-1] != self.MEL_STOP_TOKEN and len(mel_seq) < self.max_mel_frames:
|
||||
mel_seq.append(0)
|
||||
mel_emb = self.mel_embedding(torch.tensor(mel_seq, dtype=torch.long, device=text_inputs.device)).unsqueeze(0)
|
||||
mel_seq = torch.full((text_emb.shape[0],1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
|
||||
stop_encountered = torch.zeros((text_emb.shape[0],), device=text_emb.device)
|
||||
while not torch.all(stop_encountered) and len(mel_seq) < self.max_mel_frames:
|
||||
mel_emb = self.mel_embedding(mel_seq)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
emb = torch.cat([text_emb, mel_emb], dim=1)
|
||||
enc = self.gpt(emb)
|
||||
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
|
||||
mel_logits = self.mel_head(mel_logits)
|
||||
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
|
||||
mel_seq[-1] = mel_codes[-1]
|
||||
mel_seq = torch.cat([mel_seq, mel_codes[:, -1].unsqueeze(1)], dim=1)
|
||||
stop_encountered = torch.logical_or(stop_encountered, mel_seq[:,-1] == self.MEL_STOP_TOKEN)
|
||||
|
||||
if len(mel_seq) >= self.max_mel_frames:
|
||||
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
|
||||
|
||||
# Prevent sending invalid tokens to the VAE
|
||||
mel_seq = [s if s < 512 else 0 for s in mel_seq]
|
||||
return mel_seq[:-1]
|
||||
# Prevent sending invalid tokens to the VAE. Also pad to a length of 3, which is required by the DVAE.
|
||||
mel_seq = mel_seq * (mel_seq < 512)
|
||||
padding_needed = 3-(mel_seq.shape[1]%3)
|
||||
mel_seq = F.pad(mel_seq, (0,padding_needed))
|
||||
return mel_seq
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
Loading…
Reference in New Issue
Block a user