Beam search implementation in one pass? Dayyyum
This commit is contained in:
parent
83ab5e6a00
commit
01cfae28d8
|
@ -71,9 +71,9 @@ class GptTts(nn.Module):
|
||||||
return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets
|
return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets
|
||||||
|
|
||||||
def inference(self, text_inputs):
|
def inference(self, text_inputs):
|
||||||
b, _ = text_inputs.shape
|
b, s = text_inputs.shape
|
||||||
text_emb = self.text_embedding(text_inputs)
|
text_emb = self.text_embedding(text_inputs)
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(s, device=text_inputs.device))
|
||||||
|
|
||||||
mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
|
mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
|
||||||
stop_encountered = torch.zeros((b,), device=text_emb.device)
|
stop_encountered = torch.zeros((b,), device=text_emb.device)
|
||||||
|
@ -98,7 +98,46 @@ class GptTts(nn.Module):
|
||||||
return mel_seq
|
return mel_seq
|
||||||
|
|
||||||
def inference_beam(self, text_inputs):
|
def inference_beam(self, text_inputs):
|
||||||
pass
|
beam_width = 16
|
||||||
|
|
||||||
|
b, s = text_inputs.shape
|
||||||
|
assert b == 1 # Beam search only works on batches of one.
|
||||||
|
text_emb = self.text_embedding(text_inputs)
|
||||||
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(s, device=text_inputs.device))
|
||||||
|
mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device)
|
||||||
|
probabilities = torch.ones((b,), device=text_emb.device)
|
||||||
|
while len(mel_seq) < self.max_mel_frames:
|
||||||
|
mel_emb = self.mel_embedding(mel_seq)
|
||||||
|
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
|
if text_emb.shape[0] != mel_emb.shape[0]:
|
||||||
|
text_emb = text_emb.repeat(mel_emb.shape[0], 1, 1)
|
||||||
|
emb = torch.cat([text_emb, mel_emb], dim=1)
|
||||||
|
enc = self.gpt(emb)
|
||||||
|
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
|
||||||
|
mel_logits = self.mel_head(mel_logits)
|
||||||
|
topk = torch.topk(F.softmax(mel_logits[:, -1], dim=-1), dim=-1, k=beam_width)
|
||||||
|
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
|
||||||
|
probabilities, sort_indices = torch.sort(probabilities, descending=True)
|
||||||
|
probabilities = probabilities[:beam_width]
|
||||||
|
|
||||||
|
mel_seq = mel_seq.repeat_interleave(beam_width, dim=0)
|
||||||
|
codes = topk.indices.flatten()
|
||||||
|
mel_seq = torch.cat([mel_seq, codes.unsqueeze(1)], dim=1)
|
||||||
|
mel_seq = mel_seq[sort_indices]
|
||||||
|
mel_seq = mel_seq[:beam_width]
|
||||||
|
|
||||||
|
if torch.all(torch.any(mel_seq == self.MEL_STOP_TOKEN, dim=1)):
|
||||||
|
break
|
||||||
|
|
||||||
|
if mel_seq.shape[1] >= self.max_mel_frames:
|
||||||
|
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
|
||||||
|
|
||||||
|
# Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE)
|
||||||
|
mel_seq = mel_seq[0, 1:-1].unsqueeze(0) # Pick most likely outcome, remove first and last tokens, which were artificially added for GPT
|
||||||
|
mel_seq = mel_seq * (mel_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens.
|
||||||
|
|
||||||
|
return mel_seq
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
Loading…
Reference in New Issue
Block a user