From 26133c20314b77155e77be804b43909dab9809d6 Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Tue, 7 Mar 2023 04:33:49 +0000
Subject: [PATCH] do not reload AR/vocoder if already loaded

---
 tortoise/api.py | 30 ++++++++++++++++++++----------
 1 file changed, 20 insertions(+), 10 deletions(-)

diff --git a/tortoise/api.py b/tortoise/api.py
index bacbd7a..9765a87 100755
--- a/tortoise/api.py
+++ b/tortoise/api.py
@@ -279,14 +279,16 @@ class TextToSpeech:
 
         self.tokenizer = VoiceBpeTokenizer()
 
-        self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', models_dir)
 
         if os.path.exists(f'{models_dir}/autoregressive.ptt'):
             # Assume this is a traced directory.
             self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt')
             self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt')
         else:
-            self.load_autoregressive_model(self.autoregressive_model_path)
+            if not autoregressive_model_path or not os.path.exists(autoregressive_model_path):
+                autoregressive_model_path = get_model_path('autoregressive.pth', models_dir)
+
+            self.load_autoregressive_model(autoregressive_model_path)
 
             self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
                                           in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
@@ -316,11 +318,14 @@ class TextToSpeech:
         self.loading = False
 
     def load_autoregressive_model(self, autoregressive_model_path):
+        if hasattr(self,"autoregressive_model_path") and self.autoregressive_model_path == autoregressive_model_path:
+            return
+
         self.loading = True
 
-        previous_path = self.autoregressive_model_path
         self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir)
         self.autoregressive_model_hash = hash_file(self.autoregressive_model_path)
+        print(f"Loading autoregressive model: {self.autoregressive_model_path}")
 
         if hasattr(self, 'autoregressive'):
             del self.autoregressive
@@ -335,9 +340,14 @@ class TextToSpeech:
             self.autoregressive = self.autoregressive.to(self.device)
 
         self.loading = False
+        print(f"Loaded autoregressive model")
 
     def load_vocoder_model(self, vocoder_model):
+        if hasattr(self,"vocoder_model_path") and self.vocoder_model_path == vocoder_model:
+            return
+
         self.loading = True
+
         if hasattr(self, 'vocoder'):
             del self.vocoder
 
@@ -357,14 +367,15 @@ class TextToSpeech:
             vocoder_key = 'model_g'
             self.vocoder_model_path = 'vocoder.pth'
             self.vocoder = UnivNetGenerator().cpu()
-
-        print(vocoder_model, vocoder_key, self.vocoder_model_path)
+        
+        print(f"Loading vocoder model: {self.vocoder_model_path}")
         self.vocoder.load_state_dict(torch.load(get_model_path(self.vocoder_model_path, self.models_dir), map_location=torch.device('cpu'))[vocoder_key])
 
         self.vocoder.eval(inference=True)
         if self.preloaded_tensors:
             self.vocoder = self.vocoder.to(self.device)
         self.loading = False
+        print(f"Loaded vocoder model")
 
     def load_cvvp(self):
         """Load CVVP model."""
@@ -427,11 +438,10 @@ class TextToSpeech:
 
             if slices == 0:
                 slices = 1
-            else:
-                if max_chunk_size is not None and chunk_size > max_chunk_size:
-                    slices = 1
-                    while int(chunk_size / slices) > max_chunk_size:
-                        slices = slices + 1
+            elif max_chunk_size is not None and chunk_size > max_chunk_size:
+                slices = 1
+                while int(chunk_size / slices) > max_chunk_size:
+                    slices = slices + 1
 
             chunks = torch.chunk(concat, slices, dim=1)
             chunk_size = chunks[0].shape[-1]