From 9afa71542bfbf9810bcd533489b5ca0c5b30fdee Mon Sep 17 00:00:00 2001
From: mrq <mrq@ecker.tech>
Date: Fri, 11 Aug 2023 04:02:36 +0000
Subject: [PATCH] little sloppy hack to try and not load the same model when it
 was already loaded

---
 tortoise/api.py | 19 ++++++++++++-------
 1 file changed, 12 insertions(+), 7 deletions(-)

diff --git a/tortoise/api.py b/tortoise/api.py
index a8de979..0293076 100755
--- a/tortoise/api.py
+++ b/tortoise/api.py
@@ -337,13 +337,18 @@ class TextToSpeech:
         self.loading = False
 
     def load_autoregressive_model(self, autoregressive_model_path):
-        if hasattr(self,"autoregressive_model_path") and self.autoregressive_model_path == autoregressive_model_path:
+        if hasattr(self,"autoregressive_model_path") and os.path.samefile(self.autoregressive_model_path, autoregressive_model_path):
             return
 
-        self.loading = True
-
         self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir)
-        self.autoregressive_model_hash = hash_file(self.autoregressive_model_path)
+        new_hash = hash_file(self.autoregressive_model_path)
+
+        if hasattr(self,"autoregressive_model_hash") and self.autoregressive_model_hash == new_hash:
+            return
+
+        self.autoregressive_model_hash = new_hash
+
+        self.loading = True
         print(f"Loading autoregressive model: {self.autoregressive_model_path}")
 
         if hasattr(self, 'autoregressive'):
@@ -362,7 +367,7 @@ class TextToSpeech:
         print(f"Loaded autoregressive model")
 
     def load_diffusion_model(self, diffusion_model_path):
-        if hasattr(self,"diffusion_model_path") and self.diffusion_model_path == diffusion_model_path:
+        if hasattr(self,"diffusion_model_path") and os.path.samefile(self.diffusion_model_path, diffusion_model_path):
             return
 
         self.loading = True
@@ -384,7 +389,7 @@ class TextToSpeech:
         print(f"Loaded diffusion model")
 
     def load_vocoder_model(self, vocoder_model):
-        if hasattr(self,"vocoder_model_path") and self.vocoder_model_path == vocoder_model:
+        if hasattr(self,"vocoder_model_path") and os.path.samefile(self.vocoder_model_path, vocoder_model):
             return
 
         self.loading = True
@@ -424,7 +429,7 @@ class TextToSpeech:
         print(f"Loaded vocoder model")
 
     def load_tokenizer_json(self, tokenizer_json):
-        if hasattr(self,"tokenizer_json") and self.tokenizer_json == tokenizer_json:
+        if hasattr(self,"tokenizer_json") and os.path.samefile(self.tokenizer_json, tokenizer_json):
             return
         
         self.loading = True