From 20220893af7fc2e3541e7642a802142b38f8ec47 Mon Sep 17 00:00:00 2001
From: Johan Nordberg <its@johan-nordberg.com>
Date: Thu, 19 May 2022 11:31:02 +0000
Subject: [PATCH] Allow setting models path from environment variable

---
 tortoise/api.py    | 16 +++++++++-------
 tortoise/do_tts.py |  6 +++---
 tortoise/read.py   |  6 +++---
 3 files changed, 15 insertions(+), 13 deletions(-)

diff --git a/tortoise/api.py b/tortoise/api.py
index fa915b4..5707b99 100644
--- a/tortoise/api.py
+++ b/tortoise/api.py
@@ -25,6 +25,7 @@ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
 
 pbar = None
 
+MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR', '.models')
 
 def download_models(specific_models=None):
     """
@@ -40,7 +41,7 @@ def download_models(specific_models=None):
         'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
         'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
     }
-    os.makedirs('.models', exist_ok=True)
+    os.makedirs(MODELS_DIR, exist_ok=True)
     def show_progress(block_num, block_size, total_size):
         global pbar
         if pbar is None:
@@ -56,10 +57,11 @@ def download_models(specific_models=None):
     for model_name, url in MODELS.items():
         if specific_models is not None and model_name not in specific_models:
             continue
-        if os.path.exists(f'.models/{model_name}'):
+        model_path = os.path.join(MODELS_DIR, model_name)
+        if os.path.exists(model_path):
             continue
         print(f'Downloading {model_name} from {url}...')
-        request.urlretrieve(url, f'.models/{model_name}', show_progress)
+        request.urlretrieve(url, model_path, show_progress)
         print('Done.')
 
 
@@ -154,7 +156,7 @@ def classify_audio_clip(clip):
     classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
                                                     resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
                                                     dropout=0, kernel_size=5, distribute_zero_label=False)
-    classifier.load_state_dict(torch.load('.models/classifier.pth', map_location=torch.device('cpu')))
+    classifier.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'classifier.pth'), map_location=torch.device('cpu')))
     clip = clip.cpu().unsqueeze(0)
     results = F.softmax(classifier(clip), dim=-1)
     return results[0][0]
@@ -181,7 +183,7 @@ class TextToSpeech:
     Main entry point into Tortoise.
     """
 
-    def __init__(self, autoregressive_batch_size=None, models_dir='.models', enable_redaction=True):
+    def __init__(self, autoregressive_batch_size=None, models_dir=MODELS_DIR, enable_redaction=True):
         """
         Constructor
         :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@@ -276,9 +278,9 @@ class TextToSpeech:
         # Lazy-load the RLG models.
         if self.rlg_auto is None:
             self.rlg_auto = RandomLatentConverter(1024).eval()
-            self.rlg_auto.load_state_dict(torch.load('.models/rlg_auto.pth', map_location=torch.device('cpu')))
+            self.rlg_auto.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_auto.pth'), map_location=torch.device('cpu')))
             self.rlg_diffusion = RandomLatentConverter(2048).eval()
-            self.rlg_diffusion.load_state_dict(torch.load('.models/rlg_diffuser.pth', map_location=torch.device('cpu')))
+            self.rlg_diffusion.load_state_dict(torch.load(os.path.join(MODELS_DIR, 'rlg_diffuser.pth'), map_location=torch.device('cpu')))
         with torch.no_grad():
             return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
 
diff --git a/tortoise/do_tts.py b/tortoise/do_tts.py
index b74466c..091781f 100644
--- a/tortoise/do_tts.py
+++ b/tortoise/do_tts.py
@@ -3,8 +3,8 @@ import os
 
 import torchaudio
 
-from api import TextToSpeech
-from tortoise.utils.audio import load_audio, get_voices, load_voice
+from api import TextToSpeech, MODELS_DIR
+from utils.audio import load_voice
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
@@ -17,7 +17,7 @@ if __name__ == '__main__':
                         default=.5)
     parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
     parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
-                                                      'should only be specified if you have custom checkpoints.', default='.models')
+                                                      'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
     parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
     args = parser.parse_args()
     os.makedirs(args.output_path, exist_ok=True)
diff --git a/tortoise/read.py b/tortoise/read.py
index e81bd71..ac284cc 100644
--- a/tortoise/read.py
+++ b/tortoise/read.py
@@ -4,8 +4,8 @@ import os
 import torch
 import torchaudio
 
-from api import TextToSpeech
-from utils.audio import load_audio, get_voices, load_voices
+from api import TextToSpeech, MODELS_DIR
+from utils.audio import load_audio, load_voices
 from utils.text import split_and_recombine_text
 
 
@@ -21,7 +21,7 @@ if __name__ == '__main__':
                         help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility',
                         default=.5)
     parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
-                                                      'should only be specified if you have custom checkpoints.', default='.models')
+                                                      'should only be specified if you have custom checkpoints.', default=MODELS_DIR)
     args = parser.parse_args()
     tts = TextToSpeech(models_dir=args.model_dir)