From 7f98727ad576ee8dd8d46b190b686c4168bd21f7 Mon Sep 17 00:00:00 2001 From: mrq <barry.quiggles@protonmail.com> Date: Mon, 6 Mar 2023 20:31:19 +0000 Subject: [PATCH] added option to specify autoregressive model at tts generation time (for a spicy feature later) --- setup.py | 2 +- tortoise/api.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index cc708c8..3a578f5 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ with open("README.md", "r", encoding="utf-8") as fh: setuptools.setup( name="TorToiSe", packages=setuptools.find_packages(), - version="2.4.3", + version="2.4.4", author="James Betker", author_email="james@adamant.ai", description="A high quality multi-voice text-to-speech library", diff --git a/tortoise/api.py b/tortoise/api.py index 56bdc04..5ed7843 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -256,6 +256,7 @@ class TextToSpeech: if device is None: device = get_device(verbose=True) + self.version = [2,4,4] # to-do, autograb this from setup.py, or have setup.py autograb this self.input_sample_rate = input_sample_rate self.output_sample_rate = output_sample_rate self.minor_optimizations = minor_optimizations @@ -475,6 +476,7 @@ class TextToSpeech: # autoregressive generation parameters follow num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, sample_batch_size=None, + autoregressive_model=None, # CVVP parameters follow cvvp_amount=.0, # diffusion generation parameters follow @@ -537,6 +539,11 @@ class TextToSpeech: self.diffusion.enable_fp16 = half_p deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) + if autoregressive_model is None: + autoregressive_model = self.autoregressive_model_path + elif autoregressive_model != self.autoregressive_model_path: + load_autoregressive_model(autoregressive_model) + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'