From b1fc2b13c90db4a7e1438c3b1620de5730427462 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Sun, 1 May 2022 17:29:25 -0600
Subject: [PATCH] add support for specifying the model_dir

---
 tortoise/do_tts.py | 4 +++-
 tortoise/read.py   | 5 ++++-
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/tortoise/do_tts.py b/tortoise/do_tts.py
index e35be36..6f2bd88 100644
--- a/tortoise/do_tts.py
+++ b/tortoise/do_tts.py
@@ -16,10 +16,12 @@ if __name__ == '__main__':
                         help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility',
                         default=.5)
     parser.add_argument('--output_path', type=str, help='Where to store outputs.', default='results/')
+    parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
+                                                      'should only be specified if you have custom checkpoints.', default='.models')
     args = parser.parse_args()
     os.makedirs(args.output_path, exist_ok=True)
 
-    tts = TextToSpeech()
+    tts = TextToSpeech(models_dir=args.model_dir)
 
     selected_voices = args.voice.split(',')
     for voice in selected_voices:
diff --git a/tortoise/read.py b/tortoise/read.py
index ce65c05..b22f62e 100644
--- a/tortoise/read.py
+++ b/tortoise/read.py
@@ -37,13 +37,17 @@ if __name__ == '__main__':
     parser.add_argument('--voice_diversity_intelligibility_slider', type=float,
                         help='How to balance vocal diversity with the quality/intelligibility of the spoken text. 0 means highly diverse voice (not recommended), 1 means maximize intellibility',
                         default=.5)
+    parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
+                                                      'should only be specified if you have custom checkpoints.', default='.models')
     args = parser.parse_args()
+    tts = TextToSpeech(models_dir=args.model_dir)
 
     outpath = args.output_path
     selected_voices = args.voice.split(',')
     regenerate = args.regenerate
     if regenerate is not None:
         regenerate = [int(e) for e in regenerate.split(',')]
+
     for selected_voice in selected_voices:
         voice_outpath = os.path.join(outpath, selected_voice)
         os.makedirs(voice_outpath, exist_ok=True)
@@ -51,7 +55,6 @@ if __name__ == '__main__':
         with open(args.textfile, 'r', encoding='utf-8') as f:
             text = ''.join([l for l in f.readlines()])
         texts = split_and_recombine_text(text)
-        tts = TextToSpeech()
 
         if '&' in selected_voice:
             voice_sel = selected_voice.split('&')