From d8c673982066e786d045d03a5a97ebaa0db76454 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 18 Feb 2023 14:08:45 +0000 Subject: [PATCH] added constructor argument and function to load a user-specified autoregressive model --- README.md | 2 +- main.py | 2 ++ tortoise/api.py | 22 ++++++++++++++++++++-- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2c748e6..26781af 100755 --- a/README.md +++ b/README.md @@ -2,4 +2,4 @@ This repo is for my modifications to [neonbjb/tortoise-tts](https://github.com/neonbjb/tortoise-tts). -For the original repo, please go to [mrq/ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning). \ No newline at end of file +Please migrate to [mrq/ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning) for future additions. \ No newline at end of file diff --git a/main.py b/main.py index 18bc815..947b81c 100755 --- a/main.py +++ b/main.py @@ -1,6 +1,8 @@ import os import webui as mrq +print('DEPRECATION WARNING: this repo has been refractored to focus entirely on tortoise-tts. Please migrate to https://git.ecker.tech/mrq/ai-voice-cloning if you seek new features.') + if 'TORTOISE_MODELS_DIR' not in os.environ: os.environ['TORTOISE_MODELS_DIR'] = os.path.realpath(os.path.join(os.getcwd(), './models/tortoise/')) diff --git a/tortoise/api.py b/tortoise/api.py index 3806cb4..e4babca 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -203,7 +203,7 @@ class TextToSpeech: Main entry point into Tortoise. """ - def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None, minor_optimizations=True, input_sample_rate=22050, output_sample_rate=24000): + def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None, minor_optimizations=True, input_sample_rate=22050, output_sample_rate=24000, autoregressive_model_path=None): """ Constructor :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing @@ -238,6 +238,8 @@ class TextToSpeech: self.tokenizer = VoiceBpeTokenizer() + self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', models_dir) + if os.path.exists(f'{models_dir}/autoregressive.ptt'): # Assume this is a traced directory. self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt') @@ -247,7 +249,7 @@ class TextToSpeech: model_dim=1024, heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, train_solo_embeddings=False).cpu().eval() - self.autoregressive.load_state_dict(torch.load(get_model_path('autoregressive.pth', models_dir))) + self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path)) self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache) self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, @@ -277,6 +279,22 @@ class TextToSpeech: self.clvp = self.clvp.to(self.device) self.vocoder = self.vocoder.to(self.device) + def load_autoregressive_model(self, autoregressive_model_path): + previous_path = self.autoregressive_model_path + self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', models_dir) + + del self.autoregressive + self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, + model_dim=1024, + heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, + train_solo_embeddings=False).cpu().eval() + self.autoregressive.load_state_dict(torch.load(self.autoregressive_model_path)) + self.autoregressive.post_init_gpt2_config(kv_cache=self.use_kv_cache) + if self.preloaded_tensors: + self.autoregressive = self.autoregressive.to(self.device) + + return previous_path != self.autoregressive_model_path + def load_cvvp(self): """Load CVVP model.""" self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,