From c813befd535345bcd214e4a18efb7daa47d8b3be Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 19 Dec 2021 09:01:31 -0700 Subject: [PATCH] Remove dedicated positioning embeddings --- codes/models/gpt_voice/gpt_tts_hf.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/codes/models/gpt_voice/gpt_tts_hf.py b/codes/models/gpt_voice/gpt_tts_hf.py index 405fb64b..652b4cee 100644 --- a/codes/models/gpt_voice/gpt_tts_hf.py +++ b/codes/models/gpt_voice/gpt_tts_hf.py @@ -32,9 +32,6 @@ class GptTtsHf(nn.Module): self.mel_length_compression = mel_length_compression self.conditioning_encoder = AudioMiniEncoder(80, model_dim) self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim) - self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) - self.conditioning_embedding = nn.Parameter(torch.randn(1,model_dim), requires_grad=True) - self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim) seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES, n_positions=seq_length, @@ -57,7 +54,6 @@ class GptTtsHf(nn.Module): def get_logits(self, text_inputs, cond_inputs, mel_inputs, get_attns=False): text_emb = self.text_embedding(text_inputs) - text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device)) conds = [] for k in range(cond_inputs.shape[1]): @@ -65,10 +61,8 @@ class GptTtsHf(nn.Module): while len(conds) < self.max_conditioning_inputs: conds.append(conds[-1]) conds = torch.stack(conds, dim=1) - conds = conds + self.conditioning_embedding mel_emb = self.gpt.get_input_embeddings()(mel_inputs) - mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_inputs.device)) emb = torch.cat([text_emb, conds, mel_emb], dim=1) gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) @@ -117,7 +111,6 @@ class GptTtsHf(nn.Module): text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN) text_emb = self.text_embedding(text_inputs) - text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device)) conds = [] for k in range(cond_inputs.shape[1]): @@ -125,7 +118,6 @@ class GptTtsHf(nn.Module): while len(conds) < self.max_conditioning_inputs: conds.append(conds[-1]) conds = torch.stack(conds, dim=1) - conds = conds + self.conditioning_embedding emb = torch.cat([text_emb, conds], dim=1) self.inference_model.store_mel_emb(emb)