diff --git a/api.py b/api.py index 6aa94cf..92e82be 100644 --- a/api.py +++ b/api.py @@ -170,35 +170,40 @@ class TextToSpeech: :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing GPU OOM errors. Larger numbers generates slightly faster. """ - def __init__(self, autoregressive_batch_size=16): + def __init__(self, autoregressive_batch_size=16, models_dir='.models'): self.autoregressive_batch_size = autoregressive_batch_size self.tokenizer = VoiceBpeTokenizer() download_models() - self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, - model_dim=1024, - heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, - train_solo_embeddings=False, - average_conditioning_embeddings=True).cpu().eval() - self.autoregressive.load_state_dict(torch.load('.models/autoregressive.pth')) + if os.path.exists(f'{models_dir}/autoregressive.ptt'): + # Assume this is a traced directory. + self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt') + self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt') + else: + self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, + model_dim=1024, + heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, + train_solo_embeddings=False, + average_conditioning_embeddings=True).cpu().eval() + self.autoregressive.load_state_dict(torch.load(f'{models_dir}/autoregressive.pth')) + + self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, + in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, + layer_drop=0, unconditioned_percentage=0).cpu().eval() + self.diffusion.load_state_dict(torch.load(f'{models_dir}/diffusion_decoder.pth')) self.clvp = CLVP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12, text_seq_len=350, text_heads=8, num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430, use_xformers=True).cpu().eval() - self.clvp.load_state_dict(torch.load('.models/clvp.pth')) + self.clvp.load_state_dict(torch.load(f'{models_dir}/clvp.pth')) self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0, speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval() - self.cvvp.load_state_dict(torch.load('.models/cvvp.pth')) - - self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, - in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, - layer_drop=0, unconditioned_percentage=0).cpu().eval() - self.diffusion.load_state_dict(torch.load('.models/diffusion_decoder.pth')) + self.cvvp.load_state_dict(torch.load(f'{models_dir}/cvvp.pth')) self.vocoder = UnivNetGenerator().cpu() - self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g']) + self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g']) self.vocoder.eval(inference=True) def tts_with_preset(self, text, voice_samples, preset='fast', **kwargs): @@ -216,7 +221,7 @@ class TextToSpeech: 'cond_free_k': 2.0, 'diffusion_temperature': 1.0}) # Presets are defined here. presets = { - 'ultra_fast': {'num_autoregressive_samples': 32, 'diffusion_iterations': 16, 'cond_free': False}, + 'ultra_fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 32, 'cond_free': False}, 'fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 32}, 'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 128}, 'high_quality': {'num_autoregressive_samples': 512, 'diffusion_iterations': 1024},