From f44239a85aaa4a36266377624fa415e23fadce31 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 19 Feb 2023 05:10:08 +0000 Subject: [PATCH] added polyfill for loading autoregressive models in case mrq/tortoise-tts absolutely refuses to update --- src/utils.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/utils.py b/src/utils.py index 401794a..d4b9c89 100755 --- a/src/utils.py +++ b/src/utils.py @@ -543,7 +543,12 @@ def setup_tortoise(restart=False): tts = None print(f"Initializating TorToiSe... (using model: {args.autoregressive_model})") - tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model) + try: + tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model) + except Exception as e: + tts = TextToSpeech(minor_optimizations=not args.low_vram) + load_autoregressive_model(args.autoregressive_model) + get_model_path('dvae.pth') print("TorToiSe initialized, ready for generation.") return tts @@ -826,7 +831,25 @@ def update_autoregressive_model(path_name): raise Exception("TTS is uninitialized or still initializing...") print(f"Loading model: {path_name}") - tts.load_autoregressive_model(path_name) + + if hasattr(tts, 'load_autoregressive_model') and tts.load_autoregressive_model(path_name): + tts.load_autoregressive_model(path_name) + # polyfill in case a user did NOT update the packages + else: + from tortoise.models.autoregressive import UnifiedVoice + + tts.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', tts.models_dir) + + del tts.autoregressive + tts.autoregressive = UnifiedVoice(max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, + model_dim=1024, + heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, + train_solo_embeddings=False).cpu().eval() + tts.autoregressive.load_state_dict(torch.load(tts.autoregressive_model_path)) + tts.autoregressive.post_init_gpt2_config(kv_cache=tts.use_kv_cache) + if tts.preloaded_tensors: + tts.autoregressive = tts.autoregressive.to(tts.device) + print(f"Loaded model: {tts.autoregressive_model_path}") args.autoregressive_model = path_name