From 9afa71542bfbf9810bcd533489b5ca0c5b30fdee Mon Sep 17 00:00:00 2001 From: mrq <mrq@ecker.tech> Date: Fri, 11 Aug 2023 04:02:36 +0000 Subject: [PATCH] little sloppy hack to try and not load the same model when it was already loaded --- tortoise/api.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tortoise/api.py b/tortoise/api.py index a8de979..0293076 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -337,13 +337,18 @@ class TextToSpeech: self.loading = False def load_autoregressive_model(self, autoregressive_model_path): - if hasattr(self,"autoregressive_model_path") and self.autoregressive_model_path == autoregressive_model_path: + if hasattr(self,"autoregressive_model_path") and os.path.samefile(self.autoregressive_model_path, autoregressive_model_path): return - self.loading = True - self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir) - self.autoregressive_model_hash = hash_file(self.autoregressive_model_path) + new_hash = hash_file(self.autoregressive_model_path) + + if hasattr(self,"autoregressive_model_hash") and self.autoregressive_model_hash == new_hash: + return + + self.autoregressive_model_hash = new_hash + + self.loading = True print(f"Loading autoregressive model: {self.autoregressive_model_path}") if hasattr(self, 'autoregressive'): @@ -362,7 +367,7 @@ class TextToSpeech: print(f"Loaded autoregressive model") def load_diffusion_model(self, diffusion_model_path): - if hasattr(self,"diffusion_model_path") and self.diffusion_model_path == diffusion_model_path: + if hasattr(self,"diffusion_model_path") and os.path.samefile(self.diffusion_model_path, diffusion_model_path): return self.loading = True @@ -384,7 +389,7 @@ class TextToSpeech: print(f"Loaded diffusion model") def load_vocoder_model(self, vocoder_model): - if hasattr(self,"vocoder_model_path") and self.vocoder_model_path == vocoder_model: + if hasattr(self,"vocoder_model_path") and os.path.samefile(self.vocoder_model_path, vocoder_model): return self.loading = True @@ -424,7 +429,7 @@ class TextToSpeech: print(f"Loaded vocoder model") def load_tokenizer_json(self, tokenizer_json): - if hasattr(self,"tokenizer_json") and self.tokenizer_json == tokenizer_json: + if hasattr(self,"tokenizer_json") and os.path.samefile(self.tokenizer_json, tokenizer_json): return self.loading = True