forked from mrq/tortoise-tts
revamped result formatting, added "kludgy" stop button
This commit is contained in:
parent
8f789d17b9
commit
8641cc9906
|
@ -36,6 +36,8 @@ from tortoise.utils.device import get_device, get_device_name, get_device_batch_
|
|||
|
||||
pbar = None
|
||||
|
||||
STOP_SIGNAL = False
|
||||
|
||||
MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR')
|
||||
MODELS = {
|
||||
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
|
||||
|
@ -49,6 +51,11 @@ MODELS = {
|
|||
}
|
||||
|
||||
def tqdm_override(arr, verbose=False, progress=None, desc=None):
|
||||
global STOP_SIGNAL
|
||||
if STOP_SIGNAL:
|
||||
STOP_SIGNAL = False
|
||||
raise Exception("Kill signal detected")
|
||||
|
||||
if verbose and desc is not None:
|
||||
print(desc)
|
||||
|
||||
|
@ -60,6 +67,7 @@ def download_models(specific_models=None):
|
|||
"""
|
||||
Call to download all the models that Tortoise uses.
|
||||
"""
|
||||
|
||||
os.makedirs(MODELS_DIR, exist_ok=True)
|
||||
|
||||
def show_progress(block_num, block_size, total_size):
|
||||
|
|
61
webui.py
61
webui.py
|
@ -16,6 +16,8 @@ from datetime import datetime
|
|||
|
||||
from fastapi import FastAPI
|
||||
|
||||
import tortoise.api
|
||||
|
||||
from tortoise.api import TextToSpeech
|
||||
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
|
||||
from tortoise.utils.text import split_and_recombine_text
|
||||
|
@ -124,7 +126,7 @@ def generate(
|
|||
|
||||
start_time = time.time()
|
||||
|
||||
outdir = f"./results/{voice}/{int(start_time)}/"
|
||||
outdir = f"./results/{voice}/"
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
audio_cache = {}
|
||||
|
@ -140,6 +142,22 @@ def generate(
|
|||
|
||||
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
|
||||
|
||||
idx = 0
|
||||
for i, file in enumerate(os.listdir(outdir)):
|
||||
if file[-4:] == ".wav":
|
||||
idx = idx + 1
|
||||
|
||||
def get_name(line=0, candidate=0, combined=False):
|
||||
name = f"{idx}"
|
||||
if len(texts) > 1:
|
||||
name = f"{name}_{line}"
|
||||
if candidates > 1:
|
||||
name = f"{name}_{candidate}"
|
||||
|
||||
if combined:
|
||||
return f"{idx}_combined"
|
||||
return name
|
||||
|
||||
for line, cut_text in enumerate(texts):
|
||||
if emotion == "Custom":
|
||||
if prompt.strip() != "":
|
||||
|
@ -154,13 +172,14 @@ def generate(
|
|||
|
||||
if isinstance(gen, list):
|
||||
for j, g in enumerate(gen):
|
||||
os.makedirs(f'{outdir}/candidate_{j}', exist_ok=True)
|
||||
audio_cache[f"candidate_{j}/result_{line}.wav"] = {
|
||||
name = get_name(line=line, candidate=j)
|
||||
audio_cache[name] = {
|
||||
'audio': g,
|
||||
'text': cut_text,
|
||||
}
|
||||
else:
|
||||
audio_cache[f"result_{line}.wav"] = {
|
||||
name = get_name(line=line)
|
||||
audio_cache[name] = {
|
||||
'audio': gen,
|
||||
'text': cut_text,
|
||||
}
|
||||
|
@ -173,7 +192,7 @@ def generate(
|
|||
audio = volume_adjust(audio)
|
||||
|
||||
audio_cache[k]['audio'] = audio
|
||||
torchaudio.save(f'{outdir}/{k}', audio, args.output_sample_rate)
|
||||
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
|
||||
|
||||
|
||||
output_voice = None
|
||||
|
@ -182,30 +201,29 @@ def generate(
|
|||
audio_clips = []
|
||||
for line in range(len(texts)):
|
||||
if isinstance(gen, list):
|
||||
audio = audio_cache[f'candidate_{candidate}/result_{line}.wav']['audio']
|
||||
name = get_name(line=line, candidate=candidate)
|
||||
audio = audio_cache[name]['audio']
|
||||
else:
|
||||
audio = audio_cache[f'result_{line}.wav']['audio']
|
||||
name = get_name(line=line)
|
||||
audio = audio_cache[name]['audio']
|
||||
audio_clips.append(audio)
|
||||
|
||||
name = get_name(combined=True)
|
||||
audio = torch.cat(audio_clips, dim=-1)
|
||||
torchaudio.save(f'{outdir}/combined_{candidate}.wav', audio, args.output_sample_rate)
|
||||
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
|
||||
|
||||
audio = audio.squeeze(0).cpu()
|
||||
audio_cache[f'combined_{candidate}.wav'] = {
|
||||
audio_cache[name] = {
|
||||
'audio': audio,
|
||||
'text': cut_text,
|
||||
}
|
||||
|
||||
if output_voice is None:
|
||||
output_voice = f'{outdir}/combined_{candidate}.wav'
|
||||
output_voice = f'{outdir}/{voice}_{name}.wav'
|
||||
# output_voice = audio
|
||||
else:
|
||||
if isinstance(gen, list):
|
||||
output_voice = f'{outdir}/candidate_0/result_0.wav'
|
||||
#output_voice = gen[0]
|
||||
else:
|
||||
output_voice = f'{outdir}/result_0.wav'
|
||||
#output_voice = gen
|
||||
name = get_name()
|
||||
output_voice = f'{outdir}/{voice}_{name}.wav'
|
||||
|
||||
info = {
|
||||
'text': text,
|
||||
|
@ -231,7 +249,7 @@ def generate(
|
|||
'time': time.time()-start_time,
|
||||
}
|
||||
|
||||
with open(f'{outdir}/input.json', 'w', encoding="utf-8") as f:
|
||||
with open(f'{outdir}/input_{idx}.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(info, indent='\t') )
|
||||
|
||||
if voice is not None and conditioning_latents is not None:
|
||||
|
@ -242,7 +260,7 @@ def generate(
|
|||
for path in audio_cache:
|
||||
info['text'] = audio_cache[path]['text']
|
||||
|
||||
metadata = music_tag.load_file(f"{outdir}/{path}")
|
||||
metadata = music_tag.load_file(f"{outdir}/{voice}_{path}.wav")
|
||||
metadata['lyrics'] = json.dumps(info)
|
||||
metadata.save()
|
||||
|
||||
|
@ -389,6 +407,9 @@ def check_for_updates():
|
|||
def reload_tts():
|
||||
tts = setup_tortoise()
|
||||
|
||||
def cancel_generate():
|
||||
tortoise.api.STOP_SIGNAL = True
|
||||
|
||||
def update_voices():
|
||||
return gr.Dropdown.update(choices=sorted(os.listdir(get_voice_dir())) + ["microphone"])
|
||||
|
||||
|
@ -574,7 +595,7 @@ def setup_gradio():
|
|||
usedSeed = gr.Textbox(label="Seed", placeholder="0", interactive=False)
|
||||
|
||||
submit = gr.Button(value="Generate")
|
||||
#stop = gr.Button(value="Stop")
|
||||
stop = gr.Button(value="Stop")
|
||||
with gr.Tab("Utilities"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
|
@ -676,7 +697,7 @@ def setup_gradio():
|
|||
if args.check_for_updates:
|
||||
webui.load(check_for_updates)
|
||||
|
||||
#stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event])
|
||||
stop.click(fn=cancel_generate, inputs=None, outputs=None, cancels=[submit_event])
|
||||
|
||||
|
||||
webui.queue(concurrency_count=args.concurrency_count)
|
||||
|
|
Loading…
Reference in New Issue
Block a user