forked from mrq/tortoise-tts
Added candidate selection for outputs, hide output elements (except for the main one) to only show one progress bar
This commit is contained in:
parent
a7330164ab
commit
6d06bcce05
118
webui.py
118
webui.py
|
@ -43,7 +43,7 @@ def generate(
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
cond_free_k,
|
cond_free_k,
|
||||||
experimental_checkboxes,
|
experimental_checkboxes,
|
||||||
progress=gr.Progress(track_tqdm=True)
|
progress=None
|
||||||
):
|
):
|
||||||
global args
|
global args
|
||||||
global tts
|
global tts
|
||||||
|
@ -150,6 +150,10 @@ def generate(
|
||||||
if idx:
|
if idx:
|
||||||
idx = idx + 1
|
idx = idx + 1
|
||||||
|
|
||||||
|
# reserve, if for whatever reason you manage to concurrently generate
|
||||||
|
with open(f'{outdir}/input_{idx}.json', 'w', encoding="utf-8") as f:
|
||||||
|
f.write(" ")
|
||||||
|
|
||||||
def get_name(line=0, candidate=0, combined=False):
|
def get_name(line=0, candidate=0, combined=False):
|
||||||
name = f"{idx}"
|
name = f"{idx}"
|
||||||
if combined:
|
if combined:
|
||||||
|
@ -204,8 +208,8 @@ def generate(
|
||||||
|
|
||||||
output_voice = None
|
output_voice = None
|
||||||
output_voices = []
|
output_voices = []
|
||||||
if len(texts) > 1:
|
for candidate in range(candidates):
|
||||||
for candidate in range(candidates):
|
if len(texts) > 1:
|
||||||
audio_clips = []
|
audio_clips = []
|
||||||
for line in range(len(texts)):
|
for line in range(len(texts)):
|
||||||
name = get_name(line=line, candidate=candidate)
|
name = get_name(line=line, candidate=candidate)
|
||||||
|
@ -226,8 +230,7 @@ def generate(
|
||||||
output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
||||||
if output_voice is None:
|
if output_voice is None:
|
||||||
output_voice = f'{outdir}/{voice}_{name}.wav'
|
output_voice = f'{outdir}/{voice}_{name}.wav'
|
||||||
else:
|
else:
|
||||||
for candidate in range(candidates):
|
|
||||||
name = get_name(candidate=candidate)
|
name = get_name(candidate=candidate)
|
||||||
output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
output_voices.append(f'{outdir}/{voice}_{name}.wav')
|
||||||
|
|
||||||
|
@ -266,14 +269,12 @@ def generate(
|
||||||
# we could do this on the pieces before they get stiched up anyways to save some compute
|
# we could do this on the pieces before they get stiched up anyways to save some compute
|
||||||
# but the stitching would need to read back from disk, defeating the point of caching the waveform
|
# but the stitching would need to read back from disk, defeating the point of caching the waveform
|
||||||
for path in progress.tqdm(audio_cache, desc="Running voicefix..."):
|
for path in progress.tqdm(audio_cache, desc="Running voicefix..."):
|
||||||
print("VoiceFix starting")
|
|
||||||
voicefixer.restore(
|
voicefixer.restore(
|
||||||
input=f'{outdir}/{voice}_{k}.wav',
|
input=f'{outdir}/{voice}_{k}.wav',
|
||||||
output=f'{outdir}/{voice}_{k}.wav',
|
output=f'{outdir}/{voice}_{k}.wav',
|
||||||
#cuda=False,
|
#cuda=False,
|
||||||
#mode=mode,
|
#mode=mode,
|
||||||
)
|
)
|
||||||
print("VoiceFix finished")
|
|
||||||
|
|
||||||
if args.embed_output_metadata:
|
if args.embed_output_metadata:
|
||||||
for path in progress.tqdm(audio_cache, desc="Embedding metadata..."):
|
for path in progress.tqdm(audio_cache, desc="Embedding metadata..."):
|
||||||
|
@ -290,24 +291,21 @@ def generate(
|
||||||
if sample_voice is not None:
|
if sample_voice is not None:
|
||||||
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
|
sample_voice = (tts.input_sample_rate, sample_voice.numpy())
|
||||||
|
|
||||||
if output_voice is None and len(output_voices):
|
print(f"Generation took {info['time']} seconds, saved to '{output_voices[0]}'\n")
|
||||||
output_voice = output_voices[0]
|
|
||||||
|
|
||||||
print(f"Generation took {info['time']} seconds, saved to '{output_voice}'\n")
|
|
||||||
|
|
||||||
info['seed'] = settings['use_deterministic_seed']
|
info['seed'] = settings['use_deterministic_seed']
|
||||||
del info['latents']
|
del info['latents']
|
||||||
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(info, indent='\t') )
|
f.write(json.dumps(info, indent='\t') )
|
||||||
|
|
||||||
results = [
|
stats = [
|
||||||
[ seed, "{:.3f}".format(info['time']) ]
|
[ seed, "{:.3f}".format(info['time']) ]
|
||||||
]
|
]
|
||||||
|
|
||||||
return (
|
return (
|
||||||
sample_voice,
|
sample_voice,
|
||||||
output_voice,
|
output_voices,
|
||||||
results,
|
stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_presets(value):
|
def update_presets(value):
|
||||||
|
@ -438,6 +436,7 @@ def check_for_updates():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def reload_tts():
|
def reload_tts():
|
||||||
|
del tts
|
||||||
tts = setup_tortoise()
|
tts = setup_tortoise()
|
||||||
|
|
||||||
def cancel_generate():
|
def cancel_generate():
|
||||||
|
@ -577,6 +576,8 @@ def setup_gradio():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
text = gr.Textbox(lines=4, label="Prompt")
|
text = gr.Textbox(lines=4, label="Prompt")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
|
delimiter = gr.Textbox(lines=1, label="Line Delimiter", placeholder="\\n")
|
||||||
|
|
||||||
emotion = gr.Radio(
|
emotion = gr.Radio(
|
||||||
|
@ -640,12 +641,15 @@ def setup_gradio():
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
selected_voice = gr.Audio(label="Source Sample")
|
|
||||||
output_audio = gr.Audio(label="Output")
|
|
||||||
generation_results = gr.Dataframe(label="Results", headers=["Seed", "Time"])
|
|
||||||
|
|
||||||
submit = gr.Button(value="Generate")
|
submit = gr.Button(value="Generate")
|
||||||
stop = gr.Button(value="Stop")
|
stop = gr.Button(value="Stop")
|
||||||
|
|
||||||
|
generation_results = gr.Dataframe(label="Results", headers=["Seed", "Time"], visible=False)
|
||||||
|
source_sample = gr.Audio(label="Source Sample", visible=False)
|
||||||
|
output_audio = gr.Audio(label="Output")
|
||||||
|
candidates_list = gr.Dropdown(label="Candidates", type="value", visible=False)
|
||||||
|
output_pick = gr.Button(value="Select Candidate", visible=False)
|
||||||
|
|
||||||
with gr.Tab("History"):
|
with gr.Tab("History"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
@ -810,11 +814,83 @@ def setup_gradio():
|
||||||
experimental_checkboxes,
|
experimental_checkboxes,
|
||||||
]
|
]
|
||||||
|
|
||||||
submit_event = submit.click(generate,
|
# YUCK
|
||||||
inputs=input_settings,
|
def run_generation(
|
||||||
outputs=[selected_voice, output_audio, generation_results],
|
text,
|
||||||
|
delimiter,
|
||||||
|
emotion,
|
||||||
|
prompt,
|
||||||
|
voice,
|
||||||
|
mic_audio,
|
||||||
|
seed,
|
||||||
|
candidates,
|
||||||
|
num_autoregressive_samples,
|
||||||
|
diffusion_iterations,
|
||||||
|
temperature,
|
||||||
|
diffusion_sampler,
|
||||||
|
breathing_room,
|
||||||
|
cvvp_weight,
|
||||||
|
top_p,
|
||||||
|
diffusion_temperature,
|
||||||
|
length_penalty,
|
||||||
|
repetition_penalty,
|
||||||
|
cond_free_k,
|
||||||
|
experimental_checkboxes,
|
||||||
|
progress=gr.Progress(track_tqdm=True)
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
sample, outputs, stats = generate(
|
||||||
|
text,
|
||||||
|
delimiter,
|
||||||
|
emotion,
|
||||||
|
prompt,
|
||||||
|
voice,
|
||||||
|
mic_audio,
|
||||||
|
seed,
|
||||||
|
candidates,
|
||||||
|
num_autoregressive_samples,
|
||||||
|
diffusion_iterations,
|
||||||
|
temperature,
|
||||||
|
diffusion_sampler,
|
||||||
|
breathing_room,
|
||||||
|
cvvp_weight,
|
||||||
|
top_p,
|
||||||
|
diffusion_temperature,
|
||||||
|
length_penalty,
|
||||||
|
repetition_penalty,
|
||||||
|
cond_free_k,
|
||||||
|
experimental_checkboxes,
|
||||||
|
progress
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise gr.Error(e)
|
||||||
|
|
||||||
|
|
||||||
|
return (
|
||||||
|
outputs[0],
|
||||||
|
gr.update(value=sample, visible=sample is not None),
|
||||||
|
gr.update(choices=outputs, visible=len(outputs) > 1, interactive=True),
|
||||||
|
gr.update(visible=len(outputs) > 1),
|
||||||
|
gr.update(value=stats, visible=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
output_pick.click(
|
||||||
|
lambda x: x,
|
||||||
|
inputs=candidates_list,
|
||||||
|
outputs=output_audio,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
submit.click(
|
||||||
|
lambda: (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)),
|
||||||
|
outputs=[source_sample, candidates_list, output_pick, generation_results],
|
||||||
|
)
|
||||||
|
|
||||||
|
submit_event = submit.click(run_generation,
|
||||||
|
inputs=input_settings,
|
||||||
|
outputs=[output_audio, source_sample, candidates_list, output_pick, generation_results],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
copy_button.click(import_generate_settings,
|
copy_button.click(import_generate_settings,
|
||||||
inputs=audio_in, # JSON elements cannot be used as inputs
|
inputs=audio_in, # JSON elements cannot be used as inputs
|
||||||
outputs=input_settings
|
outputs=input_settings
|
||||||
|
|
Loading…
Reference in New Issue
Block a user