diff --git a/codes/models/gpt_voice/gpt_tts_hf.py b/codes/models/gpt_voice/gpt_tts_hf.py index c56bb0bd..6ac96008 100644 --- a/codes/models/gpt_voice/gpt_tts_hf.py +++ b/codes/models/gpt_voice/gpt_tts_hf.py @@ -34,10 +34,10 @@ class GptTtsHf(nn.Module): self.mel_length_compression = mel_length_compression self.conditioning_encoder = AudioMiniEncoder(80, model_dim) self.text_embedding = nn.Embedding(self.NUMBER_TEXT_TOKENS, model_dim) - self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 2, model_dim) + self.text_pos_embedding = nn.Embedding(self.max_symbols_per_phrase + 1, model_dim) self.conditioning_embedding = nn.Parameter(torch.randn(1,model_dim), requires_grad=True) - self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim) - seq_length = 4+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens + self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim) + seq_length = 2+self.max_symbols_per_phrase+self.max_conditioning_inputs+self.max_mel_tokens self.gpt_config = GPT2Config(vocab_size=self.NUMBER_MEL_CODES, n_positions=seq_length, n_ctx=seq_length, @@ -59,7 +59,6 @@ class GptTtsHf(nn.Module): text_targets = F.pad(text_inputs, (1,0), value=self.START_TEXT_TOKEN) - text_targets = F.pad(text_targets, (0,1), value=self.STOP_TEXT_TOKEN) text_emb = self.text_embedding(text_targets) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device)) @@ -72,7 +71,6 @@ class GptTtsHf(nn.Module): conds = conds + self.conditioning_embedding mel_targets = F.pad(mel_targets, (1,0), value=self.START_MEL_TOKEN) - mel_targets = F.pad(mel_targets, (0,1), value=self.STOP_MEL_TOKEN) mel_emb = self.gpt.get_input_embeddings()(mel_targets) mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_targets.device)) @@ -82,10 +80,10 @@ class GptTtsHf(nn.Module): return gpt_out.attentions enc = gpt_out.last_hidden_state - text_logits = self.final_norm(enc[:, :self.max_symbols_per_phrase]) + text_logits = self.final_norm(enc[:, :self.max_symbols_per_phrase+1]) text_logits = self.text_head(text_logits) text_logits = text_logits.permute(0,2,1) - mel_logits = self.final_norm(enc[:, -self.max_mel_tokens:]) + mel_logits = self.final_norm(enc[:, -(self.max_mel_tokens+1):]) mel_logits = self.mel_head(mel_logits) mel_logits = mel_logits.permute(0,2,1) @@ -109,9 +107,9 @@ class GptTtsHf(nn.Module): if return_attentions: return mel_logits - text_targets = F.pad(text_inputs, (0,self.max_symbols_per_phrase-text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN) + text_targets = F.pad(text_inputs, (0,self.max_symbols_per_phrase-text_inputs.shape[1]+1), value=self.STOP_TEXT_TOKEN) loss_text = F.cross_entropy(text_logits, text_targets.long()) - mel_targets = F.pad(mel_targets, (0,self.max_mel_tokens-mel_targets.shape[1]), value=self.STOP_MEL_TOKEN) + mel_targets = F.pad(mel_targets, (0,self.max_mel_tokens-mel_targets.shape[1]+1), value=self.STOP_MEL_TOKEN) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_text.mean(), loss_mel.mean(), mel_logits