1
1
forked from mrq/tortoise-tts

revamped result formatting, added "kludgy" stop button

This commit is contained in:
mrq 2023-02-10 22:12:37 +00:00
parent 9e0fbff545
commit 4f903159ee
2 changed files with 49 additions and 20 deletions

View File

@ -36,6 +36,8 @@ from tortoise.utils.device import get_device, get_device_name, get_device_batch_
pbar = None
STOP_SIGNAL = False
MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR')
MODELS = {
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
@ -49,6 +51,11 @@ MODELS = {
}
def tqdm_override(arr, verbose=False, progress=None, desc=None):
global STOP_SIGNAL
if STOP_SIGNAL:
STOP_SIGNAL = False
raise Exception("Kill signal detected")
if verbose and desc is not None:
print(desc)
@ -60,6 +67,7 @@ def download_models(specific_models=None):
"""
Call to download all the models that Tortoise uses.
"""
os.makedirs(MODELS_DIR, exist_ok=True)
def show_progress(block_num, block_size, total_size):

View File

@ -16,6 +16,8 @@ from datetime import datetime
from fastapi import FastAPI
import tortoise.api
from tortoise.api import TextToSpeech
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
from tortoise.utils.text import split_and_recombine_text
@ -124,7 +126,7 @@ def generate(
start_time = time.time()
outdir = f"./results/{voice}/{int(start_time)}/"
outdir = f"./results/{voice}/"
os.makedirs(outdir, exist_ok=True)
audio_cache = {}
@ -140,6 +142,22 @@ def generate(
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
idx = 0
for i, file in enumerate(os.listdir(outdir)):
if file[-4:] == ".wav":
idx = idx + 1
def get_name(line=0, candidate=0, combined=False):
name = f"{idx}"
if len(texts) > 1:
name = f"{name}_{line}"
if candidates > 1:
name = f"{name}_{candidate}"
if combined:
return f"{idx}_combined"
return name
for line, cut_text in enumerate(texts):
if emotion == "Custom":
if prompt.strip() != "":
@ -154,13 +172,14 @@ def generate(
if isinstance(gen, list):
for j, g in enumerate(gen):
os.makedirs(f'{outdir}/candidate_{j}', exist_ok=True)
audio_cache[f"candidate_{j}/result_{line}.wav"] = {
name = get_name(line=line, candidate=j)
audio_cache[name] = {
'audio': g,
'text': cut_text,
}
else:
audio_cache[f"result_{line}.wav"] = {
name = get_name(line=line)
audio_cache[name] = {
'audio': gen,
'text': cut_text,
}
@ -173,7 +192,7 @@ def generate(
audio = volume_adjust(audio)
audio_cache[k]['audio'] = audio
torchaudio.save(f'{outdir}/{k}', audio, args.output_sample_rate)
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
output_voice = None
@ -182,30 +201,29 @@ def generate(
audio_clips = []
for line in range(len(texts)):
if isinstance(gen, list):
audio = audio_cache[f'candidate_{candidate}/result_{line}.wav']['audio']
name = get_name(line=line, candidate=candidate)
audio = audio_cache[name]['audio']
else:
audio = audio_cache[f'result_{line}.wav']['audio']
name = get_name(line=line)
audio = audio_cache[name]['audio']
audio_clips.append(audio)
name = get_name(combined=True)
audio = torch.cat(audio_clips, dim=-1)
torchaudio.save(f'{outdir}/combined_{candidate}.wav', audio, args.output_sample_rate)
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
audio = audio.squeeze(0).cpu()
audio_cache[f'combined_{candidate}.wav'] = {
audio_cache[name] = {
'audio': audio,
'text': cut_text,
}
if output_voice is None:
output_voice = f'{outdir}/combined_{candidate}.wav'
output_voice = f'{outdir}/{voice}_{name}.wav'
# output_voice = audio
else:
if isinstance(gen, list):
output_voice = f'{outdir}/candidate_0/result_0.wav'
#output_voice = gen[0]
else:
output_voice = f'{outdir}/result_0.wav'
#output_voice = gen
name = get_name()
output_voice = f'{outdir}/{voice}_{name}.wav'
info = {
'text': text,
@ -231,7 +249,7 @@ def generate(
'time': time.time()-start_time,
}
with open(f'{outdir}/input.json', 'w', encoding="utf-8") as f:
with open(f'{outdir}/input_{idx}.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(info, indent='\t') )
if voice is not None and conditioning_latents is not None:
@ -242,7 +260,7 @@ def generate(
for path in audio_cache:
info['text'] = audio_cache[path]['text']
metadata = music_tag.load_file(f"{outdir}/{path}")
metadata = music_tag.load_file(f"{outdir}/{voice}_{path}.wav")
metadata['lyrics'] = json.dumps(info)
metadata.save()
@ -389,6 +407,9 @@ def check_for_updates():
def reload_tts():
tts = setup_tortoise()
def cancel_generate():
tortoise.api.STOP_SIGNAL = True
def update_voices():
return gr.Dropdown.update(choices=sorted(os.listdir(get_voice_dir())) + ["microphone"])
@ -574,7 +595,7 @@ def setup_gradio():
usedSeed = gr.Textbox(label="Seed", placeholder="0", interactive=False)
submit = gr.Button(value="Generate")
#stop = gr.Button(value="Stop")
stop = gr.Button(value="Stop")
with gr.Tab("Utilities"):
with gr.Row():
with gr.Column():
@ -676,7 +697,7 @@ def setup_gradio():
if args.check_for_updates:
webui.load(check_for_updates)
#stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event])
stop.click(fn=cancel_generate, inputs=None, outputs=None, cancels=[submit_event])
webui.queue(concurrency_count=args.concurrency_count)