cleaned up some model loading logic, added 'auto' mode for AR model (deduced by current voice)
This commit is contained in:
parent
3899f9b4e3
commit
d7a5ad9fd9
107
src/utils.py
107
src/utils.py
|
@ -54,6 +54,8 @@ voicefixer = None
|
|||
whisper_model = None
|
||||
training_state = None
|
||||
|
||||
current_voice = None
|
||||
|
||||
def generate(
|
||||
text,
|
||||
delimiter,
|
||||
|
@ -117,10 +119,7 @@ def generate(
|
|||
else:
|
||||
progress(0, desc=f"Loading voice: {voice}")
|
||||
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
|
||||
if hasattr(tts, 'autoregressive_model_hash'):
|
||||
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
|
||||
else:
|
||||
voice_samples, conditioning_latents = load_voice(voice)
|
||||
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
|
||||
|
||||
if voice_samples and len(voice_samples) > 0:
|
||||
conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks)
|
||||
|
@ -146,6 +145,10 @@ def generate(
|
|||
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
|
||||
cvvp_weight = 0
|
||||
|
||||
autoregressive_model = args.autoregressive_model
|
||||
if autoregressive_model == "auto":
|
||||
autoregressive_model = deduce_autoregressive_model(voice)
|
||||
|
||||
def get_settings( override=None ):
|
||||
settings = {
|
||||
'temperature': float(temperature),
|
||||
|
@ -172,7 +175,7 @@ def generate(
|
|||
'half_p': "Half Precision" in experimental_checkboxes,
|
||||
'cond_free': "Conditioning-Free" in experimental_checkboxes,
|
||||
'cvvp_amount': cvvp_weight,
|
||||
'autoregressive_model': args.autoregressive_model,
|
||||
'autoregressive_model': autoregressive_model,
|
||||
}
|
||||
|
||||
# could be better to just do a ternary on everything above, but i am not a professional
|
||||
|
@ -180,18 +183,10 @@ def generate(
|
|||
if 'voice' in override:
|
||||
voice = override['voice']
|
||||
|
||||
if "autoregressive_model" in override and override["autoregressive_model"] == "auto":
|
||||
dir = f'./training/{voice}-finetune/models/'
|
||||
if os.path.exists(f'./training/finetunes/{voice}.pth'):
|
||||
override["autoregressive_model"] = f'./training/finetunes/{voice}.pth'
|
||||
elif os.path.isdir(dir):
|
||||
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
|
||||
names = [ f'./{dir}/{d}_gpt.pth' for d in counts ]
|
||||
override["autoregressive_model"] = names[-1]
|
||||
else:
|
||||
override["autoregressive_model"] = None
|
||||
if "autoregressive_model" in override:
|
||||
if override["autoregressive_model"] == "auto":
|
||||
override["autoregressive_model"] = deduce_autoregressive_model(voice)
|
||||
|
||||
# necessary to ensure the right model gets loaded for the latents
|
||||
tts.load_autoregressive_model( override["autoregressive_model"] )
|
||||
|
||||
fetched = fetch_voice(voice)
|
||||
|
@ -204,8 +199,7 @@ def generate(
|
|||
continue
|
||||
settings[k] = override[k]
|
||||
|
||||
if hasattr(tts, 'autoregressive_model_path') and tts.autoregressive_model_path != settings["autoregressive_model"]:
|
||||
tts.load_autoregressive_model( settings["autoregressive_model"] )
|
||||
tts.load_autoregressive_model( settings["autoregressive_model"] )
|
||||
|
||||
# clamp it down for the insane users who want this
|
||||
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
|
||||
|
@ -302,7 +296,7 @@ def generate(
|
|||
|
||||
'datetime': datetime.now().isoformat(),
|
||||
'model': tts.autoregressive_model_path,
|
||||
'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None,
|
||||
'model_hash': tts.autoregressive_model_hash
|
||||
}
|
||||
|
||||
if settings is not None:
|
||||
|
@ -331,7 +325,7 @@ def generate(
|
|||
else:
|
||||
if settings and "model_hash" in settings:
|
||||
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{settings["model_hash"][:8]}.pth'
|
||||
elif hasattr(tts, "autoregressive_model_hash"):
|
||||
else:
|
||||
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
|
||||
|
||||
if latents_path and os.path.exists(latents_path):
|
||||
|
@ -387,7 +381,7 @@ def generate(
|
|||
used_settings['time'] = run_time
|
||||
used_settings['datetime'] = datetime.now().isoformat(),
|
||||
used_settings['model'] = tts.autoregressive_model_path
|
||||
used_settings['model_hash'] = tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None
|
||||
used_settings['model_hash'] = tts.autoregressive_model_hash
|
||||
|
||||
audio_cache[name] = {
|
||||
'audio': audio,
|
||||
|
@ -540,6 +534,9 @@ def hash_file(path, algo="md5", buffer_size=0):
|
|||
return "{0}".format(hash.hexdigest())
|
||||
|
||||
def update_baseline_for_latents_chunks( voice ):
|
||||
global current_voice
|
||||
current_voice = voice
|
||||
|
||||
path = f'{get_voice_dir()}/{voice}/'
|
||||
if not os.path.isdir(path):
|
||||
return 1
|
||||
|
@ -583,6 +580,9 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
|
|||
if hasattr(tts, "loading") and tts.loading:
|
||||
raise Exception("TTS is still initializing...")
|
||||
|
||||
if args.autoregressive_model == "auto":
|
||||
tts.load_autoregressive_model(deduce_autoregressive_model(voice))
|
||||
|
||||
if voice:
|
||||
load_from_dataset = voice_latents_chunks == 0
|
||||
|
||||
|
@ -620,10 +620,7 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
|
|||
if len(conditioning_latents) == 4:
|
||||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||
|
||||
if hasattr(tts, 'autoregressive_model_hash'):
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
||||
else:
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
||||
|
||||
return conditioning_latents
|
||||
|
||||
|
@ -1460,6 +1457,9 @@ def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
|
|||
models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ])
|
||||
found = found + [ f'./training/{training}/models/{d}_gpt.pth' for d in models ]
|
||||
|
||||
if len(found) > 0 or len(additionals) > 0:
|
||||
base = ["auto"] + base
|
||||
|
||||
res = base + additionals + found
|
||||
|
||||
if prefixed:
|
||||
|
@ -1815,28 +1815,29 @@ def version_check_tts( min_version ):
|
|||
return True
|
||||
return False
|
||||
|
||||
def load_tts( restart=False, model=None ):
|
||||
def load_tts( restart=False, autoregressive_model=None ):
|
||||
global args
|
||||
global tts
|
||||
|
||||
if restart:
|
||||
unload_tts()
|
||||
|
||||
if autoregressive_model:
|
||||
args.autoregressive_model = autoregressive_model
|
||||
else:
|
||||
autoregressive_model = args.autoregressive_model
|
||||
|
||||
if model:
|
||||
args.autoregressive_model = model
|
||||
if autoregressive_model == "auto":
|
||||
autoregressive_model = deduce_autoregressive_model()
|
||||
|
||||
print(f"Loading TorToiSe... (AR: {args.autoregressive_model}, vocoder: {args.vocoder_model})")
|
||||
print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {args.vocoder_model})")
|
||||
|
||||
tts_loading = True
|
||||
try:
|
||||
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model, vocoder_model=args.vocoder_model)
|
||||
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, vocoder_model=args.vocoder_model)
|
||||
except Exception as e:
|
||||
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
||||
load_autoregressive_model(args.autoregressive_model)
|
||||
|
||||
if not hasattr(tts, 'autoregressive_model_hash'):
|
||||
tts.autoregressive_model_hash = hash_file(tts.autoregressive_model_path)
|
||||
load_autoregressive_model(autoregressive_model)
|
||||
|
||||
tts_loading = False
|
||||
|
||||
|
@ -1858,6 +1859,37 @@ def unload_tts():
|
|||
def reload_tts( model=None ):
|
||||
load_tts( restart=True, model=model )
|
||||
|
||||
def get_current_voice():
|
||||
global current_voice
|
||||
if current_voice:
|
||||
return current_voice
|
||||
|
||||
settings, _ = read_generate_settings("./config/generate.json", read_latents=False)
|
||||
|
||||
if settings and "voice" in settings['voice']:
|
||||
return settings["voice"]
|
||||
|
||||
return None
|
||||
|
||||
def deduce_autoregressive_model(voice=None):
|
||||
if not voice:
|
||||
voice = get_current_voice()
|
||||
|
||||
if voice:
|
||||
dir = f'./training/{voice}-finetune/models/'
|
||||
if os.path.exists(f'./training/finetunes/{voice}.pth'):
|
||||
return f'./training/finetunes/{voice}.pth'
|
||||
|
||||
if os.path.isdir(dir):
|
||||
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
|
||||
names = [ f'{dir}/{d}_gpt.pth' for d in counts ]
|
||||
return names[-1]
|
||||
|
||||
if args.autoregressive_model != "auto":
|
||||
return args.autoregressive_model
|
||||
|
||||
return get_model_path('autoregressive.pth')
|
||||
|
||||
def update_autoregressive_model(autoregressive_model_path):
|
||||
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
|
||||
if match:
|
||||
|
@ -1880,10 +1912,13 @@ def update_autoregressive_model(autoregressive_model_path):
|
|||
if hasattr(tts, "loading") and tts.loading:
|
||||
raise Exception("TTS is still initializing...")
|
||||
|
||||
if autoregressive_model_path == "auto":
|
||||
autoregressive_model_path = deduce_autoregressive_model()
|
||||
|
||||
if autoregressive_model_path == tts.autoregressive_model_path:
|
||||
return
|
||||
|
||||
print(f"Loading model: {autoregressive_model_path}")
|
||||
tts.load_autoregressive_model(autoregressive_model_path)
|
||||
print(f"Loaded model: {tts.autoregressive_model_path}")
|
||||
|
||||
do_gc()
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit e2db36af602297501132f7f68331755f5904825a
|
||||
Subproject commit 26133c20314b77155e77be804b43909dab9809d6
|
Loading…
Reference in New Issue
Block a user