From b4ddcd7111d60b6a56d10668a5d00e47d1147a2c Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 19 Dec 2021 09:01:19 -0700 Subject: [PATCH] More inference improvements --- codes/models/gpt_voice/gpt_tts_hf.py | 8 ++++---- codes/scripts/audio/gen/use_gpt_tts.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/codes/models/gpt_voice/gpt_tts_hf.py b/codes/models/gpt_voice/gpt_tts_hf.py index 733bc8bc..405fb64b 100644 --- a/codes/models/gpt_voice/gpt_tts_hf.py +++ b/codes/models/gpt_voice/gpt_tts_hf.py @@ -108,8 +108,8 @@ class GptTtsHf(nn.Module): loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_text.mean(), loss_mel.mean(), mel_logits - def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8): - text_inputs, cond_inputs = torch.load("debug_text_and_cond.pt") + def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8, repetition_penalty=1): + #text_inputs, cond_inputs = torch.load("debug_text_and_cond.pt") if not hasattr(self, 'inference_model'): self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head) @@ -134,8 +134,8 @@ class GptTtsHf(nn.Module): fake_inputs[:,-1] = self.START_MEL_TOKEN gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN, - max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True) - return gen[:, fake_inputs.shape[1]:] + max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True, repetition_penalty=repetition_penalty) + return gen[:, fake_inputs.shape[1]:-1] @register_model diff --git a/codes/scripts/audio/gen/use_gpt_tts.py b/codes/scripts/audio/gen/use_gpt_tts.py index 08a451d7..798cbbf0 100644 --- a/codes/scripts/audio/gen/use_gpt_tts.py +++ b/codes/scripts/audio/gen/use_gpt_tts.py @@ -46,9 +46,9 @@ if __name__ == '__main__': parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae') parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts.yml') parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt') - parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts\\models\\32000_gpt.pth') - parser.add_argument('-text', type=str, help='Text to speak.', default="I'm a language model that has learned to speak.") - parser.add_argument('-cond_path', type=str, help='Folder containing conditioning samples.', default='Z:\\clips\\books1\\3042_18_Holden__000000000') + parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts\\models\\48000_gpt.pth') + parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.") + parser.add_argument('-cond_path', type=str, help='Folder containing conditioning samples.', default='Y:\\clips\\podcasts-0\\8816_20210511-Pay Taxes Less Frequently_ We\'re Interested') parser.add_argument('-num_cond', type=int, help='Number of conditioning samples to load.', default=3) args = parser.parse_args() @@ -63,7 +63,7 @@ if __name__ == '__main__': conds, cond_wav = load_conditioning_candidates(args.cond_path, args.num_cond) print("Performing GPT inference..") - codes = gpt.inference(text, conds, num_beams=32) + codes = gpt.inference(text, conds, num_beams=32, repetition_penalty=10.0) # Delete the GPT TTS model to free up GPU memory del gpt @@ -72,7 +72,7 @@ if __name__ == '__main__': dvae = load_model_from_config(args.opt_diffuse, args.dvae_model_name) print("Loading Diffusion Model..") diffusion = load_model_from_config(args.opt_diffuse, args.diffusion_model_name, also_load_savepoint=False, load_path=args.diffusion_model_path) - diffuser = load_discrete_vocoder_diffuser() + diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=50) print("Performing vocoding..") wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, codes, cond_wav, spectrogram_compression_factor=128, plt_spec=False)