forked from mrq/ai-voice-cloning
cleaned up some model loading logic, added 'auto' mode for AR model (deduced by current voice)
This commit is contained in:
parent
3899f9b4e3
commit
d7a5ad9fd9
107
src/utils.py
107
src/utils.py
|
@ -54,6 +54,8 @@ voicefixer = None
|
||||||
whisper_model = None
|
whisper_model = None
|
||||||
training_state = None
|
training_state = None
|
||||||
|
|
||||||
|
current_voice = None
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
text,
|
text,
|
||||||
delimiter,
|
delimiter,
|
||||||
|
@ -117,10 +119,7 @@ def generate(
|
||||||
else:
|
else:
|
||||||
progress(0, desc=f"Loading voice: {voice}")
|
progress(0, desc=f"Loading voice: {voice}")
|
||||||
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
|
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
|
||||||
if hasattr(tts, 'autoregressive_model_hash'):
|
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
|
||||||
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
|
|
||||||
else:
|
|
||||||
voice_samples, conditioning_latents = load_voice(voice)
|
|
||||||
|
|
||||||
if voice_samples and len(voice_samples) > 0:
|
if voice_samples and len(voice_samples) > 0:
|
||||||
conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks)
|
conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks)
|
||||||
|
@ -146,6 +145,10 @@ def generate(
|
||||||
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
|
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
|
||||||
cvvp_weight = 0
|
cvvp_weight = 0
|
||||||
|
|
||||||
|
autoregressive_model = args.autoregressive_model
|
||||||
|
if autoregressive_model == "auto":
|
||||||
|
autoregressive_model = deduce_autoregressive_model(voice)
|
||||||
|
|
||||||
def get_settings( override=None ):
|
def get_settings( override=None ):
|
||||||
settings = {
|
settings = {
|
||||||
'temperature': float(temperature),
|
'temperature': float(temperature),
|
||||||
|
@ -172,7 +175,7 @@ def generate(
|
||||||
'half_p': "Half Precision" in experimental_checkboxes,
|
'half_p': "Half Precision" in experimental_checkboxes,
|
||||||
'cond_free': "Conditioning-Free" in experimental_checkboxes,
|
'cond_free': "Conditioning-Free" in experimental_checkboxes,
|
||||||
'cvvp_amount': cvvp_weight,
|
'cvvp_amount': cvvp_weight,
|
||||||
'autoregressive_model': args.autoregressive_model,
|
'autoregressive_model': autoregressive_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
# could be better to just do a ternary on everything above, but i am not a professional
|
# could be better to just do a ternary on everything above, but i am not a professional
|
||||||
|
@ -180,18 +183,10 @@ def generate(
|
||||||
if 'voice' in override:
|
if 'voice' in override:
|
||||||
voice = override['voice']
|
voice = override['voice']
|
||||||
|
|
||||||
if "autoregressive_model" in override and override["autoregressive_model"] == "auto":
|
if "autoregressive_model" in override:
|
||||||
dir = f'./training/{voice}-finetune/models/'
|
if override["autoregressive_model"] == "auto":
|
||||||
if os.path.exists(f'./training/finetunes/{voice}.pth'):
|
override["autoregressive_model"] = deduce_autoregressive_model(voice)
|
||||||
override["autoregressive_model"] = f'./training/finetunes/{voice}.pth'
|
|
||||||
elif os.path.isdir(dir):
|
|
||||||
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
|
|
||||||
names = [ f'./{dir}/{d}_gpt.pth' for d in counts ]
|
|
||||||
override["autoregressive_model"] = names[-1]
|
|
||||||
else:
|
|
||||||
override["autoregressive_model"] = None
|
|
||||||
|
|
||||||
# necessary to ensure the right model gets loaded for the latents
|
|
||||||
tts.load_autoregressive_model( override["autoregressive_model"] )
|
tts.load_autoregressive_model( override["autoregressive_model"] )
|
||||||
|
|
||||||
fetched = fetch_voice(voice)
|
fetched = fetch_voice(voice)
|
||||||
|
@ -204,8 +199,7 @@ def generate(
|
||||||
continue
|
continue
|
||||||
settings[k] = override[k]
|
settings[k] = override[k]
|
||||||
|
|
||||||
if hasattr(tts, 'autoregressive_model_path') and tts.autoregressive_model_path != settings["autoregressive_model"]:
|
tts.load_autoregressive_model( settings["autoregressive_model"] )
|
||||||
tts.load_autoregressive_model( settings["autoregressive_model"] )
|
|
||||||
|
|
||||||
# clamp it down for the insane users who want this
|
# clamp it down for the insane users who want this
|
||||||
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
|
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
|
||||||
|
@ -302,7 +296,7 @@ def generate(
|
||||||
|
|
||||||
'datetime': datetime.now().isoformat(),
|
'datetime': datetime.now().isoformat(),
|
||||||
'model': tts.autoregressive_model_path,
|
'model': tts.autoregressive_model_path,
|
||||||
'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None,
|
'model_hash': tts.autoregressive_model_hash
|
||||||
}
|
}
|
||||||
|
|
||||||
if settings is not None:
|
if settings is not None:
|
||||||
|
@ -331,7 +325,7 @@ def generate(
|
||||||
else:
|
else:
|
||||||
if settings and "model_hash" in settings:
|
if settings and "model_hash" in settings:
|
||||||
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{settings["model_hash"][:8]}.pth'
|
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{settings["model_hash"][:8]}.pth'
|
||||||
elif hasattr(tts, "autoregressive_model_hash"):
|
else:
|
||||||
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
|
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
|
||||||
|
|
||||||
if latents_path and os.path.exists(latents_path):
|
if latents_path and os.path.exists(latents_path):
|
||||||
|
@ -387,7 +381,7 @@ def generate(
|
||||||
used_settings['time'] = run_time
|
used_settings['time'] = run_time
|
||||||
used_settings['datetime'] = datetime.now().isoformat(),
|
used_settings['datetime'] = datetime.now().isoformat(),
|
||||||
used_settings['model'] = tts.autoregressive_model_path
|
used_settings['model'] = tts.autoregressive_model_path
|
||||||
used_settings['model_hash'] = tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None
|
used_settings['model_hash'] = tts.autoregressive_model_hash
|
||||||
|
|
||||||
audio_cache[name] = {
|
audio_cache[name] = {
|
||||||
'audio': audio,
|
'audio': audio,
|
||||||
|
@ -540,6 +534,9 @@ def hash_file(path, algo="md5", buffer_size=0):
|
||||||
return "{0}".format(hash.hexdigest())
|
return "{0}".format(hash.hexdigest())
|
||||||
|
|
||||||
def update_baseline_for_latents_chunks( voice ):
|
def update_baseline_for_latents_chunks( voice ):
|
||||||
|
global current_voice
|
||||||
|
current_voice = voice
|
||||||
|
|
||||||
path = f'{get_voice_dir()}/{voice}/'
|
path = f'{get_voice_dir()}/{voice}/'
|
||||||
if not os.path.isdir(path):
|
if not os.path.isdir(path):
|
||||||
return 1
|
return 1
|
||||||
|
@ -583,6 +580,9 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
|
||||||
if hasattr(tts, "loading") and tts.loading:
|
if hasattr(tts, "loading") and tts.loading:
|
||||||
raise Exception("TTS is still initializing...")
|
raise Exception("TTS is still initializing...")
|
||||||
|
|
||||||
|
if args.autoregressive_model == "auto":
|
||||||
|
tts.load_autoregressive_model(deduce_autoregressive_model(voice))
|
||||||
|
|
||||||
if voice:
|
if voice:
|
||||||
load_from_dataset = voice_latents_chunks == 0
|
load_from_dataset = voice_latents_chunks == 0
|
||||||
|
|
||||||
|
@ -620,10 +620,7 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
|
||||||
if len(conditioning_latents) == 4:
|
if len(conditioning_latents) == 4:
|
||||||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||||
|
|
||||||
if hasattr(tts, 'autoregressive_model_hash'):
|
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
||||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
|
||||||
else:
|
|
||||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
|
||||||
|
|
||||||
return conditioning_latents
|
return conditioning_latents
|
||||||
|
|
||||||
|
@ -1460,6 +1457,9 @@ def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
|
||||||
models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ])
|
models = sorted([ int(d[:-8]) for d in os.listdir(f'./training/{training}/models/') if d[-8:] == "_gpt.pth" ])
|
||||||
found = found + [ f'./training/{training}/models/{d}_gpt.pth' for d in models ]
|
found = found + [ f'./training/{training}/models/{d}_gpt.pth' for d in models ]
|
||||||
|
|
||||||
|
if len(found) > 0 or len(additionals) > 0:
|
||||||
|
base = ["auto"] + base
|
||||||
|
|
||||||
res = base + additionals + found
|
res = base + additionals + found
|
||||||
|
|
||||||
if prefixed:
|
if prefixed:
|
||||||
|
@ -1815,28 +1815,29 @@ def version_check_tts( min_version ):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def load_tts( restart=False, model=None ):
|
def load_tts( restart=False, autoregressive_model=None ):
|
||||||
global args
|
global args
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
if restart:
|
if restart:
|
||||||
unload_tts()
|
unload_tts()
|
||||||
|
|
||||||
|
if autoregressive_model:
|
||||||
|
args.autoregressive_model = autoregressive_model
|
||||||
|
else:
|
||||||
|
autoregressive_model = args.autoregressive_model
|
||||||
|
|
||||||
if model:
|
if autoregressive_model == "auto":
|
||||||
args.autoregressive_model = model
|
autoregressive_model = deduce_autoregressive_model()
|
||||||
|
|
||||||
print(f"Loading TorToiSe... (AR: {args.autoregressive_model}, vocoder: {args.vocoder_model})")
|
print(f"Loading TorToiSe... (AR: {autoregressive_model}, vocoder: {args.vocoder_model})")
|
||||||
|
|
||||||
tts_loading = True
|
tts_loading = True
|
||||||
try:
|
try:
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model, vocoder_model=args.vocoder_model)
|
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=autoregressive_model, vocoder_model=args.vocoder_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
||||||
load_autoregressive_model(args.autoregressive_model)
|
load_autoregressive_model(autoregressive_model)
|
||||||
|
|
||||||
if not hasattr(tts, 'autoregressive_model_hash'):
|
|
||||||
tts.autoregressive_model_hash = hash_file(tts.autoregressive_model_path)
|
|
||||||
|
|
||||||
tts_loading = False
|
tts_loading = False
|
||||||
|
|
||||||
|
@ -1858,6 +1859,37 @@ def unload_tts():
|
||||||
def reload_tts( model=None ):
|
def reload_tts( model=None ):
|
||||||
load_tts( restart=True, model=model )
|
load_tts( restart=True, model=model )
|
||||||
|
|
||||||
|
def get_current_voice():
|
||||||
|
global current_voice
|
||||||
|
if current_voice:
|
||||||
|
return current_voice
|
||||||
|
|
||||||
|
settings, _ = read_generate_settings("./config/generate.json", read_latents=False)
|
||||||
|
|
||||||
|
if settings and "voice" in settings['voice']:
|
||||||
|
return settings["voice"]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def deduce_autoregressive_model(voice=None):
|
||||||
|
if not voice:
|
||||||
|
voice = get_current_voice()
|
||||||
|
|
||||||
|
if voice:
|
||||||
|
dir = f'./training/{voice}-finetune/models/'
|
||||||
|
if os.path.exists(f'./training/finetunes/{voice}.pth'):
|
||||||
|
return f'./training/finetunes/{voice}.pth'
|
||||||
|
|
||||||
|
if os.path.isdir(dir):
|
||||||
|
counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ])
|
||||||
|
names = [ f'{dir}/{d}_gpt.pth' for d in counts ]
|
||||||
|
return names[-1]
|
||||||
|
|
||||||
|
if args.autoregressive_model != "auto":
|
||||||
|
return args.autoregressive_model
|
||||||
|
|
||||||
|
return get_model_path('autoregressive.pth')
|
||||||
|
|
||||||
def update_autoregressive_model(autoregressive_model_path):
|
def update_autoregressive_model(autoregressive_model_path):
|
||||||
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
|
match = re.findall(r'^\[[a-fA-F0-9]{8}\] (.+?)$', autoregressive_model_path)
|
||||||
if match:
|
if match:
|
||||||
|
@ -1880,10 +1912,13 @@ def update_autoregressive_model(autoregressive_model_path):
|
||||||
if hasattr(tts, "loading") and tts.loading:
|
if hasattr(tts, "loading") and tts.loading:
|
||||||
raise Exception("TTS is still initializing...")
|
raise Exception("TTS is still initializing...")
|
||||||
|
|
||||||
|
if autoregressive_model_path == "auto":
|
||||||
|
autoregressive_model_path = deduce_autoregressive_model()
|
||||||
|
|
||||||
|
if autoregressive_model_path == tts.autoregressive_model_path:
|
||||||
|
return
|
||||||
|
|
||||||
print(f"Loading model: {autoregressive_model_path}")
|
|
||||||
tts.load_autoregressive_model(autoregressive_model_path)
|
tts.load_autoregressive_model(autoregressive_model_path)
|
||||||
print(f"Loaded model: {tts.autoregressive_model_path}")
|
|
||||||
|
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit e2db36af602297501132f7f68331755f5904825a
|
Subproject commit 26133c20314b77155e77be804b43909dab9809d6
|
Loading…
Reference in New Issue
Block a user