cleaned up some model loading logic, added 'auto' mode for AR model (deduced by current voice)

This commit is contained in:
mrq 2023-03-07 04:34:39 +00:00
parent 3899f9b4e3
commit d7a5ad9fd9
2 changed files with 72 additions and 37 deletions

View File

@ -54,6 +54,8 @@ voicefixer = None
whisper_model = None whisper_model = None
training_state = None training_state = None
current_voice = None
def generate( def generate(
text, text,
delimiter, delimiter,
@ -117,10 +119,7 @@ def generate(
else: else:
progress(0, desc=f"Loading voice: {voice}") progress(0, desc=f"Loading voice: {voice}")
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts # nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
if hasattr(tts, 'autoregressive_model_hash'): voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
else:
voice_samples, conditioning_latents = load_voice(voice)
if voice_samples and len(voice_samples) > 0: if voice_samples and len(voice_samples) > 0:
conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks) conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks)
@ -146,6 +145,10 @@ def generate(
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.") print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
cvvp_weight = 0 cvvp_weight = 0
autoregressive_model = args.autoregressive_model
if autoregressive_model == "auto":
autoregressive_model = deduce_autoregressive_model(voice)
def get_settings( override=None ): def get_settings( override=None ):
settings = { settings = {
'temperature': float(temperature), 'temperature': float(temperature),
@ -172,7 +175,7 @@ def generate(
'half_p': "Half Precision" in experimental_checkboxes, 'half_p': "Half Precision" in experimental_checkboxes,
'cond_free': "Conditioning-Free" in experimental_checkboxes, 'cond_free': "Conditioning-Free" in experimental_checkboxes,
'cvvp_amount': cvvp_weight, 'cvvp_amount': cvvp_weight,
'autoregressive_model': args.autoregressive_model, 'autoregressive_model': autoregressive_model,
} }
# could be better to just do a ternary on everything above, but i am not a professional # could be better to just do a ternary on everything above, but i am not a professional
@ -180,18 +183,10 @@ def generate(
if 'voice' in override: if 'voice' in override:
voice = override['voice'] voice = override['voice']
if "autoregressive_model" in override and override["autoregressive_model"] == "auto": if "autoregressive_model" in override:
dir = f'./training/{voice}-finetune/models/' if override["autoregressive_model"] == "auto":
if os.path.exists(f'./training/finetunes/{voice}.pth'): override["autoregressive_model"] = deduce_autoregressive_model(voice)
override["autoregressive_model"] = f'./training/finetunes/{voice}.pth'
elif os.path.isdir(dir):
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
names = [ f'./{dir}/{d}_gpt.pth' for d in counts ]
override["autoregressive_model"] = names[-1]
else:
override["autoregressive_model"] = None
# necessary to ensure the right model gets loaded for the latents
tts.load_autoregressive_model( override["autoregressive_model"] ) tts.load_autoregressive_model( override["autoregressive_model"] )
fetched = fetch_voice(voice) fetched = fetch_voice(voice)
@ -204,8 +199,7 @@ def generate(
continue continue
settings[k] = override[k] settings[k] = override[k]
if hasattr(tts, 'autoregressive_model_path') and tts.autoregressive_model_path != settings["autoregressive_model"]: tts.load_autoregressive_model( settings["autoregressive_model"] )
tts.load_autoregressive_model( settings["autoregressive_model"] )
# clamp it down for the insane users who want this # clamp it down for the insane users who want this
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants # it would be wiser to enforce the sample size to the batch size, but this is what the user wants
@ -302,7 +296,7 @@ def generate(
'datetime': datetime.now().isoformat(), 'datetime': datetime.now().isoformat(),
'model': tts.autoregressive_model_path, 'model': tts.autoregressive_model_path,
'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None, 'model_hash': tts.autoregressive_model_hash
} }
if settings is not None: if settings is not None:
@ -331,7 +325,7 @@ def generate(
else: else:
if settings and "model_hash" in settings: if settings and "model_hash" in settings:
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{settings["model_hash"][:8]}.pth' latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{settings["model_hash"][:8]}.pth'
elif hasattr(tts, "autoregressive_model_hash"): else:
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth' latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
if latents_path and os.path.exists(latents_path): if latents_path and os.path.exists(latents_path):
@ -387,7 +381,7 @@ def generate(
used_settings['time'] = run_time used_settings['time'] = run_time
used_settings['datetime'] = datetime.now().isoformat(), used_settings['datetime'] = datetime.now().isoformat(),
used_settings['model'] = tts.autoregressive_model_path used_settings['model'] = tts.autoregressive_model_path
used_settings['model_hash'] = tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None used_settings['model_hash'] = tts.autoregressive_model_hash
audio_cache[name] = { audio_cache[name] = {
'audio': audio, 'audio': audio,
@ -540,6 +534,9 @@ def hash_file(path, algo="md5", buffer_size=0):
return "{0}".format(hash.hexdigest()) return "{0}".format(hash.hexdigest())
def update_baseline_for_latents_chunks( voice ): def update_baseline_for_latents_chunks( voice ):
global current_voice
current_voice = voice
path = f'{get_voice_dir()}/{voice}/' path = f'{get_voice_dir()}/{voice}/'
if not os.path.isdir(path): if not os.path.isdir(path):
return 1 return 1
@ -583,6 +580,9 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
if hasattr(tts, "loading") and tts.loading: if hasattr(tts, "loading") and tts.loading:
raise Exception("TTS is still initializing...") raise Exception("TTS is still initializing...")
if args.autoregressive_model == "auto":
tts.load_autoregressive_model(deduce_autoregressive_model(voice))
if voice: if voice:
load_from_dataset = voice_latents_chunks == 0 load_from_dataset = voice_latents_chunks == 0
@ -620,10 +620,7 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
if len(conditioning_latents) == 4: if len(conditioning_latents) == 4:
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
if hasattr(tts, 'autoregressive_model_hash'): torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
else:
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
return conditioning_latents return conditioning_latents
@ -1460,6 +1457,9 @@ def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ]) models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ])
found = found + [ f'./training/{training}/models/{d}_gpt.pth' for d in models ] found = found + [ f'./training/{training}/models/{d}_gpt.pth' for d in models ]
if len(found) > 0 or len(additionals) > 0:
base = ["auto"] + base
res = base + additionals + found res = base + additionals + found
if prefixed: if prefixed:
@ -1815,28 +1815,29 @@ def version_check_tts( min_version ):
return True return True
return False return False
def load_tts( restart=False, model=None ): def load_tts( restart=False, autoregressive_model=None ):
global args global args
global tts global tts
if restart: if restart:
unload_tts() unload_tts()
if autoregressive_model:
args.autoregressive_model = autoregressive_model
else:
autoregressive_model = args.autoregressive_model
if model: if autoregressive_model == "auto":
args.autoregressive_model = model autoregressive_model = deduce_autoregressive_model()
print(f"Loading TorToiSe... (AR: {args.autoregressive_model}, vocoder: {args.vocoder_model})") print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {args.vocoder_model})")
tts_loading = True tts_loading = True
try: try:
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model, vocoder_model=args.vocoder_model) tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, vocoder_model=args.vocoder_model)
except Exception as e: except Exception as e:
tts = TextToSpeech(minor_optimizations=not args.low_vram) tts = TextToSpeech(minor_optimizations=not args.low_vram)
load_autoregressive_model(args.autoregressive_model) load_autoregressive_model(autoregressive_model)
if not hasattr(tts, 'autoregressive_model_hash'):
tts.autoregressive_model_hash = hash_file(tts.autoregressive_model_path)
tts_loading = False tts_loading = False
@ -1858,6 +1859,37 @@ def unload_tts():
def reload_tts( model=None ): def reload_tts( model=None ):
load_tts( restart=True, model=model ) load_tts( restart=True, model=model )
def get_current_voice():
global current_voice
if current_voice:
return current_voice
settings, _ = read_generate_settings("./config/generate.json", read_latents=False)
if settings and "voice" in settings['voice']:
return settings["voice"]
return None
def deduce_autoregressive_model(voice=None):
if not voice:
voice = get_current_voice()
if voice:
dir = f'./training/{voice}-finetune/models/'
if os.path.exists(f'./training/finetunes/{voice}.pth'):
return f'./training/finetunes/{voice}.pth'
if os.path.isdir(dir):
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
names = [ f'{dir}/{d}_gpt.pth' for d in counts ]
return names[-1]
if args.autoregressive_model != "auto":
return args.autoregressive_model
return get_model_path('autoregressive.pth')
def update_autoregressive_model(autoregressive_model_path): def update_autoregressive_model(autoregressive_model_path):
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path) match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
if match: if match:
@ -1880,10 +1912,13 @@ def update_autoregressive_model(autoregressive_model_path):
if hasattr(tts, "loading") and tts.loading: if hasattr(tts, "loading") and tts.loading:
raise Exception("TTS is still initializing...") raise Exception("TTS is still initializing...")
if autoregressive_model_path == "auto":
autoregressive_model_path = deduce_autoregressive_model()
if autoregressive_model_path == tts.autoregressive_model_path:
return
print(f"Loading model: {autoregressive_model_path}")
tts.load_autoregressive_model(autoregressive_model_path) tts.load_autoregressive_model(autoregressive_model_path)
print(f"Loaded model: {tts.autoregressive_model_path}")
do_gc() do_gc()

@ -1 +1 @@
Subproject commit e2db36af602297501132f7f68331755f5904825a Subproject commit 26133c20314b77155e77be804b43909dab9809d6