From 7f98727ad576ee8dd8d46b190b686c4168bd21f7 Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Mon, 6 Mar 2023 20:31:19 +0000
Subject: [PATCH] added option to specify autoregressive model at tts
 generation time (for a spicy feature later)

---
 setup.py        | 2 +-
 tortoise/api.py | 7 +++++++
 2 files changed, 8 insertions(+), 1 deletion(-)

diff --git a/setup.py b/setup.py
index cc708c8..3a578f5 100644
--- a/setup.py
+++ b/setup.py
@@ -6,7 +6,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
 setuptools.setup(
     name="TorToiSe",
     packages=setuptools.find_packages(),
-    version="2.4.3",
+    version="2.4.4",
     author="James Betker",
     author_email="james@adamant.ai",
     description="A high quality multi-voice text-to-speech library",
diff --git a/tortoise/api.py b/tortoise/api.py
index 56bdc04..5ed7843 100755
--- a/tortoise/api.py
+++ b/tortoise/api.py
@@ -256,6 +256,7 @@ class TextToSpeech:
         if device is None:
             device = get_device(verbose=True)
 
+        self.version = [2,4,4] # to-do, autograb this from setup.py, or have setup.py autograb this
         self.input_sample_rate = input_sample_rate
         self.output_sample_rate = output_sample_rate
         self.minor_optimizations = minor_optimizations
@@ -475,6 +476,7 @@ class TextToSpeech:
             # autoregressive generation parameters follow
             num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
             sample_batch_size=None,
+            autoregressive_model=None,
             # CVVP parameters follow
             cvvp_amount=.0,
             # diffusion generation parameters follow
@@ -537,6 +539,11 @@ class TextToSpeech:
         self.diffusion.enable_fp16 = half_p
         deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
 
+        if autoregressive_model is None:
+            autoregressive_model = self.autoregressive_model_path
+        elif autoregressive_model != self.autoregressive_model_path:
+            load_autoregressive_model(autoregressive_model)
+
         text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
         text_tokens = F.pad(text_tokens, (0, 1))  # This may not be necessary.
         assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'