diff --git a/api.py b/api.py index fa1a010..04c3af8 100644 --- a/api.py +++ b/api.py @@ -157,10 +157,23 @@ class TextToSpeech: self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, model_dim=1024, - heads=16, number_text_tokens=256, start_text_token=255, checkpointing=False, + heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, train_solo_embeddings=False, average_conditioning_embeddings=True).cpu().eval() self.autoregressive.load_state_dict(torch.load('.models/autoregressive.pth')) + ''' + self.autoregressive = UnifiedVoice(max_mel_tokens=2048, max_text_tokens=1024, max_conditioning_inputs=1, layers=42, + model_dim=1152, heads=18, number_text_tokens=256, train_solo_embeddings=False, + average_conditioning_embeddings=True, types=2).cpu().eval() + self.autoregressive.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_tts_xl\\models\\15250_gpt_ema.pth')) + ''' + + self.autoregressive_for_diffusion = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, + model_dim=1024, + heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, + train_solo_embeddings=False, + average_conditioning_embeddings=True).cpu().eval() + self.autoregressive_for_diffusion.load_state_dict(torch.load('.models/autoregressive.pth')) self.clip = VoiceCLIP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12, text_seq_len=350, text_heads=8, @@ -202,7 +215,7 @@ class TextToSpeech: def tts(self, text, voice_samples, k=1, # autoregressive generation parameters follow - num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, + num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, # diffusion generation parameters follow diffusion_iterations=100, cond_free=True, cond_free_k=2, diffusion_temperature=1.0, **hf_generate_kwargs): @@ -232,8 +245,9 @@ class TextToSpeech: num_return_sequences=self.autoregressive_batch_size, length_penalty=length_penalty, repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, **hf_generate_kwargs) - padding_needed = self.autoregressive.max_mel_tokens - codes.shape[1] + padding_needed = max_mel_tokens - codes.shape[1] codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) samples.append(codes) self.autoregressive = self.autoregressive.cpu() @@ -253,11 +267,11 @@ class TextToSpeech: # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # results, but will increase memory usage. - self.autoregressive = self.autoregressive.cuda() - best_latents = self.autoregressive(conds, text, torch.tensor([text.shape[-1]], device=conds.device), best_results, - torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=conds.device), + self.autoregressive_for_diffusion = self.autoregressive_for_diffusion.cuda() + best_latents = self.autoregressive_for_diffusion(conds, text, torch.tensor([text.shape[-1]], device=conds.device), best_results, + torch.tensor([best_results.shape[-1]*self.autoregressive_for_diffusion.mel_length_compression], device=conds.device), return_latent=True, clip_inputs=False) - self.autoregressive = self.autoregressive.cpu() + self.autoregressive_for_diffusion = self.autoregressive_for_diffusion.cpu() print("Performing vocoding..") wav_candidates = [] diff --git a/do_tts.py b/do_tts.py index 3448942..ec21641 100644 --- a/do_tts.py +++ b/do_tts.py @@ -1,35 +1,17 @@ import argparse import os -import torch -import torch.nn.functional as F import torchaudio -from api import TextToSpeech, load_conditioning -from utils.audio import load_audio -from utils.tokenizer import VoiceBpeTokenizer +from api import TextToSpeech +from utils.audio import load_audio, get_voices if __name__ == '__main__': - # These are voices drawn randomly from the training set. You are free to substitute your own voices in, but testing - # has shown that the model does not generalize to new voices very well. - preselected_cond_voices = { - # Male voices - 'dotrice': ['voices/dotrice/1.wav', 'voices/dotrice/2.wav'], - 'harris': ['voices/harris/1.wav', 'voices/harris/2.wav'], - 'lescault': ['voices/lescault/1.wav', 'voices/lescault/2.wav'], - 'otto': ['voices/otto/1.wav', 'voices/otto/2.wav'], - 'obama': ['voices/obama/1.wav', 'voices/obama/2.wav'], - # Female voices - 'atkins': ['voices/atkins/1.wav', 'voices/atkins/2.wav'], - 'grace': ['voices/grace/1.wav', 'voices/grace/2.wav'], - 'kennard': ['voices/kennard/1.wav', 'voices/kennard/2.wav'], - 'mol': ['voices/mol/1.wav', 'voices/mol/2.wav'], - } - parser = argparse.ArgumentParser() parser.add_argument('--text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.") - parser.add_argument('--voice', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='obama,dotrice,harris,lescault,otto,atkins,grace,kennard,mol') - parser.add_argument('--num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=128) + parser.add_argument('--voice', type=str, help='Selects the voice to use for generation. See options in voices/ directory (and add your own!) ' + 'Use the & character to join two voices together. Use a comma to perform inference on multiple voices.', default='patrick_stewart') + parser.add_argument('--num_samples', type=int, help='How many total outputs the autoregressive transformer should produce.', default=256) parser.add_argument('--batch_size', type=int, help='How many samples to process at once in the autoregressive model.', default=16) parser.add_argument('--num_diffusion_samples', type=int, help='Number of outputs that progress to the diffusion stage.', default=16) parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/') @@ -38,8 +20,10 @@ if __name__ == '__main__': tts = TextToSpeech(autoregressive_batch_size=args.batch_size) - for voice in args.voice.split(','): - cond_paths = preselected_cond_voices[voice] + voices = get_voices() + selected_voices = args.voice.split(',') + for voice in selected_voices: + cond_paths = voices[voice] conds = [] for cond_path in cond_paths: c = load_audio(cond_path, 22050)