gpt_tts_hf: pad mel tokens with an <end_of_sequence> token.
This commit is contained in:
parent
76f86c0e47
commit
4f8c4d130c
|
@ -22,7 +22,8 @@ class GptTtsHf(nn.Module):
|
||||||
START_MEL_TOKEN = 8192
|
START_MEL_TOKEN = 8192
|
||||||
STOP_MEL_TOKEN = 8193
|
STOP_MEL_TOKEN = 8193
|
||||||
|
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_tokens=250, max_conditioning_inputs=3, checkpointing=True):
|
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_tokens=250, max_conditioning_inputs=3,
|
||||||
|
checkpointing=True, mel_length_compression=256):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_mel_tokens = max_mel_tokens
|
self.max_mel_tokens = max_mel_tokens
|
||||||
self.max_symbols_per_phrase = max_symbols_per_phrase
|
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||||
|
@ -30,6 +31,7 @@ class GptTtsHf(nn.Module):
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.max_mel_tokens = max_mel_tokens
|
self.max_mel_tokens = max_mel_tokens
|
||||||
self.max_conditioning_inputs = max_conditioning_inputs
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
|
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
|
||||||
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim)
|
||||||
self.conditioning_embedding = nn.Embedding(self.max_conditioning_inputs, model_dim)
|
self.conditioning_embedding = nn.Embedding(self.max_conditioning_inputs, model_dim)
|
||||||
|
@ -87,13 +89,20 @@ class GptTtsHf(nn.Module):
|
||||||
|
|
||||||
return text_logits, mel_logits
|
return text_logits, mel_logits
|
||||||
|
|
||||||
def forward(self, text_inputs, cond_inputs, mel_targets, return_attentions=False):
|
def forward(self, text_inputs, cond_inputs, mel_targets, wav_lengths, return_attentions=False):
|
||||||
"""
|
"""
|
||||||
Forward pass
|
Forward pass
|
||||||
text_inputs: long tensor, (b,t)
|
text_inputs: long tensor, (b,t)
|
||||||
cond_inputs: MEL float tensor, (b,c,80,s)
|
cond_inputs: MEL float tensor, (b,c,80,s)
|
||||||
mel_targets: long tensor, (b,m)
|
mel_targets: long tensor, (b,m)
|
||||||
|
mel_lengths: long tensor, (b,)
|
||||||
"""
|
"""
|
||||||
|
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>)
|
||||||
|
mel_lengths = wav_lengths // self.mel_length_compression
|
||||||
|
for b in range(len(mel_lengths)):
|
||||||
|
if mel_lengths[b] < mel_targets.shape[-1]:
|
||||||
|
mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN
|
||||||
|
|
||||||
text_logits, mel_logits = self.get_logits(text_inputs, cond_inputs, mel_targets, get_attns=return_attentions)
|
text_logits, mel_logits = self.get_logits(text_inputs, cond_inputs, mel_targets, get_attns=return_attentions)
|
||||||
if return_attentions:
|
if return_attentions:
|
||||||
return mel_logits
|
return mel_logits
|
||||||
|
@ -127,7 +136,7 @@ class GptTtsHf(nn.Module):
|
||||||
fake_inputs = torch.full((text_inputs.shape[0],self.max_symbols_per_phrase+self.max_conditioning_inputs+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
fake_inputs = torch.full((text_inputs.shape[0],self.max_symbols_per_phrase+self.max_conditioning_inputs+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||||
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
||||||
|
|
||||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0,
|
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
|
||||||
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
|
max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True)
|
||||||
return gen[:, self.max_mel_frames:]
|
return gen[:, self.max_mel_frames:]
|
||||||
|
|
||||||
|
@ -141,4 +150,5 @@ if __name__ == '__main__':
|
||||||
gpt = GptTtsHf()
|
gpt = GptTtsHf()
|
||||||
l = gpt(torch.randint(high=len(symbols), size=(2,100)),
|
l = gpt(torch.randint(high=len(symbols), size=(2,100)),
|
||||||
torch.randn(2,2,80,800),
|
torch.randn(2,2,80,800),
|
||||||
torch.randint(high=8192, size=(2,200)))
|
torch.randint(high=8192, size=(2,200)),
|
||||||
|
torch.tensor([150*256,195*256]))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user