From f9c45d70f0fcc169cdeef6f2cd199e265cb5b9cb Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Sat, 18 Dec 2021 17:18:06 -0700
Subject: [PATCH] Fix mel terminator

---
 codes/models/gpt_voice/gpt_tts_hf.py   | 6 ++++--
 codes/scripts/audio/gen/use_gpt_tts.py | 4 ++--
 codes/train.py                         | 2 +-
 3 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/codes/models/gpt_voice/gpt_tts_hf.py b/codes/models/gpt_voice/gpt_tts_hf.py
index 6bb6c506..733bc8bc 100644
--- a/codes/models/gpt_voice/gpt_tts_hf.py
+++ b/codes/models/gpt_voice/gpt_tts_hf.py
@@ -23,7 +23,7 @@ class GptTtsHf(nn.Module):
     STOP_MEL_TOKEN = 8193
 
     def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=200, max_mel_tokens=250, max_conditioning_inputs=3,
-                 checkpointing=True, mel_length_compression=256):
+                 checkpointing=True, mel_length_compression=1024):
         super().__init__()
         self.max_mel_tokens = max_mel_tokens
         self.max_symbols_per_phrase = max_symbols_per_phrase
@@ -109,6 +109,8 @@ class GptTtsHf(nn.Module):
         return loss_text.mean(), loss_mel.mean(), mel_logits
 
     def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8):
+        text_inputs, cond_inputs = torch.load("debug_text_and_cond.pt")
+
         if not hasattr(self, 'inference_model'):
             self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
 
@@ -132,7 +134,7 @@ class GptTtsHf(nn.Module):
         fake_inputs[:,-1] = self.START_MEL_TOKEN
 
         gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
-                          max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True, repetition_penalty=.2)
+                          max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True)
         return gen[:, fake_inputs.shape[1]:]
 
 
diff --git a/codes/scripts/audio/gen/use_gpt_tts.py b/codes/scripts/audio/gen/use_gpt_tts.py
index 9d3f25d9..08a451d7 100644
--- a/codes/scripts/audio/gen/use_gpt_tts.py
+++ b/codes/scripts/audio/gen/use_gpt_tts.py
@@ -46,7 +46,7 @@ if __name__ == '__main__':
     parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
     parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts.yml')
     parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
-    parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts\\models\\23500_gpt.pth')
+    parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts\\models\\32000_gpt.pth')
     parser.add_argument('-text', type=str, help='Text to speak.', default="I'm a language model that has learned to speak.")
     parser.add_argument('-cond_path', type=str, help='Folder containing conditioning samples.', default='Z:\\clips\\books1\\3042_18_Holden__000000000')
     parser.add_argument('-num_cond', type=int, help='Number of conditioning samples to load.', default=3)
@@ -63,7 +63,7 @@ if __name__ == '__main__':
     conds, cond_wav = load_conditioning_candidates(args.cond_path, args.num_cond)
 
     print("Performing GPT inference..")
-    codes = gpt.inference(text, conds, num_beams=4)
+    codes = gpt.inference(text, conds, num_beams=32)
 
     # Delete the GPT TTS model to free up GPU memory
     del gpt
diff --git a/codes/train.py b/codes/train.py
index ca775e8b..45c473cb 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -286,7 +286,7 @@ class Trainer:
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_bench.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
     args = parser.parse_args()