From 26133c20314b77155e77be804b43909dab9809d6 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 7 Mar 2023 04:33:49 +0000 Subject: [PATCH] do not reload AR/vocoder if already loaded --- tortoise/api.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/tortoise/api.py b/tortoise/api.py index bacbd7a..9765a87 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -279,14 +279,16 @@ class TextToSpeech: self.tokenizer = VoiceBpeTokenizer() - self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', models_dir) if os.path.exists(f'{models_dir}/autoregressive.ptt'): # Assume this is a traced directory. self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt') self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt') else: - self.load_autoregressive_model(self.autoregressive_model_path) + if not autoregressive_model_path or not os.path.exists(autoregressive_model_path): + autoregressive_model_path = get_model_path('autoregressive.pth', models_dir) + + self.load_autoregressive_model(autoregressive_model_path) self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, @@ -316,11 +318,14 @@ class TextToSpeech: self.loading = False def load_autoregressive_model(self, autoregressive_model_path): + if hasattr(self,"autoregressive_model_path") and self.autoregressive_model_path == autoregressive_model_path: + return + self.loading = True - previous_path = self.autoregressive_model_path self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir) self.autoregressive_model_hash = hash_file(self.autoregressive_model_path) + print(f"Loading autoregressive model: {self.autoregressive_model_path}") if hasattr(self, 'autoregressive'): del self.autoregressive @@ -335,9 +340,14 @@ class TextToSpeech: self.autoregressive = self.autoregressive.to(self.device) self.loading = False + print(f"Loaded autoregressive model") def load_vocoder_model(self, vocoder_model): + if hasattr(self,"vocoder_model_path") and self.vocoder_model_path == vocoder_model: + return + self.loading = True + if hasattr(self, 'vocoder'): del self.vocoder @@ -357,14 +367,15 @@ class TextToSpeech: vocoder_key = 'model_g' self.vocoder_model_path = 'vocoder.pth' self.vocoder = UnivNetGenerator().cpu() - - print(vocoder_model, vocoder_key, self.vocoder_model_path) + + print(f"Loading vocoder model: {self.vocoder_model_path}") self.vocoder.load_state_dict(torch.load(get_model_path(self.vocoder_model_path, self.models_dir), map_location=torch.device('cpu'))[vocoder_key]) self.vocoder.eval(inference=True) if self.preloaded_tensors: self.vocoder = self.vocoder.to(self.device) self.loading = False + print(f"Loaded vocoder model") def load_cvvp(self): """Load CVVP model.""" @@ -427,11 +438,10 @@ class TextToSpeech: if slices == 0: slices = 1 - else: - if max_chunk_size is not None and chunk_size > max_chunk_size: - slices = 1 - while int(chunk_size / slices) > max_chunk_size: - slices = slices + 1 + elif max_chunk_size is not None and chunk_size > max_chunk_size: + slices = 1 + while int(chunk_size / slices) > max_chunk_size: + slices = slices + 1 chunks = torch.chunk(concat, slices, dim=1) chunk_size = chunks[0].shape[-1]