From bbc2d26289871a924767eabd07e6ae4c88be4362 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 21 Feb 2023 03:00:45 +0000 Subject: [PATCH] I finally figured out how to fix gr.Dropdown.change, so a lot of dumb UI decisions are fixed and makes sense --- src/utils.py | 54 +++++++++++++++++++++++++--------- src/webui.py | 82 ++++++++++++++++++++++++++++------------------------ 2 files changed, 86 insertions(+), 50 deletions(-) diff --git a/src/utils.py b/src/utils.py index edb6487..bfaae25 100755 --- a/src/utils.py +++ b/src/utils.py @@ -38,11 +38,13 @@ MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/370 args = None tts = None +tts_loading = False webui = None voicefixer = None whisper_model = None training_process = None + def generate( text, delimiter, @@ -70,15 +72,17 @@ def generate( global args global tts + unload_whisper() + unload_voicefixer() + if not tts: # should check if it's loading or unloaded, and load it if it's unloaded - raise Exception("TTS is uninitialized or still initializing...") + if tts_loading: + raise Exception("TTS is still initializing...") + load_tts() do_gc() - unload_whisper() - unload_voicefixer() - if voice != "microphone": voices = [voice] else: @@ -365,13 +369,15 @@ def cancel_generate(): def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): global tts global args - - if not tts: - raise Exception("TTS is uninitialized or still initializing...") - + unload_whisper() unload_voicefixer() + if not tts: + if tts_loading: + raise Exception("TTS is still initializing...") + load_tts() + voice_samples, conditioning_latents = load_voice(voice, load_latents=False) if voice_samples is None: @@ -953,10 +959,13 @@ def reset_generation_settings(): f.write(json.dumps({}, indent='\t') ) return import_generate_settings() -def read_generate_settings(file, read_latents=True, read_json=True): +def read_generate_settings(file, read_latents=True): j = None latents = None + if isinstance(file, list) and len(file) == 1: + file = file[0] + if file is not None: if hasattr(file, 'name'): file = file.name @@ -981,24 +990,33 @@ def read_generate_settings(file, read_latents=True, read_json=True): if "time" in j: j["time"] = "{:.3f}".format(j["time"]) + + return ( j, latents, ) -def load_tts(restart=False): +def load_tts( restart=False, model=None ): global args global tts if restart: unload_tts() + + if model: + args.autoregressive_model = model + print(f"Loading TorToiSe... (using model: {args.autoregressive_model})") + + tts_loading = True try: tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model) except Exception as e: tts = TextToSpeech(minor_optimizations=not args.low_vram) load_autoregressive_model(args.autoregressive_model) + tts_loading = False get_model_path('dvae.pth') print("Loaded TorToiSe, ready for generation.") @@ -1015,17 +1033,24 @@ def unload_tts(): tts = None do_gc() -def reload_tts(): - setup_tortoise(restart=True) +def reload_tts( model=None ): + load_tts( restart=True, model=model ) def update_autoregressive_model(autoregressive_model_path): + if not autoregressive_model_path or not os.path.exists(autoregressive_model_path): + return + args.autoregressive_model = autoregressive_model_path save_args_settings() print(f'Stored autoregressive model to settings: {autoregressive_model_path}') global tts if not tts: - raise Exception("TTS is uninitialized or still initializing...") + if tts_loading: + raise Exception("TTS is still initializing...") + + load_tts( model=autoregressive_model_path ) + return # redundant to proceed onward print(f"Loading model: {autoregressive_model_path}") @@ -1099,6 +1124,9 @@ def unload_whisper(): do_gc() def update_whisper_model(name, progress=None): + if not name: + return + global whisper_model if whisper_model: unload_whisper() diff --git a/src/webui.py b/src/webui.py index 83aff5f..49c6210 100755 --- a/src/webui.py +++ b/src/webui.py @@ -82,7 +82,6 @@ def run_generation( outputs[0], gr.update(value=sample, visible=sample is not None), gr.update(choices=outputs, value=outputs[0], visible=len(outputs) > 1, interactive=True), - gr.update(visible=len(outputs) > 1), gr.update(value=stats, visible=True), ) @@ -170,6 +169,8 @@ def read_generate_settings_proxy(file, saveAs='.temp'): latents = f'{outdir}/cond_latents.pth' + print(j, latents) + return ( gr.update(value=j, visible=j is not None), gr.update(visible=j is not None), @@ -244,12 +245,6 @@ def update_voices(): def history_copy_settings( voice, file ): return import_generate_settings( f"./results/{voice}/{file}" ) -def update_model_settings( autoregressive_model, whisper_model ): - update_autoregressive_model(autoregressive_model) - update_whisper_model(whisper_model) - - save_args_settings() - def setup_gradio(): global args global ui @@ -318,19 +313,30 @@ def setup_gradio(): generation_results = gr.Dataframe(label="Results", headers=["Seed", "Time"], visible=False) source_sample = gr.Audio(label="Source Sample", visible=False) output_audio = gr.Audio(label="Output") - candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False) - output_pick = gr.Button(value="Select Candidate", visible=False) + candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False, choices=[""], value="") + # output_pick = gr.Button(value="Select Candidate", visible=False) + + def change_candidate( val ): + if not val: + return + print(val) + return val + + candidates_list.change( + fn=change_candidate, + inputs=candidates_list, + outputs=output_audio, + ) with gr.Tab("History"): with gr.Row(): with gr.Column(): history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys())) with gr.Row(): with gr.Column(): - history_voices = gr.Dropdown(choices=get_voice_list("./results/"), label="Voice", type="value") - history_view_results_button = gr.Button(value="View Files") + result_voices = get_voice_list("./results/") + history_voices = gr.Dropdown(choices=result_voices, label="Voice", type="value", value=result_voices[0]) with gr.Column(): - history_results_list = gr.Dropdown(label="Results",type="value", interactive=True) - history_view_result_button = gr.Button(value="View File") + history_results_list = gr.Dropdown(label="Results",type="value", interactive=True, value="") with gr.Column(): history_audio = gr.Audio() history_copy_settings_button = gr.Button(value="Copy Settings") @@ -407,10 +413,10 @@ def setup_gradio(): gr.Checkbox(label="Low VRAM", value=args.low_vram), gr.Checkbox(label="Embed Output Metadata", value=args.embed_output_metadata), gr.Checkbox(label="Slimmer Computed Latents", value=args.latents_lean_and_mean), - gr.Checkbox(label="Voice Fixer", value=args.voice_fixer), + gr.Checkbox(label="Use Voice Fixer on Generated Output", value=args.voice_fixer), gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda), gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents), - gr.Checkbox(label="Defer TTS Load", value=args.defer_tts_load), + gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load), gr.Textbox(label="Device Override", value=args.device_override), ] with gr.Column(): @@ -421,12 +427,28 @@ def setup_gradio(): gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume), ] - autoregressive_model_dropdown = gr.Dropdown(get_autoregressive_models(), label="Autoregressive Model", value=args.autoregressive_model) + autoregressive_models = get_autoregressive_models() + autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0]) whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model) - save_settings_button = gr.Button(value="Save Settings") + + + autoregressive_model_dropdown.change( + fn=update_autoregressive_model, + inputs=autoregressive_model_dropdown, + outputs=None + ) + whisper_model_dropdown.change( + fn=update_whisper_model, + inputs=whisper_model_dropdown, + outputs=None + ) gr.Button(value="Check for Updates").click(check_for_updates) - gr.Button(value="(Re)Load TTS").click(reload_tts) + gr.Button(value="(Re)Load TTS").click( + reload_tts, + inputs=autoregressive_model_dropdown, + outputs=None + ) for i in exec_inputs: i.change( fn=update_args, inputs=exec_inputs ) @@ -457,7 +479,7 @@ def setup_gradio(): experimental_checkboxes, ] - history_view_results_button.click( + history_voices.change( fn=history_view_results, inputs=history_voices, outputs=[ @@ -465,7 +487,7 @@ def setup_gradio(): history_results_list, ] ) - history_view_result_button.click( + history_results_list.change( fn=lambda voice, file: f"./results/{voice}/{file}", inputs=[ history_voices, @@ -531,20 +553,14 @@ def setup_gradio(): ] ) - output_pick.click( - lambda x: x, - inputs=candidates_list, - outputs=output_audio, - ) - submit.click( - lambda: (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)), - outputs=[source_sample, candidates_list, output_pick, generation_results], + lambda: (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)), + outputs=[source_sample, candidates_list, generation_results], ) submit_event = submit.click(run_generation, inputs=input_settings, - outputs=[output_audio, source_sample, candidates_list, output_pick, generation_results], + outputs=[output_audio, source_sample, candidates_list, generation_results], ) @@ -603,14 +619,6 @@ def setup_gradio(): outputs=save_yaml_output #console_output ) - save_settings_button.click(update_model_settings, - inputs=[ - autoregressive_model_dropdown, - whisper_model_dropdown, - ], - outputs=None - ) - if os.path.isfile('./config/generate.json'): ui.load(import_generate_settings, inputs=None, outputs=input_settings)