Fix up inference for gpt_tts

This commit is contained in:
James Betker 2021-08-05 06:46:30 -06:00
parent 5037220ac7
commit 4017236ba9
2 changed files with 15 additions and 8 deletions

View File

@ -73,10 +73,14 @@ class GptTtsCollater():
filenames = [j[2] for j in batch]
padded_qmel_gt = torch.stack(qmels)[:, 1:-1]
padded_qmel_gt = padded_qmel_gt * (padded_qmel_gt < 512)
return {
'padded_text': torch.stack(texts),
'input_lengths': LongTensor(text_lens),
'padded_qmel': torch.stack(qmels),
'padded_qmel_gt': padded_qmel_gt,
'output_lengths': LongTensor(mel_lens),
'filenames': filenames
}

View File

@ -73,24 +73,27 @@ class GptTts(nn.Module):
text_emb = self.text_embedding(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
mel_seq = [self.MEL_START_TOKEN, 0]
while mel_seq[-1] != self.MEL_STOP_TOKEN and len(mel_seq) < self.max_mel_frames:
mel_seq.append(0)
mel_emb = self.mel_embedding(torch.tensor(mel_seq, dtype=torch.long, device=text_inputs.device)).unsqueeze(0)
mel_seq = torch.full((text_emb.shape[0],1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
stop_encountered = torch.zeros((text_emb.shape[0],), device=text_emb.device)
while not torch.all(stop_encountered) and len(mel_seq) < self.max_mel_frames:
mel_emb = self.mel_embedding(mel_seq)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
emb = torch.cat([text_emb, mel_emb], dim=1)
enc = self.gpt(emb)
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
mel_logits = self.mel_head(mel_logits)
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
mel_seq[-1] = mel_codes[-1]
mel_seq = torch.cat([mel_seq, mel_codes[:, -1].unsqueeze(1)], dim=1)
stop_encountered = torch.logical_or(stop_encountered, mel_seq[:,-1] == self.MEL_STOP_TOKEN)
if len(mel_seq) >= self.max_mel_frames:
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
# Prevent sending invalid tokens to the VAE
mel_seq = [s if s < 512 else 0 for s in mel_seq]
return mel_seq[:-1]
# Prevent sending invalid tokens to the VAE. Also pad to a length of 3, which is required by the DVAE.
mel_seq = mel_seq * (mel_seq < 512)
padding_needed = 3-(mel_seq.shape[1]%3)
mel_seq = F.pad(mel_seq, (0,padding_needed))
return mel_seq
@register_model