forked from mrq/tortoise-tts
revamped result formatting, added "kludgy" stop button
This commit is contained in:
parent
9e0fbff545
commit
4f903159ee
|
@ -36,6 +36,8 @@ from tortoise.utils.device import get_device, get_device_name, get_device_batch_
|
||||||
|
|
||||||
pbar = None
|
pbar = None
|
||||||
|
|
||||||
|
STOP_SIGNAL = False
|
||||||
|
|
||||||
MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR')
|
MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR')
|
||||||
MODELS = {
|
MODELS = {
|
||||||
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
|
'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth',
|
||||||
|
@ -49,6 +51,11 @@ MODELS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
def tqdm_override(arr, verbose=False, progress=None, desc=None):
|
def tqdm_override(arr, verbose=False, progress=None, desc=None):
|
||||||
|
global STOP_SIGNAL
|
||||||
|
if STOP_SIGNAL:
|
||||||
|
STOP_SIGNAL = False
|
||||||
|
raise Exception("Kill signal detected")
|
||||||
|
|
||||||
if verbose and desc is not None:
|
if verbose and desc is not None:
|
||||||
print(desc)
|
print(desc)
|
||||||
|
|
||||||
|
@ -60,6 +67,7 @@ def download_models(specific_models=None):
|
||||||
"""
|
"""
|
||||||
Call to download all the models that Tortoise uses.
|
Call to download all the models that Tortoise uses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
os.makedirs(MODELS_DIR, exist_ok=True)
|
os.makedirs(MODELS_DIR, exist_ok=True)
|
||||||
|
|
||||||
def show_progress(block_num, block_size, total_size):
|
def show_progress(block_num, block_size, total_size):
|
||||||
|
|
61
webui.py
61
webui.py
|
@ -16,6 +16,8 @@ from datetime import datetime
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
import tortoise.api
|
||||||
|
|
||||||
from tortoise.api import TextToSpeech
|
from tortoise.api import TextToSpeech
|
||||||
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
|
from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir
|
||||||
from tortoise.utils.text import split_and_recombine_text
|
from tortoise.utils.text import split_and_recombine_text
|
||||||
|
@ -124,7 +126,7 @@ def generate(
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
outdir = f"./results/{voice}/{int(start_time)}/"
|
outdir = f"./results/{voice}/"
|
||||||
os.makedirs(outdir, exist_ok=True)
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
|
||||||
audio_cache = {}
|
audio_cache = {}
|
||||||
|
@ -140,6 +142,22 @@ def generate(
|
||||||
|
|
||||||
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
|
volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
for i, file in enumerate(os.listdir(outdir)):
|
||||||
|
if file[-4:] == ".wav":
|
||||||
|
idx = idx + 1
|
||||||
|
|
||||||
|
def get_name(line=0, candidate=0, combined=False):
|
||||||
|
name = f"{idx}"
|
||||||
|
if len(texts) > 1:
|
||||||
|
name = f"{name}_{line}"
|
||||||
|
if candidates > 1:
|
||||||
|
name = f"{name}_{candidate}"
|
||||||
|
|
||||||
|
if combined:
|
||||||
|
return f"{idx}_combined"
|
||||||
|
return name
|
||||||
|
|
||||||
for line, cut_text in enumerate(texts):
|
for line, cut_text in enumerate(texts):
|
||||||
if emotion == "Custom":
|
if emotion == "Custom":
|
||||||
if prompt.strip() != "":
|
if prompt.strip() != "":
|
||||||
|
@ -154,13 +172,14 @@ def generate(
|
||||||
|
|
||||||
if isinstance(gen, list):
|
if isinstance(gen, list):
|
||||||
for j, g in enumerate(gen):
|
for j, g in enumerate(gen):
|
||||||
os.makedirs(f'{outdir}/candidate_{j}', exist_ok=True)
|
name = get_name(line=line, candidate=j)
|
||||||
audio_cache[f"candidate_{j}/result_{line}.wav"] = {
|
audio_cache[name] = {
|
||||||
'audio': g,
|
'audio': g,
|
||||||
'text': cut_text,
|
'text': cut_text,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
audio_cache[f"result_{line}.wav"] = {
|
name = get_name(line=line)
|
||||||
|
audio_cache[name] = {
|
||||||
'audio': gen,
|
'audio': gen,
|
||||||
'text': cut_text,
|
'text': cut_text,
|
||||||
}
|
}
|
||||||
|
@ -173,7 +192,7 @@ def generate(
|
||||||
audio = volume_adjust(audio)
|
audio = volume_adjust(audio)
|
||||||
|
|
||||||
audio_cache[k]['audio'] = audio
|
audio_cache[k]['audio'] = audio
|
||||||
torchaudio.save(f'{outdir}/{k}', audio, args.output_sample_rate)
|
torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate)
|
||||||
|
|
||||||
|
|
||||||
output_voice = None
|
output_voice = None
|
||||||
|
@ -182,30 +201,29 @@ def generate(
|
||||||
audio_clips = []
|
audio_clips = []
|
||||||
for line in range(len(texts)):
|
for line in range(len(texts)):
|
||||||
if isinstance(gen, list):
|
if isinstance(gen, list):
|
||||||
audio = audio_cache[f'candidate_{candidate}/result_{line}.wav']['audio']
|
name = get_name(line=line, candidate=candidate)
|
||||||
|
audio = audio_cache[name]['audio']
|
||||||
else:
|
else:
|
||||||
audio = audio_cache[f'result_{line}.wav']['audio']
|
name = get_name(line=line)
|
||||||
|
audio = audio_cache[name]['audio']
|
||||||
audio_clips.append(audio)
|
audio_clips.append(audio)
|
||||||
|
|
||||||
|
name = get_name(combined=True)
|
||||||
audio = torch.cat(audio_clips, dim=-1)
|
audio = torch.cat(audio_clips, dim=-1)
|
||||||
torchaudio.save(f'{outdir}/combined_{candidate}.wav', audio, args.output_sample_rate)
|
torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate)
|
||||||
|
|
||||||
audio = audio.squeeze(0).cpu()
|
audio = audio.squeeze(0).cpu()
|
||||||
audio_cache[f'combined_{candidate}.wav'] = {
|
audio_cache[name] = {
|
||||||
'audio': audio,
|
'audio': audio,
|
||||||
'text': cut_text,
|
'text': cut_text,
|
||||||
}
|
}
|
||||||
|
|
||||||
if output_voice is None:
|
if output_voice is None:
|
||||||
output_voice = f'{outdir}/combined_{candidate}.wav'
|
output_voice = f'{outdir}/{voice}_{name}.wav'
|
||||||
# output_voice = audio
|
# output_voice = audio
|
||||||
else:
|
else:
|
||||||
if isinstance(gen, list):
|
name = get_name()
|
||||||
output_voice = f'{outdir}/candidate_0/result_0.wav'
|
output_voice = f'{outdir}/{voice}_{name}.wav'
|
||||||
#output_voice = gen[0]
|
|
||||||
else:
|
|
||||||
output_voice = f'{outdir}/result_0.wav'
|
|
||||||
#output_voice = gen
|
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
'text': text,
|
'text': text,
|
||||||
|
@ -231,7 +249,7 @@ def generate(
|
||||||
'time': time.time()-start_time,
|
'time': time.time()-start_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(f'{outdir}/input.json', 'w', encoding="utf-8") as f:
|
with open(f'{outdir}/input_{idx}.json', 'w', encoding="utf-8") as f:
|
||||||
f.write(json.dumps(info, indent='\t') )
|
f.write(json.dumps(info, indent='\t') )
|
||||||
|
|
||||||
if voice is not None and conditioning_latents is not None:
|
if voice is not None and conditioning_latents is not None:
|
||||||
|
@ -242,7 +260,7 @@ def generate(
|
||||||
for path in audio_cache:
|
for path in audio_cache:
|
||||||
info['text'] = audio_cache[path]['text']
|
info['text'] = audio_cache[path]['text']
|
||||||
|
|
||||||
metadata = music_tag.load_file(f"{outdir}/{path}")
|
metadata = music_tag.load_file(f"{outdir}/{voice}_{path}.wav")
|
||||||
metadata['lyrics'] = json.dumps(info)
|
metadata['lyrics'] = json.dumps(info)
|
||||||
metadata.save()
|
metadata.save()
|
||||||
|
|
||||||
|
@ -389,6 +407,9 @@ def check_for_updates():
|
||||||
def reload_tts():
|
def reload_tts():
|
||||||
tts = setup_tortoise()
|
tts = setup_tortoise()
|
||||||
|
|
||||||
|
def cancel_generate():
|
||||||
|
tortoise.api.STOP_SIGNAL = True
|
||||||
|
|
||||||
def update_voices():
|
def update_voices():
|
||||||
return gr.Dropdown.update(choices=sorted(os.listdir(get_voice_dir())) + ["microphone"])
|
return gr.Dropdown.update(choices=sorted(os.listdir(get_voice_dir())) + ["microphone"])
|
||||||
|
|
||||||
|
@ -574,7 +595,7 @@ def setup_gradio():
|
||||||
usedSeed = gr.Textbox(label="Seed", placeholder="0", interactive=False)
|
usedSeed = gr.Textbox(label="Seed", placeholder="0", interactive=False)
|
||||||
|
|
||||||
submit = gr.Button(value="Generate")
|
submit = gr.Button(value="Generate")
|
||||||
#stop = gr.Button(value="Stop")
|
stop = gr.Button(value="Stop")
|
||||||
with gr.Tab("Utilities"):
|
with gr.Tab("Utilities"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
@ -676,7 +697,7 @@ def setup_gradio():
|
||||||
if args.check_for_updates:
|
if args.check_for_updates:
|
||||||
webui.load(check_for_updates)
|
webui.load(check_for_updates)
|
||||||
|
|
||||||
#stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event])
|
stop.click(fn=cancel_generate, inputs=None, outputs=None, cancels=[submit_event])
|
||||||
|
|
||||||
|
|
||||||
webui.queue(concurrency_count=args.concurrency_count)
|
webui.queue(concurrency_count=args.concurrency_count)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user