forked from mrq/ai-voice-cloning
fried my brain trying to nail out bugs involving using solely ar model=auto
This commit is contained in:
parent
d7a5ad9fd9
commit
2726d98ee1
108
src/utils.py
108
src/utils.py
|
@ -96,20 +96,17 @@ def generate(
|
||||||
|
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
voices = {}
|
|
||||||
|
|
||||||
voice_samples = None
|
voice_samples = None
|
||||||
conditioning_latents =None
|
conditioning_latents =None
|
||||||
sample_voice = None
|
sample_voice = None
|
||||||
|
|
||||||
def fetch_voice( requested ):
|
if seed == 0:
|
||||||
voice = requested
|
seed = None
|
||||||
|
|
||||||
if voice in voices:
|
|
||||||
return voices[voice]
|
|
||||||
|
|
||||||
|
def fetch_voice( voice ):
|
||||||
print(f"Loading voice: {voice}")
|
print(f"Loading voice: {voice}")
|
||||||
|
|
||||||
|
sample_voice = None
|
||||||
if voice == "microphone":
|
if voice == "microphone":
|
||||||
if mic_audio is None:
|
if mic_audio is None:
|
||||||
raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
|
raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
|
||||||
|
@ -117,37 +114,19 @@ def generate(
|
||||||
elif voice == "random":
|
elif voice == "random":
|
||||||
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
|
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
|
||||||
else:
|
else:
|
||||||
progress(0, desc=f"Loading voice: {voice}")
|
if progress is not None:
|
||||||
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
|
progress(0, desc=f"Loading voice: {voice}")
|
||||||
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
|
|
||||||
|
|
||||||
|
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
|
||||||
|
|
||||||
if voice_samples and len(voice_samples) > 0:
|
if voice_samples and len(voice_samples) > 0:
|
||||||
conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks)
|
if conditioning_latents is None:
|
||||||
|
conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks)
|
||||||
|
|
||||||
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
||||||
voice_samples = None
|
voice_samples = None
|
||||||
else:
|
|
||||||
if conditioning_latents is not None:
|
|
||||||
sample_voice, _ = load_voice(voice, load_latents=False)
|
|
||||||
if sample_voice and len(sample_voice) > 0:
|
|
||||||
sample_voice = torch.cat(sample_voice, dim=-1).squeeze().cpu()
|
|
||||||
else:
|
|
||||||
sample_voice = None
|
|
||||||
|
|
||||||
voices[voice] = (voice_samples, conditioning_latents, sample_voice)
|
return (voice_samples, conditioning_latents, sample_voice)
|
||||||
return voices[voice]
|
|
||||||
|
|
||||||
voice_samples, conditioning_latents, sample_voice = fetch_voice(voice)
|
|
||||||
|
|
||||||
if seed == 0:
|
|
||||||
seed = None
|
|
||||||
|
|
||||||
if conditioning_latents is not None and len(conditioning_latents) == 2 and cvvp_weight > 0:
|
|
||||||
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
|
|
||||||
cvvp_weight = 0
|
|
||||||
|
|
||||||
autoregressive_model = args.autoregressive_model
|
|
||||||
if autoregressive_model == "auto":
|
|
||||||
autoregressive_model = deduce_autoregressive_model(voice)
|
|
||||||
|
|
||||||
def get_settings( override=None ):
|
def get_settings( override=None ):
|
||||||
settings = {
|
settings = {
|
||||||
|
@ -163,8 +142,8 @@ def generate(
|
||||||
'sample_batch_size': args.sample_batch_size,
|
'sample_batch_size': args.sample_batch_size,
|
||||||
'diffusion_iterations': diffusion_iterations,
|
'diffusion_iterations': diffusion_iterations,
|
||||||
|
|
||||||
'voice_samples': voice_samples,
|
'voice_samples': None,
|
||||||
'conditioning_latents': conditioning_latents,
|
'conditioning_latents': None,
|
||||||
|
|
||||||
'use_deterministic_seed': seed,
|
'use_deterministic_seed': seed,
|
||||||
'return_deterministic_state': True,
|
'return_deterministic_state': True,
|
||||||
|
@ -175,31 +154,26 @@ def generate(
|
||||||
'half_p': "Half Precision" in experimental_checkboxes,
|
'half_p': "Half Precision" in experimental_checkboxes,
|
||||||
'cond_free': "Conditioning-Free" in experimental_checkboxes,
|
'cond_free': "Conditioning-Free" in experimental_checkboxes,
|
||||||
'cvvp_amount': cvvp_weight,
|
'cvvp_amount': cvvp_weight,
|
||||||
'autoregressive_model': autoregressive_model,
|
'autoregressive_model': args.autoregressive_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
# could be better to just do a ternary on everything above, but i am not a professional
|
# could be better to just do a ternary on everything above, but i am not a professional
|
||||||
|
selected_voice = voice
|
||||||
if override is not None:
|
if override is not None:
|
||||||
if 'voice' in override:
|
if 'voice' in override:
|
||||||
voice = override['voice']
|
selected_voice = override['voice']
|
||||||
|
|
||||||
if "autoregressive_model" in override:
|
|
||||||
if override["autoregressive_model"] == "auto":
|
|
||||||
override["autoregressive_model"] = deduce_autoregressive_model(voice)
|
|
||||||
|
|
||||||
tts.load_autoregressive_model( override["autoregressive_model"] )
|
|
||||||
|
|
||||||
fetched = fetch_voice(voice)
|
|
||||||
|
|
||||||
settings['voice_samples'] = fetched[0]
|
|
||||||
settings['conditioning_latents'] = fetched[1]
|
|
||||||
|
|
||||||
for k in override:
|
for k in override:
|
||||||
if k not in settings:
|
if k not in settings:
|
||||||
continue
|
continue
|
||||||
settings[k] = override[k]
|
settings[k] = override[k]
|
||||||
|
|
||||||
tts.load_autoregressive_model( settings["autoregressive_model"] )
|
if settings['autoregressive_model'] is not None:
|
||||||
|
if settings['autoregressive_model'] == "auto":
|
||||||
|
settings['autoregressive_model'] = deduce_autoregressive_model(selected_voice)
|
||||||
|
tts.load_autoregressive_model(settings['autoregressive_model'])
|
||||||
|
|
||||||
|
settings['voice_samples'], settings['conditioning_latents'], _ = fetch_voice(voice=selected_voice)
|
||||||
|
|
||||||
# clamp it down for the insane users who want this
|
# clamp it down for the insane users who want this
|
||||||
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
|
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants
|
||||||
|
@ -208,11 +182,13 @@ def generate(
|
||||||
sample_batch_size = tts.autoregressive_batch_size
|
sample_batch_size = tts.autoregressive_batch_size
|
||||||
if num_autoregressive_samples < sample_batch_size:
|
if num_autoregressive_samples < sample_batch_size:
|
||||||
settings['sample_batch_size'] = num_autoregressive_samples
|
settings['sample_batch_size'] = num_autoregressive_samples
|
||||||
|
|
||||||
|
if settings['conditioning_latents'] is not None and len(settings['conditioning_latents']) == 2 and settings['cvvp_amount'] > 0:
|
||||||
|
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
|
||||||
|
settings['cvvp_amount'] = 0
|
||||||
|
|
||||||
return settings
|
return settings
|
||||||
|
|
||||||
settings = get_settings()
|
|
||||||
|
|
||||||
if not delimiter:
|
if not delimiter:
|
||||||
delimiter = "\n"
|
delimiter = "\n"
|
||||||
elif delimiter == "\\n":
|
elif delimiter == "\\n":
|
||||||
|
@ -355,16 +331,12 @@ def generate(
|
||||||
match = match[0]
|
match = match[0]
|
||||||
try:
|
try:
|
||||||
override = json.loads(match[0])
|
override = json.loads(match[0])
|
||||||
|
cut_text = match[1].strip()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
|
||||||
raise Exception("Prompt settings editing requested, but received invalid JSON")
|
raise Exception("Prompt settings editing requested, but received invalid JSON")
|
||||||
|
|
||||||
cut_text = match[1].strip()
|
settings = get_settings( override=override )
|
||||||
used_settings = get_settings( override )
|
gen, additionals = tts.tts(cut_text, **settings )
|
||||||
else:
|
|
||||||
used_settings = settings.copy()
|
|
||||||
|
|
||||||
gen, additionals = tts.tts(cut_text, **used_settings )
|
|
||||||
|
|
||||||
seed = additionals[0]
|
seed = additionals[0]
|
||||||
run_time = time.time()-start_time
|
run_time = time.time()-start_time
|
||||||
|
@ -377,15 +349,15 @@ def generate(
|
||||||
audio = g.squeeze(0).cpu()
|
audio = g.squeeze(0).cpu()
|
||||||
name = get_name(line=line, candidate=j)
|
name = get_name(line=line, candidate=j)
|
||||||
|
|
||||||
used_settings['text'] = cut_text
|
settings['text'] = cut_text
|
||||||
used_settings['time'] = run_time
|
settings['time'] = run_time
|
||||||
used_settings['datetime'] = datetime.now().isoformat(),
|
settings['datetime'] = datetime.now().isoformat(),
|
||||||
used_settings['model'] = tts.autoregressive_model_path
|
settings['model'] = tts.autoregressive_model_path
|
||||||
used_settings['model_hash'] = tts.autoregressive_model_hash
|
settings['model_hash'] = tts.autoregressive_model_hash
|
||||||
|
|
||||||
audio_cache[name] = {
|
audio_cache[name] = {
|
||||||
'audio': audio,
|
'audio': audio,
|
||||||
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=used_settings)
|
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
|
||||||
}
|
}
|
||||||
# save here in case some error happens mid-batch
|
# save here in case some error happens mid-batch
|
||||||
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate)
|
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate)
|
||||||
|
@ -485,7 +457,7 @@ def generate(
|
||||||
info = get_info(voice=voice, latents=False)
|
info = get_info(voice=voice, latents=False)
|
||||||
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
|
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
|
||||||
|
|
||||||
info['seed'] = settings['use_deterministic_seed']
|
info['seed'] = seed
|
||||||
if 'latents' in info:
|
if 'latents' in info:
|
||||||
del info['latents']
|
del info['latents']
|
||||||
|
|
||||||
|
@ -619,8 +591,10 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
|
||||||
|
|
||||||
if len(conditioning_latents) == 4:
|
if len(conditioning_latents) == 4:
|
||||||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||||
|
|
||||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
outfile = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
|
||||||
|
torch.save(conditioning_latents, outfile)
|
||||||
|
print(f'Saved voice latents: {outfile}')
|
||||||
|
|
||||||
return conditioning_latents
|
return conditioning_latents
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user