1
1
forked from mrq/tortoise-tts

Added candidate selection for outputs, hide output elements (except for the main one) to only show one progress bar

This commit is contained in:
mrq 2023-02-11 16:34:47 +00:00
parent a7330164ab
commit 6d06bcce05

118
webui.py
View File

@ -43,7 +43,7 @@ def generate(
repetition_penalty, repetition_penalty,
cond_free_k, cond_free_k,
experimental_checkboxes, experimental_checkboxes,
progress=gr.Progress(track_tqdm=True) progress=None
): ):
global args global args
global tts global tts
@ -150,6 +150,10 @@ def generate(
if idx: if idx:
idx = idx + 1 idx = idx + 1
# reserve, if for whatever reason you manage to concurrently generate
with open(f'{outdir}/input_{idx}.json', 'w', encoding="utf-8") as f:
f.write(" ")
def get_name(line=0, candidate=0, combined=False): def get_name(line=0, candidate=0, combined=False):
name = f"{idx}" name = f"{idx}"
if combined: if combined:
@ -204,8 +208,8 @@ def generate(
output_voice = None output_voice = None
output_voices = [] output_voices = []
if len(texts) > 1: for candidate in range(candidates):
for candidate in range(candidates): if len(texts) > 1:
audio_clips = [] audio_clips = []
for line in range(len(texts)): for line in range(len(texts)):
name = get_name(line=line, candidate=candidate) name = get_name(line=line, candidate=candidate)
@ -226,8 +230,7 @@ def generate(
output_voices.append(f'{outdir}/{voice}_{name}.wav') output_voices.append(f'{outdir}/{voice}_{name}.wav')
if output_voice is None: if output_voice is None:
output_voice = f'{outdir}/{voice}_{name}.wav' output_voice = f'{outdir}/{voice}_{name}.wav'
else: else:
for candidate in range(candidates):
name = get_name(candidate=candidate) name = get_name(candidate=candidate)
output_voices.append(f'{outdir}/{voice}_{name}.wav') output_voices.append(f'{outdir}/{voice}_{name}.wav')
@ -266,14 +269,12 @@ def generate(
# we could do this on the pieces before they get stiched up anyways to save some compute # we could do this on the pieces before they get stiched up anyways to save some compute
# but the stitching would need to read back from disk, defeating the point of caching the waveform # but the stitching would need to read back from disk, defeating the point of caching the waveform
for path in progress.tqdm(audio_cache, desc="Running voicefix..."): for path in progress.tqdm(audio_cache, desc="Running voicefix..."):
print("VoiceFix starting")
voicefixer.restore( voicefixer.restore(
input=f'{outdir}/{voice}_{k}.wav', input=f'{outdir}/{voice}_{k}.wav',
output=f'{outdir}/{voice}_{k}.wav', output=f'{outdir}/{voice}_{k}.wav',
#cuda=False, #cuda=False,
#mode=mode, #mode=mode,
) )
print("VoiceFix finished")
if args.embed_output_metadata: if args.embed_output_metadata:
for path in progress.tqdm(audio_cache, desc="Embedding metadata..."): for path in progress.tqdm(audio_cache, desc="Embedding metadata..."):
@ -289,25 +290,22 @@ def generate(
if sample_voice is not None: if sample_voice is not None:
sample_voice = (tts.input_sample_rate, sample_voice.numpy()) sample_voice = (tts.input_sample_rate, sample_voice.numpy())
if output_voice is None and len(output_voices):
output_voice = output_voices[0]
print(f"Generation took {info['time']} seconds, saved to '{output_voice}'\n") print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
info['seed'] = settings['use_deterministic_seed'] info['seed'] = settings['use_deterministic_seed']
del info['latents'] del info['latents']
with open(f'./config/generate.json', 'w', encoding="utf-8") as f: with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t') ) f.write(json.dumps(info, indent='\t') )
results = [ stats = [
[ seed, "{:.3f}".format(info['time']) ] [ seed, "{:.3f}".format(info['time']) ]
] ]
return ( return (
sample_voice, sample_voice,
output_voice, output_voices,
results, stats,
) )
def update_presets(value): def update_presets(value):
@ -438,6 +436,7 @@ def check_for_updates():
return False return False
def reload_tts(): def reload_tts():
del tts
tts = setup_tortoise() tts = setup_tortoise()
def cancel_generate(): def cancel_generate():
@ -577,6 +576,8 @@ def setup_gradio():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
text = gr.Textbox(lines=4, label="Prompt") text = gr.Textbox(lines=4, label="Prompt")
with gr.Row():
with gr.Column():
delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n") delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
emotion = gr.Radio( emotion = gr.Radio(
@ -640,12 +641,15 @@ def setup_gradio():
], ],
) )
with gr.Column(): with gr.Column():
selected_voice = gr.Audio(label="Source Sample")
output_audio = gr.Audio(label="Output")
generation_results = gr.Dataframe(label="Results", headers=["Seed", "Time"])
submit = gr.Button(value="Generate") submit = gr.Button(value="Generate")
stop = gr.Button(value="Stop") stop = gr.Button(value="Stop")
generation_results = gr.Dataframe(label="Results", headers=["Seed", "Time"], visible=False)
source_sample = gr.Audio(label="Source Sample", visible=False)
output_audio = gr.Audio(label="Output")
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False)
output_pick = gr.Button(value="Select Candidate", visible=False)
with gr.Tab("History"): with gr.Tab("History"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -810,11 +814,83 @@ def setup_gradio():
experimental_checkboxes, experimental_checkboxes,
] ]
submit_event = submit.click(generate, # YUCK
inputs=input_settings, def run_generation(
outputs=[selected_voice, output_audio, generation_results], text,
delimiter,
emotion,
prompt,
voice,
mic_audio,
seed,
candidates,
num_autoregressive_samples,
diffusion_iterations,
temperature,
diffusion_sampler,
breathing_room,
cvvp_weight,
top_p,
diffusion_temperature,
length_penalty,
repetition_penalty,
cond_free_k,
experimental_checkboxes,
progress=gr.Progress(track_tqdm=True)
):
try:
sample, outputs, stats = generate(
text,
delimiter,
emotion,
prompt,
voice,
mic_audio,
seed,
candidates,
num_autoregressive_samples,
diffusion_iterations,
temperature,
diffusion_sampler,
breathing_room,
cvvp_weight,
top_p,
diffusion_temperature,
length_penalty,
repetition_penalty,
cond_free_k,
experimental_checkboxes,
progress
)
except Exception as e:
raise gr.Error(e)
return (
outputs[0],
gr.update(value=sample, visible=sample is not None),
gr.update(choices=outputs, visible=len(outputs) > 1, interactive=True),
gr.update(visible=len(outputs) > 1),
gr.update(value=stats, visible=True),
)
output_pick.click(
lambda x: x,
inputs=candidates_list,
outputs=output_audio,
) )
submit.click(
lambda: (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)),
outputs=[source_sample, candidates_list, output_pick, generation_results],
)
submit_event = submit.click(run_generation,
inputs=input_settings,
outputs=[output_audio, source_sample, candidates_list, output_pick, generation_results],
)
copy_button.click(import_generate_settings, copy_button.click(import_generate_settings,
inputs=audio_in, # JSON elements cannot be used as inputs inputs=audio_in, # JSON elements cannot be used as inputs
outputs=input_settings outputs=input_settings