gpt_tts_hf: pad mel tokens with an <end_of_sequence> token.

This commit is contained in:
James Betker 2021-12-12 20:04:50 -07:00
parent 76f86c0e47
commit 4f8c4d130c

View File

@ -22,7 +22,8 @@ class GptTtsHf(nn.Module):
START_MEL_TOKEN = 8192
STOP_MEL_TOKEN = 8193
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_tokens=250, max_conditioning_inputs=3, checkpointing=True):
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_tokens=250, max_conditioning_inputs=3,
checkpointing=True, mel_length_compression=256):
super().__init__()
self.max_mel_tokens = max_mel_tokens
self.max_symbols_per_phrase = max_symbols_per_phrase
@ -30,6 +31,7 @@ class GptTtsHf(nn.Module):
self.model_dim = model_dim
self.max_mel_tokens = max_mel_tokens
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
self.conditioning_embedding = nn.Embedding(self.max_conditioning_inputs, model_dim)
@ -87,13 +89,20 @@ class GptTtsHf(nn.Module):
return text_logits, mel_logits
def forward(self, text_inputs, cond_inputs, mel_targets, return_attentions=False):
def forward(self, text_inputs, cond_inputs, mel_targets, wav_lengths, return_attentions=False):
"""
Forward pass
text_inputs: long tensor, (b,t)
cond_inputs: MEL float tensor, (b,c,80,s)
mel_targets: long tensor, (b,m)
mel_lengths: long tensor, (b,)
"""
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>)
mel_lengths = wav_lengths // self.mel_length_compression
for b in range(len(mel_lengths)):
if mel_lengths[b] < mel_targets.shape[-1]:
mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN
text_logits, mel_logits = self.get_logits(text_inputs, cond_inputs, mel_targets, get_attns=return_attentions)
if return_attentions:
return mel_logits
@ -127,7 +136,7 @@ class GptTtsHf(nn.Module):
fake_inputs = torch.full((text_inputs.shape[0],self.max_symbols_per_phrase+self.max_conditioning_inputs+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
fake_inputs[:,-1] = self.START_MEL_TOKEN
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0,
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
return gen[:, self.max_mel_frames:]
@ -141,4 +150,5 @@ if __name__ == '__main__':
gpt = GptTtsHf()
l = gpt(torch.randint(high=len(symbols), size=(2,100)),
torch.randn(2,2,80,800),
torch.randint(high=8192, size=(2,200)))
torch.randint(high=8192, size=(2,200)),
torch.tensor([150*256,195*256]))