diff --git a/tortoise/api.py b/tortoise/api.py index 60dc56b..b7c799a 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -36,6 +36,8 @@ from tortoise.utils.device import get_device, get_device_name, get_device_batch_ pbar = None +STOP_SIGNAL = False + MODELS_DIR = os.environ.get('TORTOISE_MODELS_DIR') MODELS = { 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth', @@ -49,6 +51,11 @@ MODELS = { } def tqdm_override(arr, verbose=False, progress=None, desc=None): + global STOP_SIGNAL + if STOP_SIGNAL: + STOP_SIGNAL = False + raise Exception("Kill signal detected") + if verbose and desc is not None: print(desc) @@ -60,6 +67,7 @@ def download_models(specific_models=None): """ Call to download all the models that Tortoise uses. """ + os.makedirs(MODELS_DIR, exist_ok=True) def show_progress(block_num, block_size, total_size): diff --git a/webui.py b/webui.py index 773db7f..a4a4666 100755 --- a/webui.py +++ b/webui.py @@ -16,6 +16,8 @@ from datetime import datetime from fastapi import FastAPI +import tortoise.api + from tortoise.api import TextToSpeech from tortoise.utils.audio import load_audio, load_voice, load_voices, get_voice_dir from tortoise.utils.text import split_and_recombine_text @@ -124,7 +126,7 @@ def generate( start_time = time.time() - outdir = f"./results/{voice}/{int(start_time)}/" + outdir = f"./results/{voice}/" os.makedirs(outdir, exist_ok=True) audio_cache = {} @@ -140,6 +142,22 @@ def generate( volume_adjust = torchaudio.transforms.Vol(gain=args.output_volume, gain_type="amplitude") if args.output_volume != 1 else None + idx = 0 + for i, file in enumerate(os.listdir(outdir)): + if file[-4:] == ".wav": + idx = idx + 1 + + def get_name(line=0, candidate=0, combined=False): + name = f"{idx}" + if len(texts) > 1: + name = f"{name}_{line}" + if candidates > 1: + name = f"{name}_{candidate}" + + if combined: + return f"{idx}_combined" + return name + for line, cut_text in enumerate(texts): if emotion == "Custom": if prompt.strip() != "": @@ -154,13 +172,14 @@ def generate( if isinstance(gen, list): for j, g in enumerate(gen): - os.makedirs(f'{outdir}/candidate_{j}', exist_ok=True) - audio_cache[f"candidate_{j}/result_{line}.wav"] = { + name = get_name(line=line, candidate=j) + audio_cache[name] = { 'audio': g, 'text': cut_text, } else: - audio_cache[f"result_{line}.wav"] = { + name = get_name(line=line) + audio_cache[name] = { 'audio': gen, 'text': cut_text, } @@ -173,7 +192,7 @@ def generate( audio = volume_adjust(audio) audio_cache[k]['audio'] = audio - torchaudio.save(f'{outdir}/{k}', audio, args.output_sample_rate) + torchaudio.save(f'{outdir}/{voice}_{k}.wav', audio, args.output_sample_rate) output_voice = None @@ -182,30 +201,29 @@ def generate( audio_clips = [] for line in range(len(texts)): if isinstance(gen, list): - audio = audio_cache[f'candidate_{candidate}/result_{line}.wav']['audio'] + name = get_name(line=line, candidate=candidate) + audio = audio_cache[name]['audio'] else: - audio = audio_cache[f'result_{line}.wav']['audio'] + name = get_name(line=line) + audio = audio_cache[name]['audio'] audio_clips.append(audio) + name = get_name(combined=True) audio = torch.cat(audio_clips, dim=-1) - torchaudio.save(f'{outdir}/combined_{candidate}.wav', audio, args.output_sample_rate) + torchaudio.save(f'{outdir}/{voice}_{name}.wav', audio, args.output_sample_rate) audio = audio.squeeze(0).cpu() - audio_cache[f'combined_{candidate}.wav'] = { + audio_cache[name] = { 'audio': audio, 'text': cut_text, } if output_voice is None: - output_voice = f'{outdir}/combined_{candidate}.wav' + output_voice = f'{outdir}/{voice}_{name}.wav' # output_voice = audio else: - if isinstance(gen, list): - output_voice = f'{outdir}/candidate_0/result_0.wav' - #output_voice = gen[0] - else: - output_voice = f'{outdir}/result_0.wav' - #output_voice = gen + name = get_name() + output_voice = f'{outdir}/{voice}_{name}.wav' info = { 'text': text, @@ -231,7 +249,7 @@ def generate( 'time': time.time()-start_time, } - with open(f'{outdir}/input.json', 'w', encoding="utf-8") as f: + with open(f'{outdir}/input_{idx}.json', 'w', encoding="utf-8") as f: f.write(json.dumps(info, indent='\t') ) if voice is not None and conditioning_latents is not None: @@ -242,7 +260,7 @@ def generate( for path in audio_cache: info['text'] = audio_cache[path]['text'] - metadata = music_tag.load_file(f"{outdir}/{path}") + metadata = music_tag.load_file(f"{outdir}/{voice}_{path}.wav") metadata['lyrics'] = json.dumps(info) metadata.save() @@ -389,6 +407,9 @@ def check_for_updates(): def reload_tts(): tts = setup_tortoise() +def cancel_generate(): + tortoise.api.STOP_SIGNAL = True + def update_voices(): return gr.Dropdown.update(choices=sorted(os.listdir(get_voice_dir())) + ["microphone"]) @@ -574,7 +595,7 @@ def setup_gradio(): usedSeed = gr.Textbox(label="Seed", placeholder="0", interactive=False) submit = gr.Button(value="Generate") - #stop = gr.Button(value="Stop") + stop = gr.Button(value="Stop") with gr.Tab("Utilities"): with gr.Row(): with gr.Column(): @@ -676,7 +697,7 @@ def setup_gradio(): if args.check_for_updates: webui.load(check_for_updates) - #stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event]) + stop.click(fn=cancel_generate, inputs=None, outputs=None, cancels=[submit_event]) webui.queue(concurrency_count=args.concurrency_count)