From a1ae84c49db14f3545fd1e5be5232db83893456d Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Tue, 17 May 2022 12:11:18 -0600
Subject: [PATCH] Add a way to get deterministic behavior from tortoise and add
 debug states for reporting

---
 .gitignore         |  3 ++-
 tortoise/api.py    | 31 +++++++++++++++++++++++++++----
 tortoise/do_tts.py | 12 ++++++++++--
 tortoise/read.py   | 14 +++++++++++++-
 4 files changed, 52 insertions(+), 8 deletions(-)

diff --git a/.gitignore b/.gitignore
index 82504f8..7693938 100644
--- a/.gitignore
+++ b/.gitignore
@@ -131,4 +131,5 @@ dmypy.json
 .idea/*
 .models/*
 .custom/*
-results/*
\ No newline at end of file
+results/*
+debug_states/*
\ No newline at end of file
diff --git a/tortoise/api.py b/tortoise/api.py
index fa915b4..5abcb95 100644
--- a/tortoise/api.py
+++ b/tortoise/api.py
@@ -1,6 +1,7 @@
 import os
 import random
 import uuid
+from time import time
 from urllib import request
 
 import torch
@@ -304,7 +305,8 @@ class TextToSpeech:
         kwargs.update(presets[preset])
         return self.tts(text, **kwargs)
 
-    def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True,
+    def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None,
+            return_deterministic_state=False,
             # autoregressive generation parameters follow
             num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
             # CLVP & CVVP parameters
@@ -359,6 +361,8 @@ class TextToSpeech:
         :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
                  Sample rate is 24kHz.
         """
+        deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
+
         text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
         text_tokens = F.pad(text_tokens, (0, 1))  # This may not be necessary.
         assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
@@ -465,7 +469,26 @@ class TextToSpeech:
                     return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
                 return clip
             wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
-            if len(wav_candidates) > 1:
-                return wav_candidates
-            return wav_candidates[0]
 
+            if len(wav_candidates) > 1:
+                res = wav_candidates
+            else:
+                res = wav_candidates[0]
+
+            if return_deterministic_state:
+                return res, (deterministic_seed, text, voice_samples, conditioning_latents)
+            else:
+                return res
+
+    def deterministic_state(self, seed=None):
+        """
+        Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
+        reproduced.
+        """
+        seed = int(time()) if seed is None else seed
+        torch.manual_seed(seed)
+        random.seed(seed)
+        # Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
+        # torch.use_deterministic_algorithms(True)
+
+        return seed
diff --git a/tortoise/do_tts.py b/tortoise/do_tts.py
index b74466c..eb5af04 100644
--- a/tortoise/do_tts.py
+++ b/tortoise/do_tts.py
@@ -1,6 +1,7 @@
 import argparse
 import os
 
+import torch
 import torchaudio
 
 from api import TextToSpeech
@@ -19,6 +20,8 @@ if __name__ == '__main__':
     parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
                                                       'should only be specified if you have custom checkpoints.', default='.models')
     parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
+    parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
+    parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
     args = parser.parse_args()
     os.makedirs(args.output_path, exist_ok=True)
 
@@ -27,11 +30,16 @@ if __name__ == '__main__':
     selected_voices = args.voice.split(',')
     for k, voice in enumerate(selected_voices):
         voice_samples, conditioning_latents = load_voice(voice)
-        gen = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
-                                  preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider)
+        gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
+                                  preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider,
+                                  use_deterministic_seed=args.seed, return_deterministic_state=True)
         if isinstance(gen, list):
             for j, g in enumerate(gen):
                 torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000)
         else:
             torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000)
 
+        if args.produce_debug_state:
+            os.makedirs('debug_states', exist_ok=True)
+            torch.save(dbg_state, f'debug_states/do_tts_debug_{voice}.pth')
+
diff --git a/tortoise/read.py b/tortoise/read.py
index e81bd71..ae68202 100644
--- a/tortoise/read.py
+++ b/tortoise/read.py
@@ -1,5 +1,6 @@
 import argparse
 import os
+from time import time
 
 import torch
 import torchaudio
@@ -22,6 +23,9 @@ if __name__ == '__main__':
                         default=.5)
     parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
                                                       'should only be specified if you have custom checkpoints.', default='.models')
+    parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
+    parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
+
     args = parser.parse_args()
     tts = TextToSpeech(models_dir=args.model_dir)
 
@@ -41,6 +45,7 @@ if __name__ == '__main__':
     else:
         texts = split_and_recombine_text(text)
 
+    seed = int(time()) if args.seed is None else args.seed
     for selected_voice in selected_voices:
         voice_outpath = os.path.join(outpath, selected_voice)
         os.makedirs(voice_outpath, exist_ok=True)
@@ -57,10 +62,17 @@ if __name__ == '__main__':
                 all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000))
                 continue
             gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
-                                      preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider)
+                                      preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider,
+                                      use_deterministic_seed=seed)
             gen = gen.squeeze(0).cpu()
             torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), gen, 24000)
             all_parts.append(gen)
+
         full_audio = torch.cat(all_parts, dim=-1)
         torchaudio.save(os.path.join(voice_outpath, 'combined.wav'), full_audio, 24000)
 
+        if args.produce_debug_state:
+            os.makedirs('debug_states', exist_ok=True)
+            dbg_state = (seed, texts, voice_samples, conditioning_latents)
+            torch.save(dbg_state, f'debug_states/read_debug_{selected_voice}.pth')
+