From aef86d21bfa0cf2da378f12fb810afa97366acb5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 17 May 2022 12:11:18 -0600 Subject: [PATCH] Add a way to get deterministic behavior from tortoise and add debug states for reporting --- .gitignore | 3 ++- tortoise/api.py | 31 +++++++++++++++++++++++++++---- tortoise/do_tts.py | 12 ++++++++++-- tortoise/read.py | 14 +++++++++++++- 4 files changed, 52 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 82504f8..7693938 100644 --- a/.gitignore +++ b/.gitignore @@ -131,4 +131,5 @@ dmypy.json .idea/* .models/* .custom/* -results/* \ No newline at end of file +results/* +debug_states/* \ No newline at end of file diff --git a/tortoise/api.py b/tortoise/api.py index fa915b4..5abcb95 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -1,6 +1,7 @@ import os import random import uuid +from time import time from urllib import request import torch @@ -304,7 +305,8 @@ class TextToSpeech: kwargs.update(presets[preset]) return self.tts(text, **kwargs) - def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, + def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None, + return_deterministic_state=False, # autoregressive generation parameters follow num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500, # CLVP & CVVP parameters @@ -359,6 +361,8 @@ class TextToSpeech: :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. Sample rate is 24kHz. """ + deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda() text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.' @@ -465,7 +469,26 @@ class TextToSpeech: return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1) return clip wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates] - if len(wav_candidates) > 1: - return wav_candidates - return wav_candidates[0] + if len(wav_candidates) > 1: + res = wav_candidates + else: + res = wav_candidates[0] + + if return_deterministic_state: + return res, (deterministic_seed, text, voice_samples, conditioning_latents) + else: + return res + + def deterministic_state(self, seed=None): + """ + Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be + reproduced. + """ + seed = int(time()) if seed is None else seed + torch.manual_seed(seed) + random.seed(seed) + # Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary. + # torch.use_deterministic_algorithms(True) + + return seed diff --git a/tortoise/do_tts.py b/tortoise/do_tts.py index b74466c..eb5af04 100644 --- a/tortoise/do_tts.py +++ b/tortoise/do_tts.py @@ -1,6 +1,7 @@ import argparse import os +import torch import torchaudio from api import TextToSpeech @@ -19,6 +20,8 @@ if __name__ == '__main__': parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' 'should only be specified if you have custom checkpoints.', default='.models') parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3) + parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None) + parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True) args = parser.parse_args() os.makedirs(args.output_path, exist_ok=True) @@ -27,11 +30,16 @@ if __name__ == '__main__': selected_voices = args.voice.split(',') for k, voice in enumerate(selected_voices): voice_samples, conditioning_latents = load_voice(voice) - gen = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents, - preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider) + gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents, + preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider, + use_deterministic_seed=args.seed, return_deterministic_state=True) if isinstance(gen, list): for j, g in enumerate(gen): torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000) else: torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000) + if args.produce_debug_state: + os.makedirs('debug_states', exist_ok=True) + torch.save(dbg_state, f'debug_states/do_tts_debug_{voice}.pth') + diff --git a/tortoise/read.py b/tortoise/read.py index e81bd71..ae68202 100644 --- a/tortoise/read.py +++ b/tortoise/read.py @@ -1,5 +1,6 @@ import argparse import os +from time import time import torch import torchaudio @@ -22,6 +23,9 @@ if __name__ == '__main__': default=.5) parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this' 'should only be specified if you have custom checkpoints.', default='.models') + parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None) + parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True) + args = parser.parse_args() tts = TextToSpeech(models_dir=args.model_dir) @@ -41,6 +45,7 @@ if __name__ == '__main__': else: texts = split_and_recombine_text(text) + seed = int(time()) if args.seed is None else args.seed for selected_voice in selected_voices: voice_outpath = os.path.join(outpath, selected_voice) os.makedirs(voice_outpath, exist_ok=True) @@ -57,10 +62,17 @@ if __name__ == '__main__': all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000)) continue gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents, - preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider) + preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider, + use_deterministic_seed=seed) gen = gen.squeeze(0).cpu() torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), gen, 24000) all_parts.append(gen) + full_audio = torch.cat(all_parts, dim=-1) torchaudio.save(os.path.join(voice_outpath, 'combined.wav'), full_audio, 24000) + if args.produce_debug_state: + os.makedirs('debug_states', exist_ok=True) + dbg_state = (seed, texts, voice_samples, conditioning_latents) + torch.save(dbg_state, f'debug_states/read_debug_{selected_voice}.pth') +