From 7798767fc6a00601b1b2ca487738e9c54a1f4db2 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 6 Mar 2023 21:48:34 +0000 Subject: [PATCH] added settings editing (will add a guide on what to do later, and an example) --- src/utils.py | 244 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 153 insertions(+), 91 deletions(-) diff --git a/src/utils.py b/src/utils.py index c4eea7c..0074bc6 100755 --- a/src/utils.py +++ b/src/utils.py @@ -90,46 +90,59 @@ def generate( do_gc() - if voice != "microphone": - voices = [voice] - else: - voices = [] + voices = {} - if voice == "microphone": - if mic_audio is None: - raise Exception("Please provide audio from mic when choosing `microphone` as a voice input") - mic = load_audio(mic_audio, tts.input_sample_rate) - voice_samples, conditioning_latents = [mic], None - elif voice == "random": - voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents() - else: - progress(0, desc="Loading voice...") - # nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts - if hasattr(tts, 'autoregressive_model_hash'): - voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash) + voice_samples = None + conditioning_latents =None + sample_voice = None + + def fetch_voice( requested ): + voice = requested + + if voice in voices: + return voices[voice] + + print(f"Loading voice: {voice}") + + if voice == "microphone": + if mic_audio is None: + raise Exception("Please provide audio from mic when choosing `microphone` as a voice input") + voice_samples, conditioning_latents = [load_audio(mic_audio, tts.input_sample_rate)], None + elif voice == "random": + voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents() else: - voice_samples, conditioning_latents = load_voice(voice) - - if voice_samples and len(voice_samples) > 0: - sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu() - - conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents) - if len(conditioning_latents) == 4: - conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) - - if voice != "microphone": + progress(0, desc=f"Loading voice: {voice}") + # nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts if hasattr(tts, 'autoregressive_model_hash'): - torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth') + voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash) else: - torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') - voice_samples = None - else: - if conditioning_latents is not None: - sample_voice, _ = load_voice(voice, load_latents=False) - if sample_voice and len(sample_voice) > 0: - sample_voice = torch.cat(sample_voice, dim=-1).squeeze().cpu() + voice_samples, conditioning_latents = load_voice(voice) + + if voice_samples and len(voice_samples) > 0: + sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu() + + conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, progress=progress, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents) + if len(conditioning_latents) == 4: + conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None) + + if voice != "microphone": + if hasattr(tts, 'autoregressive_model_hash'): + torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth') + else: + torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth') + voice_samples = None else: - sample_voice = None + if conditioning_latents is not None: + sample_voice, _ = load_voice(voice, load_latents=False) + if sample_voice and len(sample_voice) > 0: + sample_voice = torch.cat(sample_voice, dim=-1).squeeze().cpu() + else: + sample_voice = None + + voices[voice] = (voice_samples, conditioning_latents, sample_voice) + return voices[voice] + + voice_samples, conditioning_latents, sample_voice = fetch_voice(voice) if seed == 0: seed = None @@ -138,42 +151,80 @@ def generate( print("Requesting weighing against CVVP weight, but voice latents are missing some extra data. Please regenerate your voice latents.") cvvp_weight = 0 + def get_settings( override=None ): + settings = { + 'temperature': float(temperature), - settings = { - 'temperature': float(temperature), + 'top_p': float(top_p), + 'diffusion_temperature': float(diffusion_temperature), + 'length_penalty': float(length_penalty), + 'repetition_penalty': float(repetition_penalty), + 'cond_free_k': float(cond_free_k), - 'top_p': float(top_p), - 'diffusion_temperature': float(diffusion_temperature), - 'length_penalty': float(length_penalty), - 'repetition_penalty': float(repetition_penalty), - 'cond_free_k': float(cond_free_k), + 'num_autoregressive_samples': num_autoregressive_samples, + 'sample_batch_size': args.sample_batch_size, + 'diffusion_iterations': diffusion_iterations, - 'num_autoregressive_samples': num_autoregressive_samples, - 'sample_batch_size': args.sample_batch_size, - 'diffusion_iterations': diffusion_iterations, + 'voice_samples': voice_samples, + 'conditioning_latents': conditioning_latents, - 'voice_samples': voice_samples, - 'conditioning_latents': conditioning_latents, - 'use_deterministic_seed': seed, - 'return_deterministic_state': True, - 'k': candidates, - 'diffusion_sampler': diffusion_sampler, - 'breathing_room': breathing_room, - 'progress': progress, - 'half_p': "Half Precision" in experimental_checkboxes, - 'cond_free': "Conditioning-Free" in experimental_checkboxes, - 'cvvp_amount': cvvp_weight, - } + 'use_deterministic_seed': seed, + 'return_deterministic_state': True, + 'k': candidates, + 'diffusion_sampler': diffusion_sampler, + 'breathing_room': breathing_room, + 'progress': progress, + 'half_p': "Half Precision" in experimental_checkboxes, + 'cond_free': "Conditioning-Free" in experimental_checkboxes, + 'cvvp_amount': cvvp_weight, + 'autoregressive_model': args.autoregressive_model, + } - # clamp it down for the insane users who want this - # it would be wiser to enforce the sample size to the batch size, but this is what the user wants - sample_batch_size = args.sample_batch_size - if not sample_batch_size: - sample_batch_size = tts.autoregressive_batch_size - if num_autoregressive_samples < sample_batch_size: - settings['sample_batch_size'] = num_autoregressive_samples + # could be better to just do a ternary on everything above, but i am not a professional + if override is not None: + if 'voice' in override: + voice = override['voice'] - if delimiter is None: + if "autoregressive_model" in override and override["autoregressive_model"] == "auto": + dir = f'./training/{voice}-finetune/models/' + if os.path.exists(f'./training/finetunes/{voice}.pth'): + override["autoregressive_model"] = f'./training/finetunes/{voice}.pth' + elif os.path.isdir(dir): + counts = sorted([ int(d[:-8]) for d in os.listdir(dir) if d[-8:] == "_gpt.pth" ]) + names = [ f'./{dir}/{d}_gpt.pth' for d in counts ] + override["autoregressive_model"] = names[-1] + else: + override["autoregressive_model"] = None + + # necessary to ensure the right model gets loaded for the latents + tts.load_autoregressive_model( override["autoregressive_model"] ) + + fetched = fetch_voice(voice) + + settings['voice_samples'] = fetched[0] + settings['conditioning_latents'] = fetched[1] + + for k in override: + if k not in settings: + continue + settings[k] = override[k] + + if hasattr(tts, 'autoregressive_model_path') and tts.autoregressive_model_path != settings["autoregressive_model"]: + tts.load_autoregressive_model( settings["autoregressive_model"] ) + + # clamp it down for the insane users who want this + # it would be wiser to enforce the sample size to the batch size, but this is what the user wants + sample_batch_size = args.sample_batch_size + if not sample_batch_size: + sample_batch_size = tts.autoregressive_batch_size + if num_autoregressive_samples < sample_batch_size: + settings['sample_batch_size'] = num_autoregressive_samples + + return settings + + settings = get_settings() + + if not delimiter: delimiter = "\n" elif delimiter == "\\n": delimiter = "\n" @@ -189,7 +240,6 @@ def generate( os.makedirs(outdir, exist_ok=True) audio_cache = {} - resample = None if tts.output_sample_rate != args.output_sample_rate: @@ -238,12 +288,28 @@ def generate( cut_text = f"[{prompt},] {cut_text}" elif emotion != "None": cut_text = f"[I am really {emotion.lower()},] {cut_text}" - + progress.msg_prefix = f'[{str(line+1)}/{str(len(texts))}]' print(f"{progress.msg_prefix} Generating line: {cut_text}") - start_time = time.time() - gen, additionals = tts.tts(cut_text, **settings ) + + # do setting editing + match = re.findall(r'^(\{.+\}) (.+?)$', cut_text) + if match and len(match) > 0: + match = match[0] + try: + override = json.loads(match[0]) + except Exception as e: + print(e) + raise Exception("Prompt settings editing requested, but received invalid JSON") + + cut_text = match[1].strip() + new_settings = get_settings( override ) + + gen, additionals = tts.tts(cut_text, **new_settings ) + else: + gen, additionals = tts.tts(cut_text, **settings ) + seed = additionals[0] run_time = time.time()-start_time print(f"Generating line took {run_time} seconds") @@ -327,19 +393,6 @@ def generate( 'model_hash': tts.autoregressive_model_hash if hasattr(tts, 'autoregressive_model_hash') else None, } - """ - # kludgy yucky codesmells - for name in audio_cache: - if 'output' not in audio_cache[name]: - continue - - #output_voices.append(f'{outdir}/{voice}_{name}.wav') - output_voices.append(name) - if not args.embed_output_metadata: - with open(f'{outdir}/{voice}_{name}.json', 'w', encoding="utf-8") as f: - f.write(json.dumps(info, indent='\t') ) - """ - if args.voice_fixer: if not voicefixer: progress(0, "Loading voicefix...") @@ -1057,8 +1110,8 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres files = sorted(files) previous_list = [] - parsed_list = [] if skip_existings and os.path.exists(f'{outdir}/train.txt'): + parsed_list = [] with open(f'{outdir}/train.txt', 'r', encoding="utf-8") as f: parsed_list = f.readlines() @@ -1103,20 +1156,13 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres line = f"{sliced_name}|{segment['text'].strip()}" transcription.append(line) with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f: - f.write(f'{line}\n') + f.write(f'\n{line}') do_gc() with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f: f.write(json.dumps(results, indent='\t')) - if len(parsed_list) > 0: - transcription = parsed_list + transcription - - joined = '\n'.join(transcription) - with open(f'{outdir}/train.txt', 'w', encoding="utf-8") as f: - f.write(joined) - unload_whisper() return f"Processed dataset to: {outdir}\n{joined}" @@ -1688,6 +1734,22 @@ def read_generate_settings(file, read_latents=True): latents, ) +def version_check_tts( min_version ): + global tts + if not tts: + raise Exception("TTS is not initialized") + + if not hasattr(tts, 'version'): + return False + + if min_version[0] > tts.version[0]: + return True + if min_version[1] > tts.version[1]: + return True + if min_version[2] >= tts.version[2]: + return True + return False + def load_tts( restart=False, model=None ): global args global tts