Add a sampling beam search

This commit is contained in:
James Betker 2021-08-09 11:56:06 -06:00
parent d4e33bf15f
commit 1068f53b78

View File

@ -1,6 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from munch import munchify
from models.gpt_voice.lucidrains_gpt import Transformer from models.gpt_voice.lucidrains_gpt import Transformer
from models.gpt_voice.min_gpt import GPT, GPTConfig from models.gpt_voice.min_gpt import GPT, GPTConfig
@ -97,8 +98,25 @@ class GptTts(nn.Module):
return mel_seq return mel_seq
def inference_beam(self, text_inputs): def inference_beam_topk(self, text):
def topk_sampler(distribution, k):
return torch.topk(distribution, k=k, dim=-1)
return self.inference_beam(text, topk_sampler)
def inference_beam_sampled(self, text):
def multinomial_sampler(distribution, k):
indices = torch.multinomial(distribution, num_samples=k, replacement=False)
values = torch.gather(distribution, dim=1, index=indices)
class container:
def __init__(self, i, v):
self.indices = i
self.values = v
return container(indices, values)
return self.inference_beam(text, multinomial_sampler)
def inference_beam(self, text_inputs, sampler_fn):
beam_width = 16 beam_width = 16
temperature = .8
b, s = text_inputs.shape b, s = text_inputs.shape
assert b == 1 # Beam search only works on batches of one. assert b == 1 # Beam search only works on batches of one.
@ -115,7 +133,7 @@ class GptTts(nn.Module):
enc = self.gpt(emb) enc = self.gpt(emb)
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:]) mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
mel_logits = self.mel_head(mel_logits) mel_logits = self.mel_head(mel_logits)
topk = torch.topk(F.softmax(mel_logits[:, -1], dim=-1), dim=-1, k=beam_width) topk = sampler_fn(F.softmax(temperature * mel_logits[:, -1], dim=-1), k=beam_width)
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten()) probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
probabilities, sort_indices = torch.sort(probabilities, descending=True) probabilities, sort_indices = torch.sort(probabilities, descending=True)
probabilities = probabilities[:beam_width] probabilities = probabilities[:beam_width]