diff --git a/setup.py b/setup.py index 3a578f5..86a3938 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open("README.md", "r", encoding="utf-8") as fh: setuptools.setup( name="TorToiSe", packages=setuptools.find_packages(), - version="2.4.4", + version="2.4.5", author="James Betker", author_email="james@adamant.ai", description="A high quality multi-voice text-to-speech library", diff --git a/tortoise/api.py b/tortoise/api.py index f1a60f2..965e133 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -265,7 +265,11 @@ class TextToSpeech: Main entry point into Tortoise. """ - def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None, minor_optimizations=True, input_sample_rate=22050, output_sample_rate=24000, autoregressive_model_path=None, vocoder_model=None): + def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True, device=None, + minor_optimizations=True, + input_sample_rate=22050, output_sample_rate=24000, + autoregressive_model_path=None, diffusion_model_path=None, vocoder_model=None, tokenizer_json=None + ): """ Constructor :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing @@ -300,23 +304,23 @@ class TextToSpeech: if self.enable_redaction: self.aligner = Wav2VecAlignment(device='cpu' if get_device_name() == "dml" else self.device) - self.tokenizer = VoiceBpeTokenizer() - + self.load_tokenizer_json(tokenizer_json) if os.path.exists(f'{models_dir}/autoregressive.ptt'): - # Assume this is a traced directory. self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt') - self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt') else: if not autoregressive_model_path or not os.path.exists(autoregressive_model_path): autoregressive_model_path = get_model_path('autoregressive.pth', models_dir) self.load_autoregressive_model(autoregressive_model_path) - self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, - in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, - layer_drop=0, unconditioned_percentage=0).cpu().eval() - self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', models_dir))) + if os.path.exists(f'{models_dir}/diffusion_decoder.ptt'): + self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt') + else: + if not diffusion_model_path or not os.path.exists(diffusion_model_path): + diffusion_model_path = get_model_path('diffusion_decoder.pth', models_dir) + + self.load_diffusion_model(diffusion_model_path) self.clvp = CLVP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20, @@ -366,6 +370,28 @@ class TextToSpeech: self.loading = False print(f"Loaded autoregressive model") + def load_diffusion_model(self, diffusion_model_path): + if hasattr(self,"diffusion_model_path") and self.diffusion_model_path == diffusion_model_path: + return + + self.loading = True + + self.diffusion_model_path = diffusion_model_path if diffusion_model_path and os.path.exists(diffusion_model_path) else get_model_path('diffusion_decoder.pth', self.models_dir) + self.diffusion_model_hash = hash_file(self.diffusion_model_path) + + if hasattr(self, 'diffusion'): + del self.diffusion + + self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, + in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, + layer_drop=0, unconditioned_percentage=0).cpu().eval() + self.diffusion.load_state_dict(torch.load(get_model_path('diffusion_decoder.pth', self.models_dir))) + if self.preloaded_tensors: + self.diffusion = migrate_to_device( self.diffusion, self.device ) + + self.loading = False + print(f"Loaded diffusion model") + def load_vocoder_model(self, vocoder_model): if hasattr(self,"vocoder_model_path") and self.vocoder_model_path == vocoder_model: return @@ -375,7 +401,7 @@ class TextToSpeech: if hasattr(self, 'vocoder'): del self.vocoder - print(vocoder_model) + print("Loading vocoder model:", vocoder_model) if vocoder_model is None: vocoder_model = 'bigvgan_24khz_100band' @@ -406,6 +432,22 @@ class TextToSpeech: self.loading = False print(f"Loaded vocoder model") + def load_tokenizer_json(self, tokenizer_json): + if hasattr(self,"tokenizer_json") and self.tokenizer_json == tokenizer_json: + return + + self.loading = True + self.tokenizer_json = tokenizer_json if tokenizer_json else os.path.join(os.path.dirname(os.path.realpath(__file__)), '../tortoise/data/tokenizer.json') + print("Loading tokenizer JSON:", self.tokenizer_json) + + if hasattr(self, 'tokenizer'): + del self.tokenizer + + self.tokenizer = VoiceBpeTokenizer(vocab_file=self.tokenizer_json) + + self.loading = False + print(f"Loaded tokenizer") + def load_cvvp(self): """Load CVVP model.""" self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,