From a2afb25e4240c6631744aa9ca4f77a292909e275 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 7 Aug 2021 20:11:10 -0600 Subject: [PATCH] Fix inference, always flow full text tokens through transformer --- codes/data/audio/gpt_tts_dataset.py | 3 ++- codes/models/gpt_voice/gpt_tts.py | 26 ++++++-------------------- codes/scripts/audio/test_audio_gen.py | 2 +- 3 files changed, 9 insertions(+), 22 deletions(-) diff --git a/codes/data/audio/gpt_tts_dataset.py b/codes/data/audio/gpt_tts_dataset.py index bd84a281..24a7cee0 100644 --- a/codes/data/audio/gpt_tts_dataset.py +++ b/codes/data/audio/gpt_tts_dataset.py @@ -57,7 +57,8 @@ class GptTtsCollater(): def __call__(self, batch): text_lens = [len(x[0]) for x in batch] - max_text_len = max(text_lens) + #max_text_len = max(text_lens) + max_text_len = self.MAX_SYMBOLS_PER_PHRASE # This forces all outputs to have the full 200 characters. Testing if this makes a difference. mel_lens = [len(x[1]) for x in batch] max_mel_len = max(mel_lens) texts = [] diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index 5f76d26b..a6e6cd2f 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -71,11 +71,12 @@ class GptTts(nn.Module): return loss_text.mean(), loss_mel.mean(), mel_codes, mel_targets def inference(self, text_inputs): + b, _ = text_inputs.shape text_emb = self.text_embedding(text_inputs) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) - mel_seq = torch.full((text_emb.shape[0],1), fill_value=self.MEL_START_TOKEN, device=text_emb.device) - stop_encountered = torch.zeros((text_emb.shape[0],), device=text_emb.device) + mel_seq = torch.full((b,1), fill_value=self.MEL_START_TOKEN, device=text_emb.device) + stop_encountered = torch.zeros((b,), device=text_emb.device) while not torch.all(stop_encountered) and len(mel_seq) < self.max_mel_frames: mel_emb = self.mel_embedding(mel_seq) mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) @@ -91,25 +92,10 @@ class GptTts(nn.Module): print("Warning! Encountered frame limit before a stop token. Output is likely wrong.") # Format mel_seq so that the DVAE can actually use it (it is a two-tiered DVAE) - cleaned = [] - for j in range(mel_seq.shape[0]): - s = mel_seq[j][1:-1] # Strip out BOS and EOS tokens. - gt = s >= 512 - l = (len(s)) // 3 - for i in reversed(range(l)): - if gt[i]: - l = i+1 - break - top = s[:l] - top = top + (top < 512) * 512 - bottom = s[l:l*3] - bottom = bottom * (bottom < 512) - combined = torch.cat([top,bottom], dim=0) - assert not torch.any(combined < 0) - combined = combined * (combined < 1024) - cleaned.append(combined) + mel_seq = mel_seq[:, 1:-1] # Remove first and last tokens, which were artificially added for GPT + mel_seq = mel_seq * (mel_seq < 512) # The DVAE doesn't understand BOS/EOS/PAD tokens. - return torch.stack(cleaned) + return mel_seq @register_model diff --git a/codes/scripts/audio/test_audio_gen.py b/codes/scripts/audio/test_audio_gen.py index 783564cc..32ab641f 100644 --- a/codes/scripts/audio/test_audio_gen.py +++ b/codes/scripts/audio/test_audio_gen.py @@ -54,7 +54,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt