Fixed model setting not getting updated when TTS is unloaded, for when you change it and then load TTS (sorry for that brain worm)

This commit is contained in:
mrq 2023-02-19 16:24:06 +00:00
parent 092dd7b2d7
commit 4f79b3724b

View File

@ -838,6 +838,10 @@ def update_whisper_model(name):
whisper_model = whisper.load_model(args.whisper_model) whisper_model = whisper.load_model(args.whisper_model)
def update_autoregressive_model(path_name): def update_autoregressive_model(path_name):
args.autoregressive_model = path_name
save_args_settings()
print(f'Stored autoregressive model to settings: {path_name}')
global tts global tts
if not tts: if not tts:
raise Exception("TTS is uninitialized or still initializing...") raise Exception("TTS is uninitialized or still initializing...")
@ -847,6 +851,7 @@ def update_autoregressive_model(path_name):
if hasattr(tts, 'load_autoregressive_model') and tts.load_autoregressive_model(path_name): if hasattr(tts, 'load_autoregressive_model') and tts.load_autoregressive_model(path_name):
tts.load_autoregressive_model(path_name) tts.load_autoregressive_model(path_name)
# polyfill in case a user did NOT update the packages # polyfill in case a user did NOT update the packages
# this shouldn't happen anymore, as I just clone mrq/tortoise-tts, and inject it into sys.path
else: else:
from tortoise.models.autoregressive import UnifiedVoice from tortoise.models.autoregressive import UnifiedVoice
@ -864,9 +869,6 @@ def update_autoregressive_model(path_name):
print(f"Loaded model: {tts.autoregressive_model_path}") print(f"Loaded model: {tts.autoregressive_model_path}")
args.autoregressive_model = path_name
save_args_settings()
return path_name return path_name
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ): def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume ):