forked from mrq/DL-Art-School
Add a sampling beam search
This commit is contained in:
parent
d4e33bf15f
commit
1068f53b78
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from munch import munchify
|
||||
|
||||
from models.gpt_voice.lucidrains_gpt import Transformer
|
||||
from models.gpt_voice.min_gpt import GPT, GPTConfig
|
||||
|
@ -97,8 +98,25 @@ class GptTts(nn.Module):
|
|||
|
||||
return mel_seq
|
||||
|
||||
def inference_beam(self, text_inputs):
|
||||
def inference_beam_topk(self, text):
|
||||
def topk_sampler(distribution, k):
|
||||
return torch.topk(distribution, k=k, dim=-1)
|
||||
return self.inference_beam(text, topk_sampler)
|
||||
|
||||
def inference_beam_sampled(self, text):
|
||||
def multinomial_sampler(distribution, k):
|
||||
indices = torch.multinomial(distribution, num_samples=k, replacement=False)
|
||||
values = torch.gather(distribution, dim=1, index=indices)
|
||||
class container:
|
||||
def __init__(self, i, v):
|
||||
self.indices = i
|
||||
self.values = v
|
||||
return container(indices, values)
|
||||
return self.inference_beam(text, multinomial_sampler)
|
||||
|
||||
def inference_beam(self, text_inputs, sampler_fn):
|
||||
beam_width = 16
|
||||
temperature = .8
|
||||
|
||||
b, s = text_inputs.shape
|
||||
assert b == 1 # Beam search only works on batches of one.
|
||||
|
@ -115,7 +133,7 @@ class GptTts(nn.Module):
|
|||
enc = self.gpt(emb)
|
||||
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
|
||||
mel_logits = self.mel_head(mel_logits)
|
||||
topk = torch.topk(F.softmax(mel_logits[:, -1], dim=-1), dim=-1, k=beam_width)
|
||||
topk = sampler_fn(F.softmax(temperature * mel_logits[:, -1], dim=-1), k=beam_width)
|
||||
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
|
||||
probabilities, sort_indices = torch.sort(probabilities, descending=True)
|
||||
probabilities = probabilities[:beam_width]
|
||||
|
|
Loading…
Reference in New Issue
Block a user