forked from mrq/ai-voice-cloning
I finally figured out how to fix gr.Dropdown.change, so a lot of dumb UI decisions are fixed and makes sense
This commit is contained in:
parent
7d1936adad
commit
bbc2d26289
56
src/utils.py
56
src/utils.py
|
@ -38,11 +38,13 @@ MODELS['dvae.pth'] = "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/370
|
||||||
|
|
||||||
args = None
|
args = None
|
||||||
tts = None
|
tts = None
|
||||||
|
tts_loading = False
|
||||||
webui = None
|
webui = None
|
||||||
voicefixer = None
|
voicefixer = None
|
||||||
whisper_model = None
|
whisper_model = None
|
||||||
training_process = None
|
training_process = None
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
text,
|
text,
|
||||||
delimiter,
|
delimiter,
|
||||||
|
@ -70,15 +72,17 @@ def generate(
|
||||||
global args
|
global args
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
if not tts:
|
|
||||||
# should check if it's loading or unloaded, and load it if it's unloaded
|
|
||||||
raise Exception("TTS is uninitialized or still initializing...")
|
|
||||||
|
|
||||||
do_gc()
|
|
||||||
|
|
||||||
unload_whisper()
|
unload_whisper()
|
||||||
unload_voicefixer()
|
unload_voicefixer()
|
||||||
|
|
||||||
|
if not tts:
|
||||||
|
# should check if it's loading or unloaded, and load it if it's unloaded
|
||||||
|
if tts_loading:
|
||||||
|
raise Exception("TTS is still initializing...")
|
||||||
|
load_tts()
|
||||||
|
|
||||||
|
do_gc()
|
||||||
|
|
||||||
if voice != "microphone":
|
if voice != "microphone":
|
||||||
voices = [voice]
|
voices = [voice]
|
||||||
else:
|
else:
|
||||||
|
@ -366,12 +370,14 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
||||||
global tts
|
global tts
|
||||||
global args
|
global args
|
||||||
|
|
||||||
if not tts:
|
|
||||||
raise Exception("TTS is uninitialized or still initializing...")
|
|
||||||
|
|
||||||
unload_whisper()
|
unload_whisper()
|
||||||
unload_voicefixer()
|
unload_voicefixer()
|
||||||
|
|
||||||
|
if not tts:
|
||||||
|
if tts_loading:
|
||||||
|
raise Exception("TTS is still initializing...")
|
||||||
|
load_tts()
|
||||||
|
|
||||||
voice_samples, conditioning_latents = load_voice(voice, load_latents=False)
|
voice_samples, conditioning_latents = load_voice(voice, load_latents=False)
|
||||||
|
|
||||||
if voice_samples is None:
|
if voice_samples is None:
|
||||||
|
@ -953,10 +959,13 @@ def reset_generation_settings():
|
||||||
f.write(json.dumps({}, indent='\t') )
|
f.write(json.dumps({}, indent='\t') )
|
||||||
return import_generate_settings()
|
return import_generate_settings()
|
||||||
|
|
||||||
def read_generate_settings(file, read_latents=True, read_json=True):
|
def read_generate_settings(file, read_latents=True):
|
||||||
j = None
|
j = None
|
||||||
latents = None
|
latents = None
|
||||||
|
|
||||||
|
if isinstance(file, list) and len(file) == 1:
|
||||||
|
file = file[0]
|
||||||
|
|
||||||
if file is not None:
|
if file is not None:
|
||||||
if hasattr(file, 'name'):
|
if hasattr(file, 'name'):
|
||||||
file = file.name
|
file = file.name
|
||||||
|
@ -981,24 +990,33 @@ def read_generate_settings(file, read_latents=True, read_json=True):
|
||||||
if "time" in j:
|
if "time" in j:
|
||||||
j["time"] = "{:.3f}".format(j["time"])
|
j["time"] = "{:.3f}".format(j["time"])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
j,
|
j,
|
||||||
latents,
|
latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_tts(restart=False):
|
def load_tts( restart=False, model=None ):
|
||||||
global args
|
global args
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
if restart:
|
if restart:
|
||||||
unload_tts()
|
unload_tts()
|
||||||
|
|
||||||
|
|
||||||
|
if model:
|
||||||
|
args.autoregressive_model = model
|
||||||
|
|
||||||
print(f"Loading TorToiSe... (using model: {args.autoregressive_model})")
|
print(f"Loading TorToiSe... (using model: {args.autoregressive_model})")
|
||||||
|
|
||||||
|
tts_loading = True
|
||||||
try:
|
try:
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model)
|
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
||||||
load_autoregressive_model(args.autoregressive_model)
|
load_autoregressive_model(args.autoregressive_model)
|
||||||
|
tts_loading = False
|
||||||
|
|
||||||
get_model_path('dvae.pth')
|
get_model_path('dvae.pth')
|
||||||
print("Loaded TorToiSe, ready for generation.")
|
print("Loaded TorToiSe, ready for generation.")
|
||||||
|
@ -1015,17 +1033,24 @@ def unload_tts():
|
||||||
tts = None
|
tts = None
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
def reload_tts():
|
def reload_tts( model=None ):
|
||||||
setup_tortoise(restart=True)
|
load_tts( restart=True, model=model )
|
||||||
|
|
||||||
def update_autoregressive_model(autoregressive_model_path):
|
def update_autoregressive_model(autoregressive_model_path):
|
||||||
|
if not autoregressive_model_path or not os.path.exists(autoregressive_model_path):
|
||||||
|
return
|
||||||
|
|
||||||
args.autoregressive_model = autoregressive_model_path
|
args.autoregressive_model = autoregressive_model_path
|
||||||
save_args_settings()
|
save_args_settings()
|
||||||
print(f'Stored autoregressive model to settings: {autoregressive_model_path}')
|
print(f'Stored autoregressive model to settings: {autoregressive_model_path}')
|
||||||
|
|
||||||
global tts
|
global tts
|
||||||
if not tts:
|
if not tts:
|
||||||
raise Exception("TTS is uninitialized or still initializing...")
|
if tts_loading:
|
||||||
|
raise Exception("TTS is still initializing...")
|
||||||
|
|
||||||
|
load_tts( model=autoregressive_model_path )
|
||||||
|
return # redundant to proceed onward
|
||||||
|
|
||||||
print(f"Loading model: {autoregressive_model_path}")
|
print(f"Loading model: {autoregressive_model_path}")
|
||||||
|
|
||||||
|
@ -1099,6 +1124,9 @@ def unload_whisper():
|
||||||
do_gc()
|
do_gc()
|
||||||
|
|
||||||
def update_whisper_model(name, progress=None):
|
def update_whisper_model(name, progress=None):
|
||||||
|
if not name:
|
||||||
|
return
|
||||||
|
|
||||||
global whisper_model
|
global whisper_model
|
||||||
if whisper_model:
|
if whisper_model:
|
||||||
unload_whisper()
|
unload_whisper()
|
||||||
|
|
82
src/webui.py
82
src/webui.py
|
@ -82,7 +82,6 @@ def run_generation(
|
||||||
outputs[0],
|
outputs[0],
|
||||||
gr.update(value=sample, visible=sample is not None),
|
gr.update(value=sample, visible=sample is not None),
|
||||||
gr.update(choices=outputs, value=outputs[0], visible=len(outputs) > 1, interactive=True),
|
gr.update(choices=outputs, value=outputs[0], visible=len(outputs) > 1, interactive=True),
|
||||||
gr.update(visible=len(outputs) > 1),
|
|
||||||
gr.update(value=stats, visible=True),
|
gr.update(value=stats, visible=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -170,6 +169,8 @@ def read_generate_settings_proxy(file, saveAs='.temp'):
|
||||||
|
|
||||||
latents = f'{outdir}/cond_latents.pth'
|
latents = f'{outdir}/cond_latents.pth'
|
||||||
|
|
||||||
|
print(j, latents)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
gr.update(value=j, visible=j is not None),
|
gr.update(value=j, visible=j is not None),
|
||||||
gr.update(visible=j is not None),
|
gr.update(visible=j is not None),
|
||||||
|
@ -244,12 +245,6 @@ def update_voices():
|
||||||
def history_copy_settings( voice, file ):
|
def history_copy_settings( voice, file ):
|
||||||
return import_generate_settings( f"./results/{voice}/{file}" )
|
return import_generate_settings( f"./results/{voice}/{file}" )
|
||||||
|
|
||||||
def update_model_settings( autoregressive_model, whisper_model ):
|
|
||||||
update_autoregressive_model(autoregressive_model)
|
|
||||||
update_whisper_model(whisper_model)
|
|
||||||
|
|
||||||
save_args_settings()
|
|
||||||
|
|
||||||
def setup_gradio():
|
def setup_gradio():
|
||||||
global args
|
global args
|
||||||
global ui
|
global ui
|
||||||
|
@ -318,19 +313,30 @@ def setup_gradio():
|
||||||
generation_results = gr.Dataframe(label="Results", headers=["Seed", "Time"], visible=False)
|
generation_results = gr.Dataframe(label="Results", headers=["Seed", "Time"], visible=False)
|
||||||
source_sample = gr.Audio(label="Source Sample", visible=False)
|
source_sample = gr.Audio(label="Source Sample", visible=False)
|
||||||
output_audio = gr.Audio(label="Output")
|
output_audio = gr.Audio(label="Output")
|
||||||
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False)
|
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False, choices=[""], value="")
|
||||||
output_pick = gr.Button(value="Select Candidate", visible=False)
|
# output_pick = gr.Button(value="Select Candidate", visible=False)
|
||||||
|
|
||||||
|
def change_candidate( val ):
|
||||||
|
if not val:
|
||||||
|
return
|
||||||
|
print(val)
|
||||||
|
return val
|
||||||
|
|
||||||
|
candidates_list.change(
|
||||||
|
fn=change_candidate,
|
||||||
|
inputs=candidates_list,
|
||||||
|
outputs=output_audio,
|
||||||
|
)
|
||||||
with gr.Tab("History"):
|
with gr.Tab("History"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys()))
|
history_info = gr.Dataframe(label="Results", headers=list(history_headers.keys()))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
history_voices = gr.Dropdown(choices=get_voice_list("./results/"), label="Voice", type="value")
|
result_voices = get_voice_list("./results/")
|
||||||
history_view_results_button = gr.Button(value="View Files")
|
history_voices = gr.Dropdown(choices=result_voices, label="Voice", type="value", value=result_voices[0])
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
history_results_list = gr.Dropdown(label="Results",type="value", interactive=True)
|
history_results_list = gr.Dropdown(label="Results",type="value", interactive=True, value="")
|
||||||
history_view_result_button = gr.Button(value="View File")
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
history_audio = gr.Audio()
|
history_audio = gr.Audio()
|
||||||
history_copy_settings_button = gr.Button(value="Copy Settings")
|
history_copy_settings_button = gr.Button(value="Copy Settings")
|
||||||
|
@ -407,10 +413,10 @@ def setup_gradio():
|
||||||
gr.Checkbox(label="Low VRAM", value=args.low_vram),
|
gr.Checkbox(label="Low VRAM", value=args.low_vram),
|
||||||
gr.Checkbox(label="Embed Output Metadata", value=args.embed_output_metadata),
|
gr.Checkbox(label="Embed Output Metadata", value=args.embed_output_metadata),
|
||||||
gr.Checkbox(label="Slimmer Computed Latents", value=args.latents_lean_and_mean),
|
gr.Checkbox(label="Slimmer Computed Latents", value=args.latents_lean_and_mean),
|
||||||
gr.Checkbox(label="Voice Fixer", value=args.voice_fixer),
|
gr.Checkbox(label="Use Voice Fixer on Generated Output", value=args.voice_fixer),
|
||||||
gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda),
|
gr.Checkbox(label="Use CUDA for Voice Fixer", value=args.voice_fixer_use_cuda),
|
||||||
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
|
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
|
||||||
gr.Checkbox(label="Defer TTS Load", value=args.defer_tts_load),
|
gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load),
|
||||||
gr.Textbox(label="Device Override", value=args.device_override),
|
gr.Textbox(label="Device Override", value=args.device_override),
|
||||||
]
|
]
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
@ -421,12 +427,28 @@ def setup_gradio():
|
||||||
gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume),
|
gr.Slider(label="Ouptut Volume", minimum=0, maximum=2, value=args.output_volume),
|
||||||
]
|
]
|
||||||
|
|
||||||
autoregressive_model_dropdown = gr.Dropdown(get_autoregressive_models(), label="Autoregressive Model", value=args.autoregressive_model)
|
autoregressive_models = get_autoregressive_models()
|
||||||
|
autoregressive_model_dropdown = gr.Dropdown(choices=autoregressive_models, label="Autoregressive Model", value=args.autoregressive_model if args.autoregressive_model else autoregressive_models[0])
|
||||||
whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model)
|
whisper_model_dropdown = gr.Dropdown(["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large"], label="Whisper Model", value=args.whisper_model)
|
||||||
save_settings_button = gr.Button(value="Save Settings")
|
|
||||||
|
|
||||||
|
autoregressive_model_dropdown.change(
|
||||||
|
fn=update_autoregressive_model,
|
||||||
|
inputs=autoregressive_model_dropdown,
|
||||||
|
outputs=None
|
||||||
|
)
|
||||||
|
whisper_model_dropdown.change(
|
||||||
|
fn=update_whisper_model,
|
||||||
|
inputs=whisper_model_dropdown,
|
||||||
|
outputs=None
|
||||||
|
)
|
||||||
|
|
||||||
gr.Button(value="Check for Updates").click(check_for_updates)
|
gr.Button(value="Check for Updates").click(check_for_updates)
|
||||||
gr.Button(value="(Re)Load TTS").click(reload_tts)
|
gr.Button(value="(Re)Load TTS").click(
|
||||||
|
reload_tts,
|
||||||
|
inputs=autoregressive_model_dropdown,
|
||||||
|
outputs=None
|
||||||
|
)
|
||||||
|
|
||||||
for i in exec_inputs:
|
for i in exec_inputs:
|
||||||
i.change( fn=update_args, inputs=exec_inputs )
|
i.change( fn=update_args, inputs=exec_inputs )
|
||||||
|
@ -457,7 +479,7 @@ def setup_gradio():
|
||||||
experimental_checkboxes,
|
experimental_checkboxes,
|
||||||
]
|
]
|
||||||
|
|
||||||
history_view_results_button.click(
|
history_voices.change(
|
||||||
fn=history_view_results,
|
fn=history_view_results,
|
||||||
inputs=history_voices,
|
inputs=history_voices,
|
||||||
outputs=[
|
outputs=[
|
||||||
|
@ -465,7 +487,7 @@ def setup_gradio():
|
||||||
history_results_list,
|
history_results_list,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
history_view_result_button.click(
|
history_results_list.change(
|
||||||
fn=lambda voice, file: f"./results/{voice}/{file}",
|
fn=lambda voice, file: f"./results/{voice}/{file}",
|
||||||
inputs=[
|
inputs=[
|
||||||
history_voices,
|
history_voices,
|
||||||
|
@ -531,20 +553,14 @@ def setup_gradio():
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
output_pick.click(
|
|
||||||
lambda x: x,
|
|
||||||
inputs=candidates_list,
|
|
||||||
outputs=output_audio,
|
|
||||||
)
|
|
||||||
|
|
||||||
submit.click(
|
submit.click(
|
||||||
lambda: (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)),
|
lambda: (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)),
|
||||||
outputs=[source_sample, candidates_list, output_pick, generation_results],
|
outputs=[source_sample, candidates_list, generation_results],
|
||||||
)
|
)
|
||||||
|
|
||||||
submit_event = submit.click(run_generation,
|
submit_event = submit.click(run_generation,
|
||||||
inputs=input_settings,
|
inputs=input_settings,
|
||||||
outputs=[output_audio, source_sample, candidates_list, output_pick, generation_results],
|
outputs=[output_audio, source_sample, candidates_list, generation_results],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -603,14 +619,6 @@ def setup_gradio():
|
||||||
outputs=save_yaml_output #console_output
|
outputs=save_yaml_output #console_output
|
||||||
)
|
)
|
||||||
|
|
||||||
save_settings_button.click(update_model_settings,
|
|
||||||
inputs=[
|
|
||||||
autoregressive_model_dropdown,
|
|
||||||
whisper_model_dropdown,
|
|
||||||
],
|
|
||||||
outputs=None
|
|
||||||
)
|
|
||||||
|
|
||||||
if os.path.isfile('./config/generate.json'):
|
if os.path.isfile('./config/generate.json'):
|
||||||
ui.load(import_generate_settings, inputs=None, outputs=input_settings)
|
ui.load(import_generate_settings, inputs=None, outputs=input_settings)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user