From 2726d98ee1505312a57767074a650a1326313468 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 7 Mar 2023 05:35:21 +0000 Subject: [PATCH] fried my brain trying to nail out bugs involving using solely ar model=auto --- src/utils.py | 108 +++++++++++++++++++-------------------------------- 1 file changed, 41 insertions(+), 67 deletions(-) diff --git a/src/utils.py b/src/utils.py index 0aec5a8..fecdb0a 100755 --- a/src/utils.py +++ b/src/utils.py @@ -96,20 +96,17 @@ def generate( do_gc() - voices = {} - voice_samples = None conditioning_latents =None sample_voice = None - def fetch_voice( requested ): - voice = requested - - if voice in voices: - return voices[voice] + if seed == 0: + seed = None + def fetch_voice( voice ): print(f"Loading voice: {voice}") + sample_voice = None if voice == "microphone": if mic_audio is None: raise Exception("Please provide audio from mic when choosing `microphone` as a voice input") @@ -117,37 +114,19 @@ def generate( elif voice == "random": voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents() else: - progress(0, desc=f"Loading voice: {voice}") - # nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts - voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash) + if progress is not None: + progress(0, desc=f"Loading voice: {voice}") + voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash) + if voice_samples and len(voice_samples) > 0: - conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks) + if conditioning_latents is None: + conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=voice_latents_chunks) + sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu() voice_samples = None - else: - if conditioning_latents is not None: - sample_voice, _ = load_voice(voice, load_latents=False) - if sample_voice and len(sample_voice) > 0: - sample_voice = torch.cat(sample_voice, dim=-1).squeeze().cpu() - else: - sample_voice = None - voices[voice] = (voice_samples, conditioning_latents, sample_voice) - return voices[voice] - - voice_samples, conditioning_latents, sample_voice = fetch_voice(voice) - - if seed == 0: - seed = None - - if conditioning_latents is not None and len(conditioning_latents) == 2 and cvvp_weight > 0: - print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.") - cvvp_weight = 0 - - autoregressive_model = args.autoregressive_model - if autoregressive_model == "auto": - autoregressive_model = deduce_autoregressive_model(voice) + return (voice_samples, conditioning_latents, sample_voice) def get_settings( override=None ): settings = { @@ -163,8 +142,8 @@ def generate( 'sample_batch_size': args.sample_batch_size, 'diffusion_iterations': diffusion_iterations, - 'voice_samples': voice_samples, - 'conditioning_latents': conditioning_latents, + 'voice_samples': None, + 'conditioning_latents': None, 'use_deterministic_seed': seed, 'return_deterministic_state': True, @@ -175,31 +154,26 @@ def generate( 'half_p': "Half Precision" in experimental_checkboxes, 'cond_free': "Conditioning-Free" in experimental_checkboxes, 'cvvp_amount': cvvp_weight, - 'autoregressive_model': autoregressive_model, + 'autoregressive_model': args.autoregressive_model, } # could be better to just do a ternary on everything above, but i am not a professional + selected_voice = voice if override is not None: if 'voice' in override: - voice = override['voice'] - - if "autoregressive_model" in override: - if override["autoregressive_model"] == "auto": - override["autoregressive_model"] = deduce_autoregressive_model(voice) - - tts.load_autoregressive_model( override["autoregressive_model"] ) - - fetched = fetch_voice(voice) - - settings['voice_samples'] = fetched[0] - settings['conditioning_latents'] = fetched[1] + selected_voice = override['voice'] for k in override: if k not in settings: continue settings[k] = override[k] - tts.load_autoregressive_model( settings["autoregressive_model"] ) + if settings['autoregressive_model'] is not None: + if settings['autoregressive_model'] == "auto": + settings['autoregressive_model'] = deduce_autoregressive_model(selected_voice) + tts.load_autoregressive_model(settings['autoregressive_model']) + + settings['voice_samples'], settings['conditioning_latents'], _ = fetch_voice(voice=selected_voice) # clamp it down for the insane users who want this # it would be wiser to enforce the sample size to the batch size, but this is what the user wants @@ -208,11 +182,13 @@ def generate( sample_batch_size = tts.autoregressive_batch_size if num_autoregressive_samples < sample_batch_size: settings['sample_batch_size'] = num_autoregressive_samples + + if settings['conditioning_latents'] is not None and len(settings['conditioning_latents']) == 2 and settings['cvvp_amount'] > 0: + print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.") + settings['cvvp_amount'] = 0 return settings - settings = get_settings() - if not delimiter: delimiter = "\n" elif delimiter == "\\n": @@ -355,16 +331,12 @@ def generate( match = match[0] try: override = json.loads(match[0]) + cut_text = match[1].strip() except Exception as e: - print(e) raise Exception("Prompt settings editing requested, but received invalid JSON") - cut_text = match[1].strip() - used_settings = get_settings( override ) - else: - used_settings = settings.copy() - - gen, additionals = tts.tts(cut_text, **used_settings ) + settings = get_settings( override=override ) + gen, additionals = tts.tts(cut_text, **settings ) seed = additionals[0] run_time = time.time()-start_time @@ -377,15 +349,15 @@ def generate( audio = g.squeeze(0).cpu() name = get_name(line=line, candidate=j) - used_settings['text'] = cut_text - used_settings['time'] = run_time - used_settings['datetime'] = datetime.now().isoformat(), - used_settings['model'] = tts.autoregressive_model_path - used_settings['model_hash'] = tts.autoregressive_model_hash + settings['text'] = cut_text + settings['time'] = run_time + settings['datetime'] = datetime.now().isoformat(), + settings['model'] = tts.autoregressive_model_path + settings['model_hash'] = tts.autoregressive_model_hash audio_cache[name] = { 'audio': audio, - 'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=used_settings) + 'settings': get_info(voice=override['voice'] if override and 'voice' in override else voice, settings=settings) } # save here in case some error happens mid-batch torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, tts.output_sample_rate) @@ -485,7 +457,7 @@ def generate( info = get_info(voice=voice, latents=False) print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n") - info['seed'] = settings['use_deterministic_seed'] + info['seed'] = seed if 'latents' in info: del info['latents'] @@ -619,8 +591,10 @@ def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, prog if len(conditioning_latents) == 4: conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) - - torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth') + + outfile = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth' + torch.save(conditioning_latents, outfile) + print(f'Saved voice latents: {outfile}') return conditioning_latents