we do a little garbage collection

This commit is contained in:
mrq 2023-02-18 20:37:37 +00:00
parent 58c981d714
commit fc5b303319
2 changed files with 40 additions and 47 deletions

View File

@ -15,6 +15,7 @@ import base64
import re import re
import urllib.request import urllib.request
import signal import signal
import gc
import tqdm import tqdm
import torch import torch
@ -40,6 +41,9 @@ webui = None
voicefixer = None voicefixer = None
whisper_model = None whisper_model = None
def do_gc():
gc.collect()
def get_args(): def get_args():
global args global args
return args return args
@ -152,6 +156,8 @@ def generate(
if not tts: if not tts:
raise Exception("TTS is uninitialized or still initializing...") raise Exception("TTS is uninitialized or still initializing...")
do_gc()
if voice != "microphone": if voice != "microphone":
voices = [voice] voices = [voice]
else: else:
@ -307,6 +313,9 @@ def generate(
# save here in case some error happens mid-batch # save here in case some error happens mid-batch
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate) torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate)
del gen
do_gc()
for k in audio_cache: for k in audio_cache:
audio = audio_cache[k]['audio'] audio = audio_cache[k]['audio']
@ -480,19 +489,44 @@ def setup_tortoise(restart=False):
global args global args
global tts global tts
if args.voice_fixer and not restart: do_gc()
if args.voice_fixer:
setup_voicefixer(restart=restart) setup_voicefixer(restart=restart)
if restart: if restart:
del tts del tts
tts = None tts = None
print("Initializating TorToiSe...") print(f"Initializating TorToiSe... (using model: {args.autoregressive_model})")
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model) tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model)
get_model_path('dvae.pth') get_model_path('dvae.pth')
print("TorToiSe initialized, ready for generation.") print("TorToiSe initialized, ready for generation.")
return tts return tts
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
global tts
global args
if not tts:
raise Exception("TTS is uninitialized or still initializing...")
do_gc()
voice_samples, conditioning_latents = load_voice(voice, load_latents=False)
if voice_samples is None:
return
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents)
if len(conditioning_latents) == 4:
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
return voice
def save_training_settings( iterations=None, batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ): def save_training_settings( iterations=None, batch_size=None, learning_rate=None, print_rate=None, save_rate=None, name=None, dataset_name=None, dataset_path=None, validation_name=None, validation_path=None, output_name=None ):
settings = { settings = {
"iterations": iterations if iterations else 500, "iterations": iterations if iterations else 500,
@ -737,31 +771,11 @@ def update_autoregressive_model(path_name):
raise Exception("TTS is uninitialized or still initializing...") raise Exception("TTS is uninitialized or still initializing...")
print(f"Loading model: {path_name}") print(f"Loading model: {path_name}")
if hasattr(tts, 'load_autoregressive_model') and tts.load_autoregressive_model(path_name): tts.load_autoregressive_model(path_name)
args.autoregressive_model = path_name
save_args_settings()
# polyfill in case a user did NOT update the packages
else:
from tortoise.models.autoregressive import UnifiedVoice
previous_path = tts.autoregressive_model_path
tts.autoregressive_model_path = path_name if path_name and os.path.exists(path_name) else get_model_path('autoregressive.pth')
del tts.autoregressive
tts.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30,
model_dim=1024,
heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False,
train_solo_embeddings=False).cpu().eval()
tts.autoregressive.load_state_dict(torch.load(tts.autoregressive_model_path))
tts.autoregressive.post_init_gpt2_config(kv_cache=tts.use_kv_cache)
if tts.preloaded_tensors:
tts.autoregressive = tts.autoregressive.to(tts.device)
if previous_path != tts.autoregressive_model_path:
args.autoregressive_model = path_name
save_args_settings()
print(f"Loaded model: {tts.autoregressive_model_path}") print(f"Loaded model: {tts.autoregressive_model_path}")
args.autoregressive_model = path_name
save_args_settings()
return path_name return path_name

View File

@ -86,27 +86,6 @@ def run_generation(
gr.update(value=stats, visible=True), gr.update(value=stats, visible=True),
) )
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
global tts
global args
if not tts:
raise Exception("TTS is uninitialized or still initializing...")
voice_samples, conditioning_latents = load_voice(voice, load_latents=False)
if voice_samples is None:
return
conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents)
if len(conditioning_latents) == 4:
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
return voice
def update_presets(value): def update_presets(value):
PRESETS = { PRESETS = {
'Ultra Fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False}, 'Ultra Fast': {'num_autoregressive_samples': 16, 'diffusion_iterations': 30, 'cond_free': False},