diff --git a/codes/models/gpt_voice/gpt_tts_hf.py b/codes/models/gpt_voice/gpt_tts_hf.py index c57d492e..57818788 100644 --- a/codes/models/gpt_voice/gpt_tts_hf.py +++ b/codes/models/gpt_voice/gpt_tts_hf.py @@ -22,7 +22,8 @@ class GptTtsHf(nn.Module): START_MEL_TOKEN = 8192 STOP_MEL_TOKEN = 8193 - def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_tokens=250, max_conditioning_inputs=3, checkpointing=True): + def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_tokens=250, max_conditioning_inputs=3, + checkpointing=True, mel_length_compression=256): super().__init__() self.max_mel_tokens = max_mel_tokens self.max_symbols_per_phrase = max_symbols_per_phrase @@ -30,6 +31,7 @@ class GptTtsHf(nn.Module): self.model_dim = model_dim self.max_mel_tokens = max_mel_tokens self.max_conditioning_inputs = max_conditioning_inputs + self.mel_length_compression = mel_length_compression self.conditioning_encoder = AudioMiniEncoder(80, model_dim) self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) self.conditioning_embedding = nn.Embedding(self.max_conditioning_inputs, model_dim) @@ -87,13 +89,20 @@ class GptTtsHf(nn.Module): return text_logits, mel_logits - def forward(self, text_inputs, cond_inputs, mel_targets, return_attentions=False): + def forward(self, text_inputs, cond_inputs, mel_targets, wav_lengths, return_attentions=False): """ Forward pass text_inputs: long tensor, (b,t) cond_inputs: MEL float tensor, (b,c,80,s) mel_targets: long tensor, (b,m) + mel_lengths: long tensor, (b,) """ + # Set padding areas within MEL (currently it is coded with the MEL code for ) + mel_lengths = wav_lengths // self.mel_length_compression + for b in range(len(mel_lengths)): + if mel_lengths[b] < mel_targets.shape[-1]: + mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN + text_logits, mel_logits = self.get_logits(text_inputs, cond_inputs, mel_targets, get_attns=return_attentions) if return_attentions: return mel_logits @@ -127,7 +136,7 @@ class GptTtsHf(nn.Module): fake_inputs = torch.full((text_inputs.shape[0],self.max_symbols_per_phrase+self.max_conditioning_inputs+1,), fill_value=1, dtype=torch.long, device=text_inputs.device) fake_inputs[:,-1] = self.START_MEL_TOKEN - gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.NUMBER_SYMBOLS, pad_token_id=0, eos_token_id=0, + gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN, max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True) return gen[:, self.max_mel_frames:] @@ -141,4 +150,5 @@ if __name__ == '__main__': gpt = GptTtsHf() l = gpt(torch.randint(high=len(symbols), size=(2,100)), torch.randn(2,2,80,800), - torch.randint(high=8192, size=(2,200))) + torch.randint(high=8192, size=(2,200)), + torch.tensor([150*256,195*256]))