I finally figured out how to fix gr.Dropdown.change, so a lot of dumb UI decisions are fixed and makes sense

This commit is contained in:
mrq 2023-02-21 03:00:45 +00:00
parent 7d1936adad
commit bbc2d26289
2 changed files with 88 additions and 52 deletions

View File

@ -38,11 +38,13 @@ MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/370
args = None args = None
tts = None tts = None
tts_loading = False
webui = None webui = None
voicefixer = None voicefixer = None
whisper_model = None whisper_model = None
training_process = None training_process = None
def generate( def generate(
text, text,
delimiter, delimiter,
@ -70,15 +72,17 @@ def generate(
global args global args
global tts global tts
if not tts:
# should check if it's loading or unloaded, and load it if it's unloaded
raise Exception("TTS is uninitialized or still initializing...")
do_gc()
unload_whisper() unload_whisper()
unload_voicefixer() unload_voicefixer()
if not tts:
# should check if it's loading or unloaded, and load it if it's unloaded
if tts_loading:
raise Exception("TTS is still initializing...")
load_tts()
do_gc()
if voice != "microphone": if voice != "microphone":
voices = [voice] voices = [voice]
else: else:
@ -365,13 +369,15 @@ def cancel_generate():
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)): def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
global tts global tts
global args global args
if not tts:
raise Exception("TTS is uninitialized or still initializing...")
unload_whisper() unload_whisper()
unload_voicefixer() unload_voicefixer()
if not tts:
if tts_loading:
raise Exception("TTS is still initializing...")
load_tts()
voice_samples, conditioning_latents = load_voice(voice, load_latents=False) voice_samples, conditioning_latents = load_voice(voice, load_latents=False)
if voice_samples is None: if voice_samples is None:
@ -953,10 +959,13 @@ def reset_generation_settings():
f.write(json.dumps({}, indent='\t') ) f.write(json.dumps({}, indent='\t') )
return import_generate_settings() return import_generate_settings()
def read_generate_settings(file, read_latents=True, read_json=True): def read_generate_settings(file, read_latents=True):
j = None j = None
latents = None latents = None
if isinstance(file, list) and len(file) == 1:
file = file[0]
if file is not None: if file is not None:
if hasattr(file, 'name'): if hasattr(file, 'name'):
file = file.name file = file.name
@ -981,24 +990,33 @@ def read_generate_settings(file, read_latents=True, read_json=True):
if "time" in j: if "time" in j:
j["time"] = "{:.3f}".format(j["time"]) j["time"] = "{:.3f}".format(j["time"])
return ( return (
j, j,
latents, latents,
) )
def load_tts(restart=False): def load_tts( restart=False, model=None ):
global args global args
global tts global tts
if restart: if restart:
unload_tts() unload_tts()
if model:
args.autoregressive_model = model
print(f"Loading TorToiSe... (using model: {args.autoregressive_model})") print(f"Loading TorToiSe... (using model: {args.autoregressive_model})")
tts_loading = True
try: try:
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model) tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model)
except Exception as e: except Exception as e:
tts = TextToSpeech(minor_optimizations=not args.low_vram) tts = TextToSpeech(minor_optimizations=not args.low_vram)
load_autoregressive_model(args.autoregressive_model) load_autoregressive_model(args.autoregressive_model)
tts_loading = False
get_model_path('dvae.pth') get_model_path('dvae.pth')
print("Loaded TorToiSe, ready for generation.") print("Loaded TorToiSe, ready for generation.")
@ -1015,17 +1033,24 @@ def unload_tts():
tts = None tts = None
do_gc() do_gc()
def reload_tts(): def reload_tts( model=None ):
setup_tortoise(restart=True) load_tts( restart=True, model=model )
def update_autoregressive_model(autoregressive_model_path): def update_autoregressive_model(autoregressive_model_path):
if not autoregressive_model_path or not os.path.exists(autoregressive_model_path):
return
args.autoregressive_model = autoregressive_model_path args.autoregressive_model = autoregressive_model_path
save_args_settings() save_args_settings()
print(f'Stored autoregressive model to settings: {autoregressive_model_path}') print(f'Stored autoregressive model to settings: {autoregressive_model_path}')
global tts global tts
if not tts: if not tts:
raise Exception("TTS is uninitialized or still initializing...") if tts_loading:
raise Exception("TTS is still initializing...")
load_tts( model=autoregressive_model_path )
return # redundant to proceed onward
print(f"Loading model: {autoregressive_model_path}") print(f"Loading model: {autoregressive_model_path}")
@ -1099,6 +1124,9 @@ def unload_whisper():
do_gc() do_gc()
def update_whisper_model(name, progress=None): def update_whisper_model(name, progress=None):
if not name:
return
global whisper_model global whisper_model
if whisper_model: if whisper_model:
unload_whisper() unload_whisper()

View File

@ -82,7 +82,6 @@ def run_generation(
outputs[0], outputs[0],
gr.update(value=sample, visible=sample is not None), gr.update(value=sample, visible=sample is not None),
gr.update(choices=outputs, value=outputs[0], visible=len(outputs) > 1, interactive=True), gr.update(choices=outputs, value=outputs[0], visible=len(outputs) > 1, interactive=True),
gr.update(visible=len(outputs) > 1),
gr.update(value=stats, visible=True), gr.update(value=stats, visible=True),
) )
@ -170,6 +169,8 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
latents = f'{outdir}/cond_latents.pth' latents = f'{outdir}/cond_latents.pth'
print(j, latents)
return ( return (
gr.update(value=j, visible=j is not None), gr.update(value=j, visible=j is not None),
gr.update(visible=j is not None), gr.update(visible=j is not None),
@ -244,12 +245,6 @@ def update_voices():
def history_copy_settings( voice, file ): def history_copy_settings( voice, file ):
return import_generate_settings( f"./results/{voice}/{file}" ) return import_generate_settings( f"./results/{voice}/{file}" )
def update_model_settings( autoregressive_model, whisper_model ):
update_autoregressive_model(autoregressive_model)
update_whisper_model(whisper_model)
save_args_settings()
def setup_gradio(): def setup_gradio():
global args global args
global ui global ui
@ -318,19 +313,30 @@ def setup_gradio():
generation_results = gr.Dataframe(label="Results", headers=["Seed", "Time"], visible=False) generation_results = gr.Dataframe(label="Results", headers=["Seed", "Time"], visible=False)
source_sample = gr.Audio(label="Source Sample", visible=False) source_sample = gr.Audio(label="Source Sample", visible=False)
output_audio = gr.Audio(label="Output") output_audio = gr.Audio(label="Output")
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False) candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False, choices=[""], value="")
output_pick = gr.Button(value="Select Candidate", visible=False) # output_pick = gr.Button(value="Select Candidate", visible=False)
def change_candidate( val ):
if not val:
return
print(val)
return val
candidates_list.change(
fn=change_candidate,
inputs=candidates_list,
outputs=output_audio,
)
with gr.Tab("History"): with gr.Tab("History"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys())) history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys()))
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
history_voices = gr.Dropdown(choices=get_voice_list("./results/"), label="Voice", type="value") result_voices = get_voice_list("./results/")
history_view_results_button = gr.Button(value="View Files") history_voices = gr.Dropdown(choices=result_voices, label="Voice", type="value", value=result_voices[0])
with gr.Column(): with gr.Column():
history_results_list = gr.Dropdown(label="Results",type="value", interactive=True) history_results_list = gr.Dropdown(label="Results",type="value", interactive=True, value="")
history_view_result_button = gr.Button(value="View File")
with gr.Column(): with gr.Column():
history_audio = gr.Audio() history_audio = gr.Audio()
history_copy_settings_button = gr.Button(value="Copy Settings") history_copy_settings_button = gr.Button(value="Copy Settings")
@ -407,10 +413,10 @@ def setup_gradio():
gr.Checkbox(label="Low VRAM", value=args.low_vram), gr.Checkbox(label="Low VRAM", value=args.low_vram),
gr.Checkbox(label="Embed Output Metadata", value=args.embed_output_metadata), gr.Checkbox(label="Embed Output Metadata", value=args.embed_output_metadata),
gr.Checkbox(label="Slimmer Computed Latents", value=args.latents_lean_and_mean), gr.Checkbox(label="Slimmer Computed Latents", value=args.latents_lean_and_mean),
gr.Checkbox(label="Voice Fixer", value=args.voice_fixer), gr.Checkbox(label="Use Voice Fixer on Generated Output", value=args.voice_fixer),
gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda), gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda),
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents), gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
gr.Checkbox(label="Defer TTS Load", value=args.defer_tts_load), gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load),
gr.Textbox(label="Device Override", value=args.device_override), gr.Textbox(label="Device Override", value=args.device_override),
] ]
with gr.Column(): with gr.Column():
@ -421,12 +427,28 @@ def setup_gradio():
gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume), gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume),
] ]
autoregressive_model_dropdown = gr.Dropdown(get_autoregressive_models(), label="Autoregressive Model", value=args.autoregressive_model) autoregressive_models = get_autoregressive_models()
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model) whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model)
save_settings_button = gr.Button(value="Save Settings")
autoregressive_model_dropdown.change(
fn=update_autoregressive_model,
inputs=autoregressive_model_dropdown,
outputs=None
)
whisper_model_dropdown.change(
fn=update_whisper_model,
inputs=whisper_model_dropdown,
outputs=None
)
gr.Button(value="Check for Updates").click(check_for_updates) gr.Button(value="Check for Updates").click(check_for_updates)
gr.Button(value="(Re)Load TTS").click(reload_tts) gr.Button(value="(Re)Load TTS").click(
reload_tts,
inputs=autoregressive_model_dropdown,
outputs=None
)
for i in exec_inputs: for i in exec_inputs:
i.change( fn=update_args, inputs=exec_inputs ) i.change( fn=update_args, inputs=exec_inputs )
@ -457,7 +479,7 @@ def setup_gradio():
experimental_checkboxes, experimental_checkboxes,
] ]
history_view_results_button.click( history_voices.change(
fn=history_view_results, fn=history_view_results,
inputs=history_voices, inputs=history_voices,
outputs=[ outputs=[
@ -465,7 +487,7 @@ def setup_gradio():
history_results_list, history_results_list,
] ]
) )
history_view_result_button.click( history_results_list.change(
fn=lambda voice, file: f"./results/{voice}/{file}", fn=lambda voice, file: f"./results/{voice}/{file}",
inputs=[ inputs=[
history_voices, history_voices,
@ -531,20 +553,14 @@ def setup_gradio():
] ]
) )
output_pick.click(
lambda x: x,
inputs=candidates_list,
outputs=output_audio,
)
submit.click( submit.click(
lambda: (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)), lambda: (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)),
outputs=[source_sample, candidates_list, output_pick, generation_results], outputs=[source_sample, candidates_list, generation_results],
) )
submit_event = submit.click(run_generation, submit_event = submit.click(run_generation,
inputs=input_settings, inputs=input_settings,
outputs=[output_audio, source_sample, candidates_list, output_pick, generation_results], outputs=[output_audio, source_sample, candidates_list, generation_results],
) )
@ -603,14 +619,6 @@ def setup_gradio():
outputs=save_yaml_output #console_output outputs=save_yaml_output #console_output
) )
save_settings_button.click(update_model_settings,
inputs=[
autoregressive_model_dropdown,
whisper_model_dropdown,
],
outputs=None
)
if os.path.isfile('./config/generate.json'): if os.path.isfile('./config/generate.json'):
ui.load(import_generate_settings, inputs=None, outputs=input_settings) ui.load(import_generate_settings, inputs=None, outputs=input_settings)