forked from mrq/tortoise-tts
Add a way to get deterministic behavior from tortoise and add debug states for reporting
This commit is contained in:
parent
9eac62598a
commit
aef86d21bf
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -132,3 +132,4 @@ dmypy.json
|
||||||
.models/*
|
.models/*
|
||||||
.custom/*
|
.custom/*
|
||||||
results/*
|
results/*
|
||||||
|
debug_states/*
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import uuid
|
import uuid
|
||||||
|
from time import time
|
||||||
from urllib import request
|
from urllib import request
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -304,7 +305,8 @@ class TextToSpeech:
|
||||||
kwargs.update(presets[preset])
|
kwargs.update(presets[preset])
|
||||||
return self.tts(text, **kwargs)
|
return self.tts(text, **kwargs)
|
||||||
|
|
||||||
def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True,
|
def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None,
|
||||||
|
return_deterministic_state=False,
|
||||||
# autoregressive generation parameters follow
|
# autoregressive generation parameters follow
|
||||||
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
|
num_autoregressive_samples=512, temperature=.8, length_penalty=1, repetition_penalty=2.0, top_p=.8, max_mel_tokens=500,
|
||||||
# CLVP & CVVP parameters
|
# CLVP & CVVP parameters
|
||||||
|
@ -359,6 +361,8 @@ class TextToSpeech:
|
||||||
:return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
:return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||||
Sample rate is 24kHz.
|
Sample rate is 24kHz.
|
||||||
"""
|
"""
|
||||||
|
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
|
||||||
|
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).cuda()
|
||||||
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
||||||
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
|
assert text_tokens.shape[-1] < 400, 'Too much text provided. Break the text up into separate segments and re-try inference.'
|
||||||
|
@ -465,7 +469,26 @@ class TextToSpeech:
|
||||||
return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
|
return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
|
||||||
return clip
|
return clip
|
||||||
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
|
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
|
||||||
if len(wav_candidates) > 1:
|
|
||||||
return wav_candidates
|
|
||||||
return wav_candidates[0]
|
|
||||||
|
|
||||||
|
if len(wav_candidates) > 1:
|
||||||
|
res = wav_candidates
|
||||||
|
else:
|
||||||
|
res = wav_candidates[0]
|
||||||
|
|
||||||
|
if return_deterministic_state:
|
||||||
|
return res, (deterministic_seed, text, voice_samples, conditioning_latents)
|
||||||
|
else:
|
||||||
|
return res
|
||||||
|
|
||||||
|
def deterministic_state(self, seed=None):
|
||||||
|
"""
|
||||||
|
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
|
||||||
|
reproduced.
|
||||||
|
"""
|
||||||
|
seed = int(time()) if seed is None else seed
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
|
||||||
|
# torch.use_deterministic_algorithms(True)
|
||||||
|
|
||||||
|
return seed
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
from api import TextToSpeech
|
from api import TextToSpeech
|
||||||
|
@ -19,6 +20,8 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
||||||
'should only be specified if you have custom checkpoints.', default='.models')
|
'should only be specified if you have custom checkpoints.', default='.models')
|
||||||
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
|
parser.add_argument('--candidates', type=int, help='How many output candidates to produce per-voice.', default=3)
|
||||||
|
parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
|
||||||
|
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
os.makedirs(args.output_path, exist_ok=True)
|
os.makedirs(args.output_path, exist_ok=True)
|
||||||
|
|
||||||
|
@ -27,11 +30,16 @@ if __name__ == '__main__':
|
||||||
selected_voices = args.voice.split(',')
|
selected_voices = args.voice.split(',')
|
||||||
for k, voice in enumerate(selected_voices):
|
for k, voice in enumerate(selected_voices):
|
||||||
voice_samples, conditioning_latents = load_voice(voice)
|
voice_samples, conditioning_latents = load_voice(voice)
|
||||||
gen = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
|
gen, dbg_state = tts.tts_with_preset(args.text, k=args.candidates, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
|
||||||
preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider)
|
preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider,
|
||||||
|
use_deterministic_seed=args.seed, return_deterministic_state=True)
|
||||||
if isinstance(gen, list):
|
if isinstance(gen, list):
|
||||||
for j, g in enumerate(gen):
|
for j, g in enumerate(gen):
|
||||||
torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000)
|
torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}_{j}.wav'), g.squeeze(0).cpu(), 24000)
|
||||||
else:
|
else:
|
||||||
torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000)
|
torchaudio.save(os.path.join(args.output_path, f'{voice}_{k}.wav'), gen.squeeze(0).cpu(), 24000)
|
||||||
|
|
||||||
|
if args.produce_debug_state:
|
||||||
|
os.makedirs('debug_states', exist_ok=True)
|
||||||
|
torch.save(dbg_state, f'debug_states/do_tts_debug_{voice}.pth')
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
from time import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
@ -22,6 +23,9 @@ if __name__ == '__main__':
|
||||||
default=.5)
|
default=.5)
|
||||||
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
parser.add_argument('--model_dir', type=str, help='Where to find pretrained model checkpoints. Tortoise automatically downloads these to .models, so this'
|
||||||
'should only be specified if you have custom checkpoints.', default='.models')
|
'should only be specified if you have custom checkpoints.', default='.models')
|
||||||
|
parser.add_argument('--seed', type=int, help='Random seed which can be used to reproduce results.', default=None)
|
||||||
|
parser.add_argument('--produce_debug_state', type=bool, help='Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.', default=True)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
tts = TextToSpeech(models_dir=args.model_dir)
|
tts = TextToSpeech(models_dir=args.model_dir)
|
||||||
|
|
||||||
|
@ -41,6 +45,7 @@ if __name__ == '__main__':
|
||||||
else:
|
else:
|
||||||
texts = split_and_recombine_text(text)
|
texts = split_and_recombine_text(text)
|
||||||
|
|
||||||
|
seed = int(time()) if args.seed is None else args.seed
|
||||||
for selected_voice in selected_voices:
|
for selected_voice in selected_voices:
|
||||||
voice_outpath = os.path.join(outpath, selected_voice)
|
voice_outpath = os.path.join(outpath, selected_voice)
|
||||||
os.makedirs(voice_outpath, exist_ok=True)
|
os.makedirs(voice_outpath, exist_ok=True)
|
||||||
|
@ -57,10 +62,17 @@ if __name__ == '__main__':
|
||||||
all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000))
|
all_parts.append(load_audio(os.path.join(voice_outpath, f'{j}.wav'), 24000))
|
||||||
continue
|
continue
|
||||||
gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
|
gen = tts.tts_with_preset(text, voice_samples=voice_samples, conditioning_latents=conditioning_latents,
|
||||||
preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider)
|
preset=args.preset, clvp_cvvp_slider=args.voice_diversity_intelligibility_slider,
|
||||||
|
use_deterministic_seed=seed)
|
||||||
gen = gen.squeeze(0).cpu()
|
gen = gen.squeeze(0).cpu()
|
||||||
torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), gen, 24000)
|
torchaudio.save(os.path.join(voice_outpath, f'{j}.wav'), gen, 24000)
|
||||||
all_parts.append(gen)
|
all_parts.append(gen)
|
||||||
|
|
||||||
full_audio = torch.cat(all_parts, dim=-1)
|
full_audio = torch.cat(all_parts, dim=-1)
|
||||||
torchaudio.save(os.path.join(voice_outpath, 'combined.wav'), full_audio, 24000)
|
torchaudio.save(os.path.join(voice_outpath, 'combined.wav'), full_audio, 24000)
|
||||||
|
|
||||||
|
if args.produce_debug_state:
|
||||||
|
os.makedirs('debug_states', exist_ok=True)
|
||||||
|
dbg_state = (seed, texts, voice_samples, conditioning_latents)
|
||||||
|
torch.save(dbg_state, f'debug_states/read_debug_{selected_voice}.pth')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user