diff --git a/src/utils.py b/src/utils.py index 82fe106..f9ce8d3 100755 --- a/src/utils.py +++ b/src/utils.py @@ -15,6 +15,7 @@ import base64 import re import urllib.request import signal +import gc import tqdm import torch @@ -40,6 +41,9 @@ webui = None voicefixer = None whisper_model = None +def do_gc(): + gc.collect() + def get_args(): global args return args @@ -152,6 +156,8 @@ def generate( if not tts: raise Exception("TTS is uninitialized or still initializing...") + do_gc() + if voice != "microphone": voices = [voice] else: @@ -307,6 +313,9 @@ def generate( # save here in case some error happens mid-batch torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate) + del gen + do_gc() + for k in audio_cache: audio = audio_cache[k]['audio'] @@ -480,19 +489,44 @@ def setup_tortoise(restart=False): global args global tts - if args.voice_fixer and not restart: + do_gc() + + if args.voice_fixer: setup_voicefixer(restart=restart) if restart: del tts tts = None - print("Initializating TorToiSe...") + print(f"Initializating TorToiSe... (using model: {args.autoregressive_model})") tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model) get_model_path('dvae.pth') print("TorToiSe initialized, ready for generation.") return tts +def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): + global tts + global args + + if not tts: + raise Exception("TTS is uninitialized or still initializing...") + + do_gc() + + voice_samples, conditioning_latents = load_voice(voice, load_latents=False) + + if voice_samples is None: + return + + conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents) + + if len(conditioning_latents) == 4: + conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) + + torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') + + return voice + def save_training_settings( iterations=None, batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ): settings = { "iterations": iterations if iterations else 500, @@ -737,31 +771,11 @@ def update_autoregressive_model(path_name): raise Exception("TTS is uninitialized or still initializing...") print(f"Loading model: {path_name}") - if hasattr(tts, 'load_autoregressive_model') and tts.load_autoregressive_model(path_name): - args.autoregressive_model = path_name - save_args_settings() - # polyfill in case a user did NOT update the packages - else: - from tortoise.models.autoregressive import UnifiedVoice - - previous_path = tts.autoregressive_model_path - tts.autoregressive_model_path = path_name if path_name and os.path.exists(path_name) else get_model_path('autoregressive.pth') - - del tts.autoregressive - tts.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, - model_dim=1024, - heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, - train_solo_embeddings=False).cpu().eval() - tts.autoregressive.load_state_dict(torch.load(tts.autoregressive_model_path)) - tts.autoregressive.post_init_gpt2_config(kv_cache=tts.use_kv_cache) - if tts.preloaded_tensors: - tts.autoregressive = tts.autoregressive.to(tts.device) - - if previous_path != tts.autoregressive_model_path: - args.autoregressive_model = path_name - save_args_settings() - + tts.load_autoregressive_model(path_name) print(f"Loaded model: {tts.autoregressive_model_path}") + + args.autoregressive_model = path_name + save_args_settings() return path_name diff --git a/src/webui.py b/src/webui.py index 2cca424..cab4055 100755 --- a/src/webui.py +++ b/src/webui.py @@ -86,27 +86,6 @@ def run_generation( gr.update(value=stats, visible=True), ) -def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): - global tts - global args - - if not tts: - raise Exception("TTS is uninitialized or still initializing...") - - voice_samples, conditioning_latents = load_voice(voice, load_latents=False) - - if voice_samples is None: - return - - conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents) - - if len(conditioning_latents) == 4: - conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) - - torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') - - return voice - def update_presets(value): PRESETS = { 'Ultra Fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False},