diff --git a/codes/data/audio/gpt_tts_dataset.py b/codes/data/audio/gpt_tts_dataset.py index 7d1aff98..bd84a281 100644 --- a/codes/data/audio/gpt_tts_dataset.py +++ b/codes/data/audio/gpt_tts_dataset.py @@ -73,10 +73,14 @@ class GptTtsCollater(): filenames = [j[2] for j in batch] + padded_qmel_gt = torch.stack(qmels)[:, 1:-1] + padded_qmel_gt = padded_qmel_gt * (padded_qmel_gt < 512) + return { 'padded_text': torch.stack(texts), 'input_lengths': LongTensor(text_lens), 'padded_qmel': torch.stack(qmels), + 'padded_qmel_gt': padded_qmel_gt, 'output_lengths': LongTensor(mel_lens), 'filenames': filenames } diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index 5db11c8b..2c678f07 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -73,24 +73,27 @@ class GptTts(nn.Module): text_emb = self.text_embedding(text_inputs) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) - mel_seq = [self.MEL_START_TOKEN, 0] - while mel_seq[-1] != self.MEL_STOP_TOKEN and len(mel_seq) < self.max_mel_frames: - mel_seq.append(0) - mel_emb = self.mel_embedding(torch.tensor(mel_seq, dtype=torch.long, device=text_inputs.device)).unsqueeze(0) + mel_seq = torch.full((text_emb.shape[0],1), fill_value=self.MEL_START_TOKEN, device=text_emb.device) + stop_encountered = torch.zeros((text_emb.shape[0],), device=text_emb.device) + while not torch.all(stop_encountered) and len(mel_seq) < self.max_mel_frames: + mel_emb = self.mel_embedding(mel_seq) mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) emb = torch.cat([text_emb, mel_emb], dim=1) enc = self.gpt(emb) mel_logits = self.final_norm(enc[:, text_emb.shape[1]:]) mel_logits = self.mel_head(mel_logits) mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1) - mel_seq[-1] = mel_codes[-1] + mel_seq = torch.cat([mel_seq, mel_codes[:, -1].unsqueeze(1)], dim=1) + stop_encountered = torch.logical_or(stop_encountered, mel_seq[:,-1] == self.MEL_STOP_TOKEN) if len(mel_seq) >= self.max_mel_frames: print("Warning! Encountered frame limit before a stop token. Output is likely wrong.") - # Prevent sending invalid tokens to the VAE - mel_seq = [s if s < 512 else 0 for s in mel_seq] - return mel_seq[:-1] + # Prevent sending invalid tokens to the VAE. Also pad to a length of 3, which is required by the DVAE. + mel_seq = mel_seq * (mel_seq < 512) + padding_needed = 3-(mel_seq.shape[1]%3) + mel_seq = F.pad(mel_seq, (0,padding_needed)) + return mel_seq @register_model