cleaned up some model loading logic, added 'auto' mode for AR model (deduced by current voice)

This commit is contained in:
mrq 2023-03-07 04:34:39 +00:00
parent 3899f9b4e3
commit d7a5ad9fd9
2 changed files with 72 additions and 37 deletions

View File

@ -54,6 +54,8 @@ voicefixer = None
whisper_model = None
training_state = None
current_voice = None
def generate(
text,
delimiter,
@ -117,10 +119,7 @@ def generate(
else:
progress(0, desc=f"Loading voice: {voice}")
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
if hasattr(tts, 'autoregressive_model_hash'):
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
else:
voice_samples, conditioning_latents = load_voice(voice)
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
if voice_samples and len(voice_samples) > 0:
conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks)
@ -146,6 +145,10 @@ def generate(
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
cvvp_weight = 0
autoregressive_model = args.autoregressive_model
if autoregressive_model == "auto":
autoregressive_model = deduce_autoregressive_model(voice)
def get_settings( override=None ):
settings = {
'temperature': float(temperature),
@ -172,7 +175,7 @@ def generate(
'half_p': "Half Precision" in experimental_checkboxes,
'cond_free': "Conditioning-Free" in experimental_checkboxes,
'cvvp_amount': cvvp_weight,
'autoregressive_model': args.autoregressive_model,
'autoregressive_model': autoregressive_model,
}
# could be better to just do a ternary on everything above, but i am not a professional
@ -180,18 +183,10 @@ def generate(
if 'voice' in override:
voice = override['voice']
if "autoregressive_model" in override and override["autoregressive_model"] == "auto":
dir = f'./training/{voice}-finetune/models/'
if os.path.exists(f'./training/finetunes/{voice}.pth'):
override["autoregressive_model"] = f'./training/finetunes/{voice}.pth'
elif os.path.isdir(dir):
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
names = [ f'./{dir}/{d}_gpt.pth' for d in counts ]
override["autoregressive_model"] = names[-1]
else:
override["autoregressive_model"] = None
if "autoregressive_model" in override:
if override["autoregressive_model"] == "auto":
override["autoregressive_model"] = deduce_autoregressive_model(voice)
# necessary to ensure the right model gets loaded for the latents
tts.load_autoregressive_model( override["autoregressive_model"] )
fetched = fetch_voice(voice)
@ -204,8 +199,7 @@ def generate(
continue
settings[k] = override[k]
if hasattr(tts, 'autoregressive_model_path') and tts.autoregressive_model_path != settings["autoregressive_model"]:
tts.load_autoregressive_model( settings["autoregressive_model"] )
tts.load_autoregressive_model( settings["autoregressive_model"] )
# clamp it down for the insane users who want this
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
@ -302,7 +296,7 @@ def generate(
'datetime': datetime.now().isoformat(),
'model': tts.autoregressive_model_path,
'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None,
'model_hash': tts.autoregressive_model_hash
}
if settings is not None:
@ -331,7 +325,7 @@ def generate(
else:
if settings and "model_hash" in settings:
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{settings["model_hash"][:8]}.pth'
elif hasattr(tts, "autoregressive_model_hash"):
else:
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
if latents_path and os.path.exists(latents_path):
@ -387,7 +381,7 @@ def generate(
used_settings['time'] = run_time
used_settings['datetime'] = datetime.now().isoformat(),
used_settings['model'] = tts.autoregressive_model_path
used_settings['model_hash'] = tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None
used_settings['model_hash'] = tts.autoregressive_model_hash
audio_cache[name] = {
'audio': audio,
@ -540,6 +534,9 @@ def hash_file(path, algo="md5", buffer_size=0):
return "{0}".format(hash.hexdigest())
def update_baseline_for_latents_chunks( voice ):
global current_voice
current_voice = voice
path = f'{get_voice_dir()}/{voice}/'
if not os.path.isdir(path):
return 1
@ -583,6 +580,9 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
if hasattr(tts, "loading") and tts.loading:
raise Exception("TTS is still initializing...")
if args.autoregressive_model == "auto":
tts.load_autoregressive_model(deduce_autoregressive_model(voice))
if voice:
load_from_dataset = voice_latents_chunks == 0
@ -620,10 +620,7 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
if len(conditioning_latents) == 4:
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
if hasattr(tts, 'autoregressive_model_hash'):
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
else:
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
return conditioning_latents
@ -1460,6 +1457,9 @@ def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ])
found = found + [ f'./training/{training}/models/{d}_gpt.pth' for d in models ]
if len(found) > 0 or len(additionals) > 0:
base = ["auto"] + base
res = base + additionals + found
if prefixed:
@ -1815,28 +1815,29 @@ def version_check_tts( min_version ):
return True
return False
def load_tts( restart=False, model=None ):
def load_tts( restart=False, autoregressive_model=None ):
global args
global tts
if restart:
unload_tts()
if autoregressive_model:
args.autoregressive_model = autoregressive_model
else:
autoregressive_model = args.autoregressive_model
if model:
args.autoregressive_model = model
if autoregressive_model == "auto":
autoregressive_model = deduce_autoregressive_model()
print(f"Loading TorToiSe... (AR: {args.autoregressive_model}, vocoder: {args.vocoder_model})")
print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {args.vocoder_model})")
tts_loading = True
try:
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model, vocoder_model=args.vocoder_model)
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, vocoder_model=args.vocoder_model)
except Exception as e:
tts = TextToSpeech(minor_optimizations=not args.low_vram)
load_autoregressive_model(args.autoregressive_model)
if not hasattr(tts, 'autoregressive_model_hash'):
tts.autoregressive_model_hash = hash_file(tts.autoregressive_model_path)
load_autoregressive_model(autoregressive_model)
tts_loading = False
@ -1858,6 +1859,37 @@ def unload_tts():
def reload_tts( model=None ):
load_tts( restart=True, model=model )
def get_current_voice():
global current_voice
if current_voice:
return current_voice
settings, _ = read_generate_settings("./config/generate.json", read_latents=False)
if settings and "voice" in settings['voice']:
return settings["voice"]
return None
def deduce_autoregressive_model(voice=None):
if not voice:
voice = get_current_voice()
if voice:
dir = f'./training/{voice}-finetune/models/'
if os.path.exists(f'./training/finetunes/{voice}.pth'):
return f'./training/finetunes/{voice}.pth'
if os.path.isdir(dir):
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
names = [ f'{dir}/{d}_gpt.pth' for d in counts ]
return names[-1]
if args.autoregressive_model != "auto":
return args.autoregressive_model
return get_model_path('autoregressive.pth')
def update_autoregressive_model(autoregressive_model_path):
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
if match:
@ -1880,10 +1912,13 @@ def update_autoregressive_model(autoregressive_model_path):
if hasattr(tts, "loading") and tts.loading:
raise Exception("TTS is still initializing...")
if autoregressive_model_path == "auto":
autoregressive_model_path = deduce_autoregressive_model()
if autoregressive_model_path == tts.autoregressive_model_path:
return
print(f"Loading model: {autoregressive_model_path}")
tts.load_autoregressive_model(autoregressive_model_path)
print(f"Loaded model: {tts.autoregressive_model_path}")
do_gc()

@ -1 +1 @@
Subproject commit e2db36af602297501132f7f68331755f5904825a
Subproject commit 26133c20314b77155e77be804b43909dab9809d6