diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index 96ad2ce4..b9884a3d 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -71,9 +71,9 @@ class GptTts(nn.Module): return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets def inference(self, text_inputs): - b, _ = text_inputs.shape + b, s = text_inputs.shape text_emb = self.text_embedding(text_inputs) - text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + text_emb = text_emb + self.text_pos_embedding(torch.arange(s, device=text_inputs.device)) mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device) stop_encountered = torch.zeros((b,), device=text_emb.device) @@ -98,7 +98,46 @@ class GptTts(nn.Module): return mel_seq def inference_beam(self, text_inputs): - pass + beam_width = 16 + + b, s = text_inputs.shape + assert b == 1 # Beam search only works on batches of one. + text_emb = self.text_embedding(text_inputs) + text_emb = text_emb + self.text_pos_embedding(torch.arange(s, device=text_inputs.device)) + mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device) + probabilities = torch.ones((b,), device=text_emb.device) + while len(mel_seq) < self.max_mel_frames: + mel_emb = self.mel_embedding(mel_seq) + mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) + if text_emb.shape[0] != mel_emb.shape[0]: + text_emb = text_emb.repeat(mel_emb.shape[0], 1, 1) + emb = torch.cat([text_emb, mel_emb], dim=1) + enc = self.gpt(emb) + mel_logits = self.final_norm(enc[:, text_emb.shape[1]:]) + mel_logits = self.mel_head(mel_logits) + topk = torch.topk(F.softmax(mel_logits[:, -1], dim=-1), dim=-1, k=beam_width) + probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten()) + probabilities, sort_indices = torch.sort(probabilities, descending=True) + probabilities = probabilities[:beam_width] + + mel_seq = mel_seq.repeat_interleave(beam_width, dim=0) + codes = topk.indices.flatten() + mel_seq = torch.cat([mel_seq, codes.unsqueeze(1)], dim=1) + mel_seq = mel_seq[sort_indices] + mel_seq = mel_seq[:beam_width] + + if torch.all(torch.any(mel_seq == self.MEL_STOP_TOKEN, dim=1)): + break + + if mel_seq.shape[1] >= self.max_mel_frames: + print("Warning! Encountered frame limit before a stop token. Output is likely wrong.") + + # Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE) + mel_seq = mel_seq[0, 1:-1].unsqueeze(0) # Pick most likely outcome, remove first and last tokens, which were artificially added for GPT + mel_seq = mel_seq * (mel_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens. + + return mel_seq + @register_model