diff --git a/tortoise/api.py b/tortoise/api.py index 965e133..3aeb5e1 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -565,6 +565,8 @@ class TextToSpeech: num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, sample_batch_size=None, autoregressive_model=None, + diffusion_model=None, + tokenizer_json=None, # CVVP parameters follow cvvp_amount=.0, # diffusion generation parameters follow @@ -632,6 +634,16 @@ class TextToSpeech: elif autoregressive_model != self.autoregressive_model_path: self.load_autoregressive_model(autoregressive_model) + if diffusion_model is None: + diffusion_model = self.diffusion_model_path + elif diffusion_model != self.diffusion_model_path: + self.load_diffusion_model(diffusion_model) + + if tokenizer_json is None: + tokenizer_json = self.tokenizer_json + elif tokenizer_json != self.tokenizer_json: + self.load_tokenizer_json(tokenizer_json) + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0) text_tokens = migrate_to_device( text_tokens, self.device )