diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 4c3ea0a..d355c5e 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -13,6 +13,7 @@ def main(): parser.add_argument("text") parser.add_argument("references", type=path_list, default=None) parser.add_argument("--language", type=str, default="en") + parser.add_argument("--task", type=str, default="tts") parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None) @@ -53,6 +54,7 @@ def main(): text=args.text, references=args.references, language=args.language, + task=args.task, out_path=args.out_path, input_prompt_length=args.input_prompt_length, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, diff --git a/vall_e/data.py b/vall_e/data.py index db12759..fa4c5de 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1053,11 +1053,13 @@ class Dataset(_Dataset): task, ] - # Base TTS ( => ) + # Base STT ( => ) elif task == "stt": # easier to just keep it instead of wrangling around trying to remove it # it might also help to provide a guidance prompt but who knows right now - proms = self.sample_prompts(spkr_name, ignore=path) + proms = [ + task + ] # noise suppression (? => ) # speech removal (? => ) diff --git a/vall_e/inference.py b/vall_e/inference.py index 8820531..f42e230 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -128,6 +128,7 @@ class TTS(): text, references, language="en", + task="tts", # max_ar_steps=6 * cfg.dataset.frames_per_second, max_nar_levels=7, @@ -181,6 +182,40 @@ class TTS(): set_seed(seed) + if task == "stt": + resp = self.encode_audio( references ) + lang = self.encode_lang( language ) + + reps = to_device(reps, device=self.device, dtype=torch.int16) + lang = to_device(lang, device=self.device, dtype=torch.uint8) + + with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): + if model_ar is not None: + text_list = model_ar( + text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], max_steps=max_ar_steps, + sampling_temperature=ar_temp, + sampling_min_temperature=min_ar_temp, + sampling_top_p=top_p, sampling_top_k=top_k, + sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, + sampling_length_penalty=length_penalty, + sampling_beam_width=beam_width, + sampling_mirostat_tau=mirostat_tau, + sampling_mirostat_eta=mirostat_eta, + sampling_dry_multiplier=dry_multiplier, + sampling_dry_base=dry_base, + sampling_dry_allowed_length=dry_allowed_length, + + disable_tqdm=not tqdm, + ) + else: + raise Exception("!") + + text_list = [ cfg.tokenizer.decode( text ) for text in text_list ] + print( text_list ) + + return text_list[0] + + for line in lines: if out_path is None: output_dir = Path("./data/results/") diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 1a6323b..574bc42 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -357,11 +357,14 @@ def example_usage(): from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio from ..engines import Engine, Engines from ..utils import wrapper as ml + from ..utils import setup_logging import numpy as np import re + setup_logging() device = "cuda" + # mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it) """ diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index a253848..8bb98d2 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -120,7 +120,7 @@ class TqdmLoggingHandler(logging.Handler): self.handleError(record) @global_leader_only -def setup_logging(log_dir: str | Path | None = "log", log_level="info"): +def setup_logging(log_dir: str | Path | None = None, log_level="info"): handlers = [] #stdout_handler = StreamHandler() diff --git a/vall_e/webui.py b/vall_e/webui.py index 0ca05e0..b6004df 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -13,7 +13,7 @@ from pathlib import Path from .inference import TTS, cfg from .train import train -from .utils import get_devices +from .utils import get_devices, setup_logging tts = None @@ -338,6 +338,8 @@ with ui: gr.Markdown(md) def start( lock=True ): + setup_logging() + ui.queue(max_size=8) ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)