diff --git a/src/utils.py b/src/utils.py index 309b72f..0aec5a8 100755 --- a/src/utils.py +++ b/src/utils.py @@ -54,6 +54,8 @@ voicefixer = None whisper_model = None training_state = None +current_voice = None + def generate( text, delimiter, @@ -117,10 +119,7 @@ def generate( else: progress(0, desc=f"Loading voice: {voice}") # nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts - if hasattr(tts, 'autoregressive_model_hash'): - voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash) - else: - voice_samples, conditioning_latents = load_voice(voice) + voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash) if voice_samples and len(voice_samples) > 0: conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks) @@ -146,6 +145,10 @@ def generate( print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.") cvvp_weight = 0 + autoregressive_model = args.autoregressive_model + if autoregressive_model == "auto": + autoregressive_model = deduce_autoregressive_model(voice) + def get_settings( override=None ): settings = { 'temperature': float(temperature), @@ -172,7 +175,7 @@ def generate( 'half_p': "Half Precision" in experimental_checkboxes, 'cond_free': "Conditioning-Free" in experimental_checkboxes, 'cvvp_amount': cvvp_weight, - 'autoregressive_model': args.autoregressive_model, + 'autoregressive_model': autoregressive_model, } # could be better to just do a ternary on everything above, but i am not a professional @@ -180,18 +183,10 @@ def generate( if 'voice' in override: voice = override['voice'] - if "autoregressive_model" in override and override["autoregressive_model"] == "auto": - dir = f'./training/{voice}-finetune/models/' - if os.path.exists(f'./training/finetunes/{voice}.pth'): - override["autoregressive_model"] = f'./training/finetunes/{voice}.pth' - elif os.path.isdir(dir): - counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ]) - names = [ f'./{dir}/{d}_gpt.pth' for d in counts ] - override["autoregressive_model"] = names[-1] - else: - override["autoregressive_model"] = None + if "autoregressive_model" in override: + if override["autoregressive_model"] == "auto": + override["autoregressive_model"] = deduce_autoregressive_model(voice) - # necessary to ensure the right model gets loaded for the latents tts.load_autoregressive_model( override["autoregressive_model"] ) fetched = fetch_voice(voice) @@ -204,8 +199,7 @@ def generate( continue settings[k] = override[k] - if hasattr(tts, 'autoregressive_model_path') and tts.autoregressive_model_path != settings["autoregressive_model"]: - tts.load_autoregressive_model( settings["autoregressive_model"] ) + tts.load_autoregressive_model( settings["autoregressive_model"] ) # clamp it down for the insane users who want this # it would be wiser to enforce the sample size to the batch size, but this is what the user wants @@ -302,7 +296,7 @@ def generate( 'datetime': datetime.now().isoformat(), 'model': tts.autoregressive_model_path, - 'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None, + 'model_hash': tts.autoregressive_model_hash } if settings is not None: @@ -331,7 +325,7 @@ def generate( else: if settings and "model_hash" in settings: latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{settings["model_hash"][:8]}.pth' - elif hasattr(tts, "autoregressive_model_hash"): + else: latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth' if latents_path and os.path.exists(latents_path): @@ -387,7 +381,7 @@ def generate( used_settings['time'] = run_time used_settings['datetime'] = datetime.now().isoformat(), used_settings['model'] = tts.autoregressive_model_path - used_settings['model_hash'] = tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None + used_settings['model_hash'] = tts.autoregressive_model_hash audio_cache[name] = { 'audio': audio, @@ -540,6 +534,9 @@ def hash_file(path, algo="md5", buffer_size=0): return "{0}".format(hash.hexdigest()) def update_baseline_for_latents_chunks( voice ): + global current_voice + current_voice = voice + path = f'{get_voice_dir()}/{voice}/' if not os.path.isdir(path): return 1 @@ -583,6 +580,9 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog if hasattr(tts, "loading") and tts.loading: raise Exception("TTS is still initializing...") + if args.autoregressive_model == "auto": + tts.load_autoregressive_model(deduce_autoregressive_model(voice)) + if voice: load_from_dataset = voice_latents_chunks == 0 @@ -620,10 +620,7 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog if len(conditioning_latents) == 4: conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) - if hasattr(tts, 'autoregressive_model_hash'): - torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth') - else: - torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') + torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth') return conditioning_latents @@ -1460,6 +1457,9 @@ def get_autoregressive_models(dir="./models/finetunes/", prefixed=False): models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ]) found = found + [ f'./training/{training}/models/{d}_gpt.pth' for d in models ] + if len(found) > 0 or len(additionals) > 0: + base = ["auto"] + base + res = base + additionals + found if prefixed: @@ -1815,28 +1815,29 @@ def version_check_tts( min_version ): return True return False -def load_tts( restart=False, model=None ): +def load_tts( restart=False, autoregressive_model=None ): global args global tts if restart: unload_tts() + if autoregressive_model: + args.autoregressive_model = autoregressive_model + else: + autoregressive_model = args.autoregressive_model - if model: - args.autoregressive_model = model + if autoregressive_model == "auto": + autoregressive_model = deduce_autoregressive_model() - print(f"Loading TorToiSe... (AR: {args.autoregressive_model}, vocoder: {args.vocoder_model})") + print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {args.vocoder_model})") tts_loading = True try: - tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model, vocoder_model=args.vocoder_model) + tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, vocoder_model=args.vocoder_model) except Exception as e: tts = TextToSpeech(minor_optimizations=not args.low_vram) - load_autoregressive_model(args.autoregressive_model) - - if not hasattr(tts, 'autoregressive_model_hash'): - tts.autoregressive_model_hash = hash_file(tts.autoregressive_model_path) + load_autoregressive_model(autoregressive_model) tts_loading = False @@ -1858,6 +1859,37 @@ def unload_tts(): def reload_tts( model=None ): load_tts( restart=True, model=model ) +def get_current_voice(): + global current_voice + if current_voice: + return current_voice + + settings, _ = read_generate_settings("./config/generate.json", read_latents=False) + + if settings and "voice" in settings['voice']: + return settings["voice"] + + return None + +def deduce_autoregressive_model(voice=None): + if not voice: + voice = get_current_voice() + + if voice: + dir = f'./training/{voice}-finetune/models/' + if os.path.exists(f'./training/finetunes/{voice}.pth'): + return f'./training/finetunes/{voice}.pth' + + if os.path.isdir(dir): + counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ]) + names = [ f'{dir}/{d}_gpt.pth' for d in counts ] + return names[-1] + + if args.autoregressive_model != "auto": + return args.autoregressive_model + + return get_model_path('autoregressive.pth') + def update_autoregressive_model(autoregressive_model_path): match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path) if match: @@ -1880,10 +1912,13 @@ def update_autoregressive_model(autoregressive_model_path): if hasattr(tts, "loading") and tts.loading: raise Exception("TTS is still initializing...") + if autoregressive_model_path == "auto": + autoregressive_model_path = deduce_autoregressive_model() + + if autoregressive_model_path == tts.autoregressive_model_path: + return - print(f"Loading model: {autoregressive_model_path}") tts.load_autoregressive_model(autoregressive_model_path) - print(f"Loaded model: {tts.autoregressive_model_path}") do_gc() diff --git a/tortoise-tts b/tortoise-tts index e2db36a..26133c2 160000 --- a/tortoise-tts +++ b/tortoise-tts @@ -1 +1 @@ -Subproject commit e2db36af602297501132f7f68331755f5904825a +Subproject commit 26133c20314b77155e77be804b43909dab9809d6