diff --git a/tortoise/api.py b/tortoise/api.py index 32016cd..bacbd7a 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -44,6 +44,7 @@ MODELS = { 'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth', 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', 'bigvgan_base_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.pth', + #'bigvgan_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.pth', } def hash_file(path, algo="md5", buffer_size=0): @@ -241,7 +242,7 @@ class TextToSpeech: Main entry point into Tortoise. """ - def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None, minor_optimizations=True, input_sample_rate=22050, output_sample_rate=24000, autoregressive_model_path=None, use_bigvgan=True): + def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None, minor_optimizations=True, input_sample_rate=22050, output_sample_rate=24000, autoregressive_model_path=None, vocoder_model=None): """ Constructor :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing @@ -253,6 +254,7 @@ class TextToSpeech: Default is true. :param device: Device to use when running the model. If omitted, the device will be automatically chosen. """ + self.loading = True if device is None: device = get_device(verbose=True) @@ -278,19 +280,13 @@ class TextToSpeech: self.tokenizer = VoiceBpeTokenizer() self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', models_dir) - self.autoregressive_model_hash = hash_file(self.autoregressive_model_path) if os.path.exists(f'{models_dir}/autoregressive.ptt'): # Assume this is a traced directory. self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt') self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt') else: - self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, - model_dim=1024, - heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, - train_solo_embeddings=False).cpu().eval() - self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path)) - self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache) + self.load_autoregressive_model(self.autoregressive_model_path) self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, @@ -305,14 +301,8 @@ class TextToSpeech: self.clvp.load_state_dict(torch.load(get_model_path('clvp2.pth', models_dir))) self.cvvp = None # CVVP model is only loaded if used. - if use_bigvgan: - # credit to https://github.com/deviandice / https://git.ecker.tech/mrq/ai-voice-cloning/issues/52 - self.vocoder = BigVGAN().cpu() - self.vocoder.load_state_dict(torch.load(get_model_path('bigvgan_base_24khz_100band.pth', models_dir), map_location=torch.device('cpu'))['generator']) - else: - self.vocoder = UnivNetGenerator().cpu() - self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g']) - self.vocoder.eval(inference=True) + self.vocoder_model = vocoder_model + self.load_vocoder_model(self.vocoder_model) # Random latent generators (RLGs) are loaded lazily. self.rlg_auto = None @@ -323,13 +313,18 @@ class TextToSpeech: self.diffusion = self.diffusion.to(self.device) self.clvp = self.clvp.to(self.device) self.vocoder = self.vocoder.to(self.device) + self.loading = False def load_autoregressive_model(self, autoregressive_model_path): + self.loading = True + previous_path = self.autoregressive_model_path self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir) self.autoregressive_model_hash = hash_file(self.autoregressive_model_path) - del self.autoregressive + if hasattr(self, 'autoregressive'): + del self.autoregressive + self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, model_dim=1024, heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, @@ -339,8 +334,37 @@ class TextToSpeech: if self.preloaded_tensors: self.autoregressive = self.autoregressive.to(self.device) + self.loading = False - return previous_path != self.autoregressive_model_path + def load_vocoder_model(self, vocoder_model): + self.loading = True + if hasattr(self, 'vocoder'): + del self.vocoder + + print(vocoder_model) + if vocoder_model is None: + vocoder_model = 'bigvgan_24khz_100band' + + if 'bigvgan' in vocoder_model: + # credit to https://github.com/deviandice / https://git.ecker.tech/mrq/ai-voice-cloning/issues/52 + vocoder_key = 'generator' + self.vocoder_model_path = 'bigvgan_24khz_100band.pth' + if f'{vocoder_model}.pth' in MODELS: + self.vocoder_model_path = f'{vocoder_model}.pth' + self.vocoder = BigVGAN().cpu() + #elif vocoder_model == "univnet": + else: + vocoder_key = 'model_g' + self.vocoder_model_path = 'vocoder.pth' + self.vocoder = UnivNetGenerator().cpu() + + print(vocoder_model, vocoder_key, self.vocoder_model_path) + self.vocoder.load_state_dict(torch.load(get_model_path(self.vocoder_model_path, self.models_dir), map_location=torch.device('cpu'))[vocoder_key]) + + self.vocoder.eval(inference=True) + if self.preloaded_tensors: + self.vocoder = self.vocoder.to(self.device) + self.loading = False def load_cvvp(self): """Load CVVP model."""