From 9e8a9bf6cac069c003732ec2cfaf91cf1dd33930 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 16 Dec 2021 23:28:44 -0700 Subject: [PATCH] Various fixes to gpt_tts_hf --- codes/data/audio/nv_tacotron_dataset.py | 6 +++-- codes/models/gpt_voice/gpt_tts_hf.py | 34 +++++++++++++------------ 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index e058f8d8..b8adf4b7 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -219,11 +219,13 @@ if __name__ == '__main__': 'phase': 'train', 'n_workers': 0, 'batch_size': batch_sz, - 'needs_collate': True, - 'max_wav_length': 256000, + 'needs_collate': False, + 'max_wav_length': 255995, 'max_text_length': 200, 'sample_rate': 22050, 'load_conditioning': True, + 'num_conditioning_candidates': 3, + 'conditioning_length': 44100, } from data import create_dataset, create_dataloader diff --git a/codes/models/gpt_voice/gpt_tts_hf.py b/codes/models/gpt_voice/gpt_tts_hf.py index 1d38fed4..c56bb0bd 100644 --- a/codes/models/gpt_voice/gpt_tts_hf.py +++ b/codes/models/gpt_voice/gpt_tts_hf.py @@ -33,10 +33,11 @@ class GptTtsHf(nn.Module): self.max_conditioning_inputs = max_conditioning_inputs self.mel_length_compression = mel_length_compression self.conditioning_encoder = AudioMiniEncoder(80, model_dim) - self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) - self.conditioning_embedding = nn.Embedding(self.max_conditioning_inputs, model_dim) + self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim) + self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 2, model_dim) + self.conditioning_embedding = nn.Parameter(torch.randn(1,model_dim), requires_grad=True) self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim) - seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens + seq_length = 4+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES, n_positions=seq_length, n_ctx=seq_length, @@ -56,14 +57,10 @@ class GptTtsHf(nn.Module): assert cond_inputs.shape[1] <= self.max_conditioning_inputs assert mel_targets.shape[1] <= self.max_mel_tokens - mel_targets = F.pad(mel_targets, (1,0), value=self.START_MEL_TOKEN) - mel_targets = F.pad(mel_targets, (0, self.max_mel_tokens - mel_targets.shape[1]), value=self.STOP_MEL_TOKEN) - mel_emb = self.gpt.get_input_embeddings()(mel_targets) - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_targets.device)) text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN) - text_targets = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_targets.shape[1]), value=self.STOP_TEXT_TOKEN) - text_emb = self.gpt.get_input_embeddings()(text_targets) + text_targets = F.pad(text_targets, (0,1), value=self.STOP_TEXT_TOKEN) + text_emb = self.text_embedding(text_targets) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device)) conds = [] @@ -72,7 +69,12 @@ class GptTtsHf(nn.Module): while len(conds) < self.max_conditioning_inputs: conds.append(conds[-1]) conds = torch.stack(conds, dim=1) - conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device)) + conds = conds + self.conditioning_embedding + + mel_targets = F.pad(mel_targets, (1,0), value=self.START_MEL_TOKEN) + mel_targets = F.pad(mel_targets, (0,1), value=self.STOP_MEL_TOKEN) + mel_emb = self.gpt.get_input_embeddings()(mel_targets) + mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_targets.device)) emb = torch.cat([text_emb, conds, mel_emb], dim=1) gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) @@ -118,8 +120,8 @@ class GptTtsHf(nn.Module): self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head) text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN) - text_targets = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_targets.shape[1]), value=self.STOP_TEXT_TOKEN) - text_emb = self.gpt.get_input_embeddings()(text_targets) + text_targets = F.pad(text_targets, (0,1), value=self.STOP_TEXT_TOKEN) + text_emb = self.text_embedding(text_targets) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device)) conds = [] @@ -133,11 +135,11 @@ class GptTtsHf(nn.Module): emb = torch.cat([text_emb, conds], dim=1) self.inference_model.store_mel_emb(emb) - fake_inputs = torch.full((text_inputs.shape[0],self.max_symbols_per_phrase+self.max_conditioning_inputs+1,), fill_value=1, dtype=torch.long, device=text_inputs.device) + fake_inputs = torch.full((text_inputs.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device) fake_inputs[:,-1] = self.START_MEL_TOKEN gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN, - max_length=self.max_symbols_per_phrase+self.max_mel_frames, temperature=temperature, num_beams=num_beams, use_cache=True) + max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True) return gen[:, self.max_mel_frames:] @@ -148,7 +150,7 @@ def register_gpt_tts_hf(opt_net, opt): if __name__ == '__main__': gpt = GptTtsHf(model_dim=1024, heads=16) - l = gpt(torch.randint(high=len(symbols), size=(2,100)), + l = gpt(torch.randint(high=len(symbols), size=(2,200)), torch.randn(2,2,80,800), - torch.randint(high=8192, size=(2,200)), + torch.randint(high=8192, size=(2,250)), torch.tensor([150*256,195*256]))