From 1a26f789a59ecbbec0679ef07ee5fe8dd0e458d7 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 12 Jan 2025 21:52:49 -0600 Subject: [PATCH] added option to playback audio directly, removed no-phonemize option since I swear it worked in testing but it doesn't actually work --- setup.py | 1 + vall_e/__main__.py | 2 ++ vall_e/inference.py | 16 +++++++++++++++- vall_e/models/ar_nar.py | 5 ++++- vall_e/webui.py | 11 ++++++++++- 5 files changed, 32 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 33294a4..da9d127 100755 --- a/setup.py +++ b/setup.py @@ -80,6 +80,7 @@ setup( "gradio", "nltk", # for parsing text inputs down to pieces "langdetect", # for detecting the language of a text + "sounddevice", # for raw playback ], extras_require = { "all": [ diff --git a/vall_e/__main__.py b/vall_e/__main__.py index ab79961..f4f22f6 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -74,6 +74,7 @@ def main(): parser.add_argument("--amp", action="store_true") parser.add_argument("--dtype", type=str, default=None) parser.add_argument("--attention", type=str, default=None) + parser.add_argument("--play", action="store_true") args = parser.parse_args() config = None @@ -122,6 +123,7 @@ def main(): task=args.task, modality=args.modality, out_path=args.out_path, + play=args.play, input_prompt_length=args.input_prompt_length, load_from_artifact=args.load_from_artifact, diff --git a/vall_e/inference.py b/vall_e/inference.py index ec03635..1f2e3c0 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -29,6 +29,11 @@ from .models import download_model, DEFAULT_MODEL_PATH if deepspeed_available: import deepspeed +try: + import sounddevice as sd +except Exception as e: + sd = None + class TTS(): def __init__( self, config=None, lora=None, device=None, amp=None, dtype=None, attention=None ): self.loading = True @@ -110,7 +115,7 @@ class TTS(): return torch.tensor( tokens ) if not phonemize: - return torch.tensor( text_tokenize( content ) ) + return torch.tensor( text_tokenize( text ) ) return torch.tensor( tokenize( g2p.encode(text, language=language) ) ) @@ -352,8 +357,12 @@ class TTS(): text_language=None, task="tts", out_path=None, + play=False, **sampling_kwargs, ): + if sd is None: + play = False + input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0) modality = sampling_kwargs.pop("modality", "auto") seed = sampling_kwargs.pop("seed", None) @@ -560,6 +569,11 @@ class TTS(): # add utterances wavs.append(wav) + if play: + sd.play(wav.cpu().numpy()[0], sr) + sd.wait() + + # combine all utterances return (torch.concat(wavs, dim=-1), sr) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 4c206e9..05aa0b1 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -716,7 +716,10 @@ class AR_NAR(Base): text_list = [ sequence_list[i] if task in ["phn"] else text_list[i] for i, task in enumerate(task_list) ] raw_text_list = [ sequence_list[i] if task in ["un-phn"] else raw_text_list[i] for i, task in enumerate(task_list) ] else: - text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] + if raw_text_list is not None: + raw_text_list = [ sequence_list[i] if task in text_task else raw_text_list[i] for i, task in enumerate(task_list) ] + else: + text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ] diff --git a/vall_e/webui.py b/vall_e/webui.py index a5dbfcc..ba83886 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -218,6 +218,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--language", type=str, default=kwargs["language"]) parser.add_argument("--text-language", type=str, default=kwargs["text-language"]) parser.add_argument("--no-phonemize", action="store_true") + parser.add_argument("--play", action="store_true") parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"]) parser.add_argument("--context-history", type=int, default=kwargs["context-history"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) @@ -274,7 +275,10 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): args.refine_on_stop = True if kwargs.pop("no-phonemize", False): - args.no_phonemize = False + args.no_phonemize = True + + if kwargs.pop("play", False): + args.play = True if args.split_text_by == "lines": args.split_text_by = "\n" @@ -324,6 +328,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): language=args.language, text_language=args.text_language, task=args.task, + play=args.play, modality=args.modality.lower(), references=args.references.split(";") if args.references is not None else [], **sampling_kwargs, @@ -472,7 +477,11 @@ with ui: with gr.Row(): layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="How to split the text into utterances.", value="sentences") layout["inference_tts"]["inputs"]["context-history"] = gr.Slider(value=0, minimum=0, maximum=4, step=1, label="(Rolling) Context History", info="How many prior lines to serve as the context/prefix (0 to disable).") + """ + with gr.Row(): layout["inference_tts"]["inputs"]["no-phonemize"] = gr.Checkbox(label="No Phonemize", info="Use raw text rather than phonemize the text as the input prompt.") + layout["inference_tts"]["inputs"]["play"] = gr.Checkbox(label="Auto Play", info="Auto play on generation (using sounddevice).") + """ with gr.Tab("Sampler Settings"): with gr.Row(): layout["inference_tts"]["inputs"]["ar-temperature"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR/NAR-len)", info="Adjusts the probabilities in the AR/NAR-len. (0 to greedy* sample)")