fried my brain trying to nail out bugs involving using solely ar model=auto

This commit is contained in:
mrq 2023-03-07 05:35:21 +00:00
parent d7a5ad9fd9
commit 2726d98ee1

View File

@ -96,20 +96,17 @@ def generate(
do_gc() do_gc()
voices = {}
voice_samples = None voice_samples = None
conditioning_latents =None conditioning_latents =None
sample_voice = None sample_voice = None
def fetch_voice( requested ): if seed == 0:
voice = requested seed = None
if voice in voices:
return voices[voice]
def fetch_voice( voice ):
print(f"Loading voice: {voice}") print(f"Loading voice: {voice}")
sample_voice = None
if voice == "microphone": if voice == "microphone":
if mic_audio is None: if mic_audio is None:
raise Exception("Please provide audio from mic when choosing `microphone` as a voice input") raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
@ -117,37 +114,19 @@ def generate(
elif voice == "random": elif voice == "random":
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents() voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
else: else:
progress(0, desc=f"Loading voice: {voice}") if progress is not None:
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts progress(0, desc=f"Loading voice: {voice}")
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash) voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
if voice_samples and len(voice_samples) > 0: if voice_samples and len(voice_samples) > 0:
conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks) if conditioning_latents is None:
conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks)
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu() sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
voice_samples = None voice_samples = None
else:
if conditioning_latents is not None:
sample_voice, _ = load_voice(voice, load_latents=False)
if sample_voice and len(sample_voice) > 0:
sample_voice = torch.cat(sample_voice, dim=-1).squeeze().cpu()
else:
sample_voice = None
voices[voice] = (voice_samples, conditioning_latents, sample_voice) return (voice_samples, conditioning_latents, sample_voice)
return voices[voice]
voice_samples, conditioning_latents, sample_voice = fetch_voice(voice)
if seed == 0:
seed = None
if conditioning_latents is not None and len(conditioning_latents) == 2 and cvvp_weight > 0:
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
cvvp_weight = 0
autoregressive_model = args.autoregressive_model
if autoregressive_model == "auto":
autoregressive_model = deduce_autoregressive_model(voice)
def get_settings( override=None ): def get_settings( override=None ):
settings = { settings = {
@ -163,8 +142,8 @@ def generate(
'sample_batch_size': args.sample_batch_size, 'sample_batch_size': args.sample_batch_size,
'diffusion_iterations': diffusion_iterations, 'diffusion_iterations': diffusion_iterations,
'voice_samples': voice_samples, 'voice_samples': None,
'conditioning_latents': conditioning_latents, 'conditioning_latents': None,
'use_deterministic_seed': seed, 'use_deterministic_seed': seed,
'return_deterministic_state': True, 'return_deterministic_state': True,
@ -175,31 +154,26 @@ def generate(
'half_p': "Half Precision" in experimental_checkboxes, 'half_p': "Half Precision" in experimental_checkboxes,
'cond_free': "Conditioning-Free" in experimental_checkboxes, 'cond_free': "Conditioning-Free" in experimental_checkboxes,
'cvvp_amount': cvvp_weight, 'cvvp_amount': cvvp_weight,
'autoregressive_model': autoregressive_model, 'autoregressive_model': args.autoregressive_model,
} }
# could be better to just do a ternary on everything above, but i am not a professional # could be better to just do a ternary on everything above, but i am not a professional
selected_voice = voice
if override is not None: if override is not None:
if 'voice' in override: if 'voice' in override:
voice = override['voice'] selected_voice = override['voice']
if "autoregressive_model" in override:
if override["autoregressive_model"] == "auto":
override["autoregressive_model"] = deduce_autoregressive_model(voice)
tts.load_autoregressive_model( override["autoregressive_model"] )
fetched = fetch_voice(voice)
settings['voice_samples'] = fetched[0]
settings['conditioning_latents'] = fetched[1]
for k in override: for k in override:
if k not in settings: if k not in settings:
continue continue
settings[k] = override[k] settings[k] = override[k]
tts.load_autoregressive_model( settings["autoregressive_model"] ) if settings['autoregressive_model'] is not None:
if settings['autoregressive_model'] == "auto":
settings['autoregressive_model'] = deduce_autoregressive_model(selected_voice)
tts.load_autoregressive_model(settings['autoregressive_model'])
settings['voice_samples'], settings['conditioning_latents'], _ = fetch_voice(voice=selected_voice)
# clamp it down for the insane users who want this # clamp it down for the insane users who want this
# it would be wiser to enforce the sample size to the batch size, but this is what the user wants # it would be wiser to enforce the sample size to the batch size, but this is what the user wants
@ -209,9 +183,11 @@ def generate(
if num_autoregressive_samples < sample_batch_size: if num_autoregressive_samples < sample_batch_size:
settings['sample_batch_size'] = num_autoregressive_samples settings['sample_batch_size'] = num_autoregressive_samples
return settings if settings['conditioning_latents'] is not None and len(settings['conditioning_latents']) == 2 and settings['cvvp_amount'] > 0:
print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.")
settings['cvvp_amount'] = 0
settings = get_settings() return settings
if not delimiter: if not delimiter:
delimiter = "\n" delimiter = "\n"
@ -355,16 +331,12 @@ def generate(
match = match[0] match = match[0]
try: try:
override = json.loads(match[0]) override = json.loads(match[0])
cut_text = match[1].strip()
except Exception as e: except Exception as e:
print(e)
raise Exception("Prompt settings editing requested, but received invalid JSON") raise Exception("Prompt settings editing requested, but received invalid JSON")
cut_text = match[1].strip() settings = get_settings( override=override )
used_settings = get_settings( override ) gen, additionals = tts.tts(cut_text, **settings )
else:
used_settings = settings.copy()
gen, additionals = tts.tts(cut_text, **used_settings )
seed = additionals[0] seed = additionals[0]
run_time = time.time()-start_time run_time = time.time()-start_time
@ -377,15 +349,15 @@ def generate(
audio = g.squeeze(0).cpu() audio = g.squeeze(0).cpu()
name = get_name(line=line, candidate=j) name = get_name(line=line, candidate=j)
used_settings['text'] = cut_text settings['text'] = cut_text
used_settings['time'] = run_time settings['time'] = run_time
used_settings['datetime'] = datetime.now().isoformat(), settings['datetime'] = datetime.now().isoformat(),
used_settings['model'] = tts.autoregressive_model_path settings['model'] = tts.autoregressive_model_path
used_settings['model_hash'] = tts.autoregressive_model_hash settings['model_hash'] = tts.autoregressive_model_hash
audio_cache[name] = { audio_cache[name] = {
'audio': audio, 'audio': audio,
'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=used_settings) 'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings)
} }
# save here in case some error happens mid-batch # save here in case some error happens mid-batch
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate) torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate)
@ -485,7 +457,7 @@ def generate(
info = get_info(voice=voice, latents=False) info = get_info(voice=voice, latents=False)
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n") print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
info['seed'] = settings['use_deterministic_seed'] info['seed'] = seed
if 'latents' in info: if 'latents' in info:
del info['latents'] del info['latents']
@ -620,7 +592,9 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog
if len(conditioning_latents) == 4: if len(conditioning_latents) == 4:
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth') outfile = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
torch.save(conditioning_latents, outfile)
print(f'Saved voice latents: {outfile}')
return conditioning_latents return conditioning_latents