gpt_tts_hf: pad mel tokens with an <end_of_sequence> token.
This commit is contained in:
parent
76f86c0e47
commit
4f8c4d130c
|
@ -22,7 +22,8 @@ class GptTtsHf(nn.Module):
|
|||
START_MEL_TOKEN = 8192
|
||||
STOP_MEL_TOKEN = 8193
|
||||
|
||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_tokens=250, max_conditioning_inputs=3, checkpointing=True):
|
||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_tokens=250, max_conditioning_inputs=3,
|
||||
checkpointing=True, mel_length_compression=256):
|
||||
super().__init__()
|
||||
self.max_mel_tokens = max_mel_tokens
|
||||
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||
|
@ -30,6 +31,7 @@ class GptTtsHf(nn.Module):
|
|||
self.model_dim = model_dim
|
||||
self.max_mel_tokens = max_mel_tokens
|
||||
self.max_conditioning_inputs = max_conditioning_inputs
|
||||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||
self.conditioning_embedding = nn.Embedding(self.max_conditioning_inputs, model_dim)
|
||||
|
@ -87,13 +89,20 @@ class GptTtsHf(nn.Module):
|
|||
|
||||
return text_logits, mel_logits
|
||||
|
||||
def forward(self, text_inputs, cond_inputs, mel_targets, return_attentions=False):
|
||||
def forward(self, text_inputs, cond_inputs, mel_targets, wav_lengths, return_attentions=False):
|
||||
"""
|
||||
Forward pass
|
||||
text_inputs: long tensor, (b,t)
|
||||
cond_inputs: MEL float tensor, (b,c,80,s)
|
||||
mel_targets: long tensor, (b,m)
|
||||
mel_lengths: long tensor, (b,)
|
||||
"""
|
||||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>)
|
||||
mel_lengths = wav_lengths // self.mel_length_compression
|
||||
for b in range(len(mel_lengths)):
|
||||
if mel_lengths[b] < mel_targets.shape[-1]:
|
||||
mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN
|
||||
|
||||
text_logits, mel_logits = self.get_logits(text_inputs, cond_inputs, mel_targets, get_attns=return_attentions)
|
||||
if return_attentions:
|
||||
return mel_logits
|
||||
|
@ -127,7 +136,7 @@ class GptTtsHf(nn.Module):
|
|||
fake_inputs = torch.full((text_inputs.shape[0],self.max_symbols_per_phrase+self.max_conditioning_inputs+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
||||
|
||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0,
|
||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
|
||||
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
|
||||
return gen[:, self.max_mel_frames:]
|
||||
|
||||
|
@ -141,4 +150,5 @@ if __name__ == '__main__':
|
|||
gpt = GptTtsHf()
|
||||
l = gpt(torch.randint(high=len(symbols), size=(2,100)),
|
||||
torch.randn(2,2,80,800),
|
||||
torch.randint(high=8192, size=(2,200)))
|
||||
torch.randint(high=8192, size=(2,200)),
|
||||
torch.tensor([150*256,195*256]))
|
||||
|
|
Loading…
Reference in New Issue
Block a user