diff --git a/src/utils.py b/src/utils.py index 8991314..309b72f 100755 --- a/src/utils.py +++ b/src/utils.py @@ -31,7 +31,7 @@ import pandas as pd from datetime import datetime from datetime import timedelta -from tortoise.api import TextToSpeech, MODELS, get_model_path +from tortoise.api import TextToSpeech, MODELS, get_model_path, pad_or_truncate from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir from tortoise.utils.text import split_and_recombine_text from tortoise.utils.device import get_device_name, set_device_name @@ -89,6 +89,8 @@ def generate( if tts_loading: raise Exception("TTS is still initializing...") load_tts() + if hasattr(tts, "loading") and tts.loading: + raise Exception("TTS is still initializing...") do_gc() @@ -121,17 +123,8 @@ def generate( voice_samples, conditioning_latents = load_voice(voice) if voice_samples and len(voice_samples) > 0: + conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks) sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu() - - conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents) - if len(conditioning_latents) == 4: - conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) - - if voice != "microphone": - if hasattr(tts, 'autoregressive_model_hash'): - torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth') - else: - torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') voice_samples = None else: if conditioning_latents is not None: @@ -551,6 +544,10 @@ def update_baseline_for_latents_chunks( voice ): if not os.path.isdir(path): return 1 + dataset_file = f'./training/{voice}/train.txt' + if os.path.exists(dataset_file): + return 0 # 0 will leverage using the LJspeech dataset for computing latents + files = os.listdir(path) total = 0 @@ -565,11 +562,13 @@ def update_baseline_for_latents_chunks( voice ): total_duration += duration total = total + 1 + + # brain too fried to figure out a better way if args.autocalculate_voice_chunk_duration_size == 0: return int(total_duration / total) if total > 0 else 1 return int(total_duration / args.autocalculate_voice_chunk_duration_size) if total_duration > 0 else 1 -def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): +def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, progress=None): global tts global args @@ -581,12 +580,42 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm raise Exception("TTS is still initializing...") load_tts() - voice_samples, conditioning_latents = load_voice(voice, load_latents=False) + if hasattr(tts, "loading") and tts.loading: + raise Exception("TTS is still initializing...") + + if voice: + load_from_dataset = voice_latents_chunks == 0 + + if load_from_dataset: + dataset_path = f'./training/{voice}/train.txt' + if not os.path.exists(dataset_path): + load_from_dataset = False + else: + with open(dataset_path, 'r', encoding="utf-8") as f: + lines = f.readlines() + + print("Leveraging LJSpeech dataset for computing latents") + + voice_samples = [] + max_length = 0 + for line in lines: + filename = f'./training/{voice}/{line.split("|")[0]}' + + waveform = load_audio(filename, 22050) + max_length = max(max_length, waveform.shape[-1]) + voice_samples.append(waveform) + + for i in range(len(voice_samples)): + voice_samples[i] = pad_or_truncate(voice_samples[i], max_length) + + voice_latents_chunks = len(voice_samples) + if not load_from_dataset: + voice_samples, _ = load_voice(voice, load_latents=False) if voice_samples is None: return - conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents) + conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents, progress=progress) if len(conditioning_latents) == 4: conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) @@ -596,7 +625,7 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm else: torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') - return voice + return conditioning_latents # superfluous, but it cleans up some things class TrainingState(): @@ -1847,6 +1876,10 @@ def update_autoregressive_model(autoregressive_model_path): if tts_loading: raise Exception("TTS is still initializing...") return + + if hasattr(tts, "loading") and tts.loading: + raise Exception("TTS is still initializing...") + print(f"Loading model: {autoregressive_model_path}") tts.load_autoregressive_model(autoregressive_model_path) @@ -1867,6 +1900,9 @@ def update_vocoder_model(vocoder_model): raise Exception("TTS is still initializing...") return + if hasattr(tts, "loading") and tts.loading: + raise Exception("TTS is still initializing...") + print(f"Loading model: {vocoder_model}") tts.load_vocoder_model(vocoder_model) print(f"Loaded model: {tts.vocoder_model}") diff --git a/src/webui.py b/src/webui.py index 3090b5d..abc73b3 100755 --- a/src/webui.py +++ b/src/webui.py @@ -163,6 +163,11 @@ def history_view_results( voice ): gr.Dropdown.update(choices=sorted(files)) ) +def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): + compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress ) + return voice + + def import_voices_proxy(files, name, progress=gr.Progress(track_tqdm=True)): import_voices(files, name, progress) return gr.update() @@ -387,7 +392,7 @@ def setup_gradio(): prompt = gr.Textbox(lines=1, label="Custom Emotion") voice = gr.Dropdown(choices=voice_list_with_defaults, label="Voice", type="value", value=voice_list_with_defaults[0]) # it'd be very cash money if gradio was able to default to the first value in the list without this shit mic_audio = gr.Audio( label="Microphone Source", source="microphone", type="filepath", visible=False ) - voice_latents_chunks = gr.Slider(label="Voice Chunks", minimum=1, maximum=128, value=1, step=1) + voice_latents_chunks = gr.Number(label="Voice Chunks", precision=0, value=0) with gr.Row(): refresh_voices = gr.Button(value="Refresh Voice List") recompute_voice_latents = gr.Button(value="(Re)Compute Voice Latents") @@ -704,7 +709,7 @@ def setup_gradio(): ], ) - recompute_voice_latents.click(compute_latents, + recompute_voice_latents.click(compute_latents_proxy, inputs=[ voice, voice_latents_chunks,