gpt_tts: format conditioning inputs more for contextual voice clues and less for prosidy
also support single conditional inputs
This commit is contained in:
parent
c813befd53
commit
712d746e9b
|
@ -1,3 +1,4 @@
|
|||
import random
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
|
@ -23,7 +24,7 @@ class GptTtsHf(nn.Module):
|
|||
STOP_MEL_TOKEN = 8193
|
||||
|
||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_tokens=250, max_conditioning_inputs=3,
|
||||
checkpointing=True, mel_length_compression=1024):
|
||||
checkpointing=True, mel_length_compression=1024, max_conditioning_length=44100//256):
|
||||
super().__init__()
|
||||
self.max_mel_tokens = max_mel_tokens
|
||||
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||
|
@ -45,6 +46,7 @@ class GptTtsHf(nn.Module):
|
|||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||
self.mel_head = nn.Linear(model_dim, self.NUMBER_MEL_CODES)
|
||||
self.max_conditioning_length = max_conditioning_length
|
||||
|
||||
|
||||
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
||||
|
@ -87,12 +89,21 @@ class GptTtsHf(nn.Module):
|
|||
mel_targets: long tensor, (b,m)
|
||||
mel_lengths: long tensor, (b,)
|
||||
"""
|
||||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>)
|
||||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||
mel_lengths = wav_lengths // self.mel_length_compression
|
||||
for b in range(len(mel_lengths)):
|
||||
if mel_lengths[b] < mel_targets.shape[-1]:
|
||||
mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN
|
||||
|
||||
# Format conditioning inputs properly.
|
||||
if len(cond_inputs.shape) == 3:
|
||||
cond_inputs = cond_inputs.unsqueeze(1) # Format a single conditioning input as a set of {1}
|
||||
if cond_inputs.shape[-1] > self.max_conditioning_length:
|
||||
# Remember, that this doesn't necessarily mean that the conditioning inputs aren't mostly zero-padded, so
|
||||
# skew trimming towards the front end of the clip.
|
||||
rand_clip = random.randint(0, min(50, cond_inputs.shape[-1]-self.max_conditioning_length))
|
||||
cond_inputs = cond_inputs[:,:,:,rand_clip:rand_clip+self.max_conditioning_length]
|
||||
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
||||
mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_targets, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN)
|
||||
text_logits, mel_logits = self.get_logits(text_inputs, cond_inputs, mel_inputs, get_attns=return_attentions)
|
||||
|
@ -103,8 +114,6 @@ class GptTtsHf(nn.Module):
|
|||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||
|
||||
def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8, repetition_penalty=1):
|
||||
#text_inputs, cond_inputs = torch.load("debug_text_and_cond.pt")
|
||||
|
||||
if not hasattr(self, 'inference_model'):
|
||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
|
||||
|
||||
|
@ -138,6 +147,6 @@ def register_gpt_tts_hf(opt_net, opt):
|
|||
if __name__ == '__main__':
|
||||
gpt = GptTtsHf(model_dim=1024, heads=16)
|
||||
l = gpt(torch.randint(high=len(symbols), size=(2,200)),
|
||||
torch.randn(2,2,80,800),
|
||||
torch.randn(2,80,800),
|
||||
torch.randint(high=8192, size=(2,250)),
|
||||
torch.tensor([150*256,195*256]))
|
||||
|
|
Loading…
Reference in New Issue
Block a user