From d0caf7e695d27c1d92cad4d6d135face13d84e43 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Sun, 1 May 2022 14:51:44 -0600
Subject: [PATCH] add option to specify model directory to API

---
 api.py | 37 +++++++++++++++++++++----------------
 1 file changed, 21 insertions(+), 16 deletions(-)

diff --git a/api.py b/api.py
index 6aa94cf..92e82be 100644
--- a/api.py
+++ b/api.py
@@ -170,35 +170,40 @@ class TextToSpeech:
     :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
                                       GPU OOM errors. Larger numbers generates slightly faster.
     """
-    def __init__(self, autoregressive_batch_size=16):
+    def __init__(self, autoregressive_batch_size=16, models_dir='.models'):
         self.autoregressive_batch_size = autoregressive_batch_size
         self.tokenizer = VoiceBpeTokenizer()
         download_models()
 
-        self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
-                                      model_dim=1024,
-                                      heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
-                                      train_solo_embeddings=False,
-                                      average_conditioning_embeddings=True).cpu().eval()
-        self.autoregressive.load_state_dict(torch.load('.models/autoregressive.pth'))
+        if os.path.exists(f'{models_dir}/autoregressive.ptt'):
+            # Assume this is a traced directory.
+            self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt')
+            self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt')
+        else:
+            self.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
+                                          model_dim=1024,
+                                          heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
+                                          train_solo_embeddings=False,
+                                          average_conditioning_embeddings=True).cpu().eval()
+            self.autoregressive.load_state_dict(torch.load(f'{models_dir}/autoregressive.pth'))
+
+            self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
+                                          in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
+                                          layer_drop=0, unconditioned_percentage=0).cpu().eval()
+            self.diffusion.load_state_dict(torch.load(f'{models_dir}/diffusion_decoder.pth'))
 
         self.clvp = CLVP(dim_text=512, dim_speech=512, dim_latent=512, num_text_tokens=256, text_enc_depth=12,
                          text_seq_len=350, text_heads=8,
                          num_speech_tokens=8192, speech_enc_depth=12, speech_heads=8, speech_seq_len=430,
                          use_xformers=True).cpu().eval()
-        self.clvp.load_state_dict(torch.load('.models/clvp.pth'))
+        self.clvp.load_state_dict(torch.load(f'{models_dir}/clvp.pth'))
 
         self.cvvp = CVVP(model_dim=512, transformer_heads=8, dropout=0, mel_codes=8192, conditioning_enc_depth=8, cond_mask_percentage=0,
                          speech_enc_depth=8, speech_mask_percentage=0, latent_multiplier=1).cpu().eval()
-        self.cvvp.load_state_dict(torch.load('.models/cvvp.pth'))
-
-        self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
-                                      in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
-                                      layer_drop=0, unconditioned_percentage=0).cpu().eval()
-        self.diffusion.load_state_dict(torch.load('.models/diffusion_decoder.pth'))
+        self.cvvp.load_state_dict(torch.load(f'{models_dir}/cvvp.pth'))
 
         self.vocoder = UnivNetGenerator().cpu()
-        self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g'])
+        self.vocoder.load_state_dict(torch.load(f'{models_dir}/vocoder.pth')['model_g'])
         self.vocoder.eval(inference=True)
 
     def tts_with_preset(self, text, voice_samples, preset='fast', **kwargs):
@@ -216,7 +221,7 @@ class TextToSpeech:
                        'cond_free_k': 2.0, 'diffusion_temperature': 1.0})
         # Presets are defined here.
         presets = {
-            'ultra_fast': {'num_autoregressive_samples': 32, 'diffusion_iterations': 16, 'cond_free': False},
+            'ultra_fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 32, 'cond_free': False},
             'fast': {'num_autoregressive_samples': 96, 'diffusion_iterations': 32},
             'standard': {'num_autoregressive_samples': 256, 'diffusion_iterations': 128},
             'high_quality': {'num_autoregressive_samples': 512, 'diffusion_iterations': 1024},