diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index b9884a3d..e8fe4d9f 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from munch import munchify from models.gpt_voice.lucidrains_gpt import Transformer from models.gpt_voice.min_gpt import GPT, GPTConfig @@ -97,8 +98,25 @@ class GptTts(nn.Module): return mel_seq - def inference_beam(self, text_inputs): + def inference_beam_topk(self, text): + def topk_sampler(distribution, k): + return torch.topk(distribution, k=k, dim=-1) + return self.inference_beam(text, topk_sampler) + + def inference_beam_sampled(self, text): + def multinomial_sampler(distribution, k): + indices = torch.multinomial(distribution, num_samples=k, replacement=False) + values = torch.gather(distribution, dim=1, index=indices) + class container: + def __init__(self, i, v): + self.indices = i + self.values = v + return container(indices, values) + return self.inference_beam(text, multinomial_sampler) + + def inference_beam(self, text_inputs, sampler_fn): beam_width = 16 + temperature = .8 b, s = text_inputs.shape assert b == 1 # Beam search only works on batches of one. @@ -115,7 +133,7 @@ class GptTts(nn.Module): enc = self.gpt(emb) mel_logits = self.final_norm(enc[:, text_emb.shape[1]:]) mel_logits = self.mel_head(mel_logits) - topk = torch.topk(F.softmax(mel_logits[:, -1], dim=-1), dim=-1, k=beam_width) + topk = sampler_fn(F.softmax(temperature * mel_logits[:, -1], dim=-1), k=beam_width) probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten()) probabilities, sort_indices = torch.sort(probabilities, descending=True) probabilities = probabilities[:beam_width]