From 3acc54df22f0c46d59f3c1a6be43d922561f39f4 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 15 Jul 2024 19:59:48 -0500 Subject: [PATCH] allow loading a different model within the web ui (apparently I did not have the web UI in the documentation) --- README.md | 24 +++++++++++++- vall_e/config.py | 30 +++++++++++------ vall_e/data.py | 6 ++-- vall_e/engines/__init__.py | 2 +- vall_e/inference.py | 28 +++++++++------- vall_e/models/ar_nar.py | 11 ++----- vall_e/models/base.py | 19 ++++------- vall_e/models/nar.py | 9 +---- vall_e/webui.py | 67 +++++++++++++++++++++++++++++++------- 9 files changed, 129 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index fe85dea..683f7f3 100755 --- a/README.md +++ b/README.md @@ -20,7 +20,6 @@ Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install g I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`. - ## Pre-Trained Model > [!NOTE] @@ -196,11 +195,34 @@ And some experimental sampling flags you can use too (your mileage will ***defin + **!**NOTE**!**: This is incompatible with beam search sampling (for the meantime at least). * `--mirostat-eta`: (AR only) the "learning rate" during mirostat sampling applied to the maximum surprise. +### Web UI + +A Gradio-based web UI is accessible by running `python3 -m vall_e.webui`. You can, optionally, pass: + +* `--yaml=./path/to/your/config.yaml`: will load the targeted YAML +* `--listen 0.0.0.0:7860`: will set the web UI to listen to all IPs at port 7860. Replace the IP and Port to your preference. + +#### Inference + +Synthesizing speech is simple: + +* `Input Prompt`: The guiding text prompt. Each new line will be it's own generated audio to be stitched together at the end. +* `Audio Input`: The reference audio for the synthesis. Under Gradio, you can trim your clip accordingly, but leaving it as-is works fine. +* `Output`: The resultant audio. +* `Inference`: Button to start generating the audio. + +All the additional knobs have a description that can be correlated to the above CLI flags. + +#### Settings + +So far, this only allows you to load a different model without needing to restart. The previous model should seamlessly unload, and the new one will load in place. + ## To-Do * [x] train and release a serviceable model for finetuning against. * [ ] train and release a ***good*** zero-shot model. - this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now. +* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning) * [x] ~~explore alternative setups, like a NAR-only model~~ - the current experiment of an AR length-predictor + NAR for the rest seems to fall apart... * [x] ~~explore better sampling techniques~~ diff --git a/vall_e/config.py b/vall_e/config.py index b57ba35..9d13c16 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -210,11 +210,10 @@ class ModelExperimentalSettings: # I really need to clean this up @dataclass() class Model: - name: str = "" # vanity name for the model + name: str = "ar+nar" # vanity name for the model version: int = 5 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding, 3+ = additional embeddings size: str | dict = "full" # preset string or explicitly defined dimensionality - resp_levels: int = 1 # RVQ-bin levels this model targets for outputs - prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt + resp_levels: int = 8 # RVQ-bin levels this model supports tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") (unused) langs: int = 1 # defined languages (semi-unused) tones: int = 1 # defined tones (unsued) @@ -238,7 +237,10 @@ class Model: @property def max_levels(self): - return max(self.prom_levels, self.resp_levels) + # return RVQ level range + if self.experimental is not None and self.experimental.rvq_level_range: + return self.experimental.rvq_level_range[-1] + return self.resp_levels @property # required for fp8 as the lengths needs to be divisible by 8 @@ -626,7 +628,7 @@ class Inference: use_encodec: bool = True use_dac: bool = True - @cached_property + @property def dtype(self): if self.weight_dtype == "float16": return torch.float16 @@ -651,7 +653,7 @@ class Optimizations: optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation) bitsandbytes: bool = False # use bitsandbytes - dadaptation: bool = True # use dadaptation optimizer + dadaptation: bool = False # use dadaptation optimizer bitnet: bool = False # use bitnet fp8: bool = False # use fp8 @@ -671,7 +673,8 @@ class Config(BaseConfig): bitsandbytes: dict | list | None = None # deprecated optimizations: Optimizations = field(default_factory=lambda: Optimizations) - tokenizer: str = "./tokenizer.json" + tokenizer: str | None = None + tokenizer_path: str = "./tokenizer.json" sample_rate: int = 24_000 variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss @@ -760,14 +763,21 @@ class Config(BaseConfig): self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ] self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ] + # do cleanup for model in self.models: if not isinstance( model, dict ): continue + if "prom_levels" in model: + del model["prom_levels"] + + if "interleave" in model: + del model["interleave"] + if "audio_embedding_sums" not in model: continue - if not model["experimental"]: + if "experimental" not in model or not model["experimental"]: model["experimental"] = {} model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums") @@ -837,9 +847,9 @@ class Config(BaseConfig): try: from transformers import PreTrainedTokenizerFast - tokenizer_path = cfg.rel_path / cfg.tokenizer + tokenizer_path = cfg.rel_path / cfg.tokenizer_path if not tokenizer_path.exists(): - tokenizer_path = Path("./data/") / cfg.tokenizer + tokenizer_path = Path("./data/") / cfg.tokenizer_path cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path)) except Exception as e: cfg.tokenizer = NaiveTokenizer() diff --git a/vall_e/data.py b/vall_e/data.py index b2537c7..f4ea5e3 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -961,8 +961,8 @@ class Dataset(_Dataset): """ # trim to fit to requested prom/resps levels - proms = proms[:, :cfg.model.prom_levels] - resps = resps[:, :cfg.model.prom_levels] + proms = proms[:, :cfg.model.resp_levels] + resps = resps[:, :cfg.model.resp_levels] return dict( @@ -1466,7 +1466,7 @@ if __name__ == "__main__": if task not in cfg.dataset.tasks_list: continue - print(text, task, cfg.model.prom_levels) + print(text, task, cfg.model.resp_levels) print( proms.shape, resps.shape ) tokens = 0 diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index f686e3e..2d97cc2 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -37,7 +37,7 @@ def load_engines(training=True): optimizer = None lr_scheduler = None - inferencing = cfg.mode == "inferencing" or not model.config.training + inferencing = cfg.mode == "inferencing" or not model.config.training or not training backend = cfg.inference.backend if inferencing else cfg.trainer.backend dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype amp = cfg.inference.amp if inferencing else cfg.trainer.amp diff --git a/vall_e/inference.py b/vall_e/inference.py index 3bbff95..6b9a7e2 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -8,10 +8,10 @@ from einops import rearrange from pathlib import Path from .emb import g2p, qnt -from .emb.qnt import trim, trim_random +from .emb.qnt import trim, trim_random, unload_model from .utils import to_device, set_seed, wrapper as ml -from .config import cfg +from .config import cfg, Config from .models import get_models from .engines import load_engines, deepspeed_available from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize @@ -22,11 +22,15 @@ if deepspeed_available: class TTS(): def __init__( self, config=None, device=None, amp=None, dtype=None ): self.loading = True - - self.input_sample_rate = 24000 - self.output_sample_rate = 24000 + self.load_config( config=config, device=device, amp=amp, dtype=dtype ) + self.load_model() + + self.loading = False + + def load_config( self, config=None, device=None, amp=None, dtype=None ): if config: + print("Loading YAML:", config) cfg.load_yaml( config ) try: @@ -52,20 +56,20 @@ class TTS(): self.device = device self.dtype = cfg.inference.dtype self.amp = amp + - self.symmap = None - + def load_model( self ): + load_engines.cache_clear() + unload_model() + self.engines = load_engines(training=False) for name, engine in self.engines.items(): if self.dtype != torch.int8: engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32) self.engines.eval() - - if self.symmap is None: - self.symmap = get_phone_symmap() - - self.loading = False + self.symmap = get_phone_symmap() + print("Loaded model") def encode_text( self, text, language="en" ): # already a tensor, return it diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 754194b..c3bf797 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -33,12 +33,6 @@ class AR_NAR(Base): def causal(self): return "ar" in self.capabilities - @property - def n_prom_levels(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.prom_levels - return cfg.model.prom_levels - @property def n_resp_levels(self) -> int: if hasattr(self, "config") and self.config: @@ -197,7 +191,7 @@ class AR_NAR(Base): # is NAR if max_levels == 0: - max_levels = self.n_resp_levels - 1 + max_levels = self.n_max_levels - 1 # expand if given a raw 1D tensor for i, resp in enumerate(resps_list): @@ -373,7 +367,6 @@ def example_usage(): # mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it) """ if "mamba" in cfg.model.arch_type: - cfg.model.prom_levels = 1 cfg.model.resp_levels = 1 """ # cfg.model.loss_factors = {} @@ -383,7 +376,7 @@ def example_usage(): def _load_quants(path) -> Tensor: qnt = np.load(path, allow_pickle=True)[()] - return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) + return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16) qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 013cac7..3be45aa 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -123,7 +123,7 @@ class AudioEmbedding_Old(nn.Module): ): super().__init__() # array of embeddings - # proms are [0, prom_levels] + # proms are [0, resp_levels] # resp are split to where [0] is for the AR, and [1:] are reserved for NAR self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) # weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this) @@ -154,7 +154,7 @@ class AudioEmbedding(nn.Module): ): super().__init__() # array of embeddings - # proms are [0, prom_levels] + # proms are [0, resp_levels] # resp are split to where [0] is for the AR, and [1:] are reserved for NAR # + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens]) @@ -282,10 +282,6 @@ class Base(nn.Module): def causal(self) -> bool: raise NotImplementedError - @property - def n_prom_levels(self) -> int: - raise NotImplementedError - @property def n_resp_levels(self) -> int: raise NotImplementedError @@ -368,7 +364,6 @@ class Base(nn.Module): self.l_padding = l_padding - n_prom_tokens = n_audio_tokens arch_type = self.config.arch_type if self.config is not None else "llama" self.arch_type = arch_type @@ -397,14 +392,14 @@ class Base(nn.Module): self.len_emb = None if self.version == 1: # legacy - n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom - self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) + n_audio_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom + self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model) self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic) elif self.version < 5: # [1024] * 8 self.proms_emb = AudioEmbedding_Old( - [n_prom_tokens] * self.n_prom_levels, d_model, - levels=self.n_prom_levels if self.version > 3 else None, + [n_audio_tokens] * self.n_resp_levels, d_model, + levels=self.n_resp_levels if self.version > 3 else None, ) # [1024 + STOP] + [1024] * 8 self.resps_emb = AudioEmbedding_Old( @@ -413,7 +408,7 @@ class Base(nn.Module): ) else: self.proms_emb = AudioEmbedding( - [n_prom_tokens] * self.n_prom_levels, d_model, + [n_audio_tokens] * self.n_resp_levels, d_model, sums=audio_embedding_sums, external_mode=audio_embedding_mode, ) diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 6098778..cbd91dd 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -31,12 +31,6 @@ class NAR(Base): def causal(self): return "len" in self.capabilities - @property - def n_prom_levels(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.prom_levels - return cfg.model.prom_levels - @property def n_resp_levels(self) -> int: if hasattr(self, "config") and self.config: @@ -309,7 +303,6 @@ def example_usage(): # mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it) """ if "mamba" in cfg.model.arch_type: - cfg.model.prom_levels = 1 cfg.model.resp_levels = 1 """ # cfg.model.loss_factors = {} @@ -319,7 +312,7 @@ def example_usage(): def _load_quants(path) -> Tensor: qnt = np.load(path, allow_pickle=True)[()] - return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) + return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16) qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") diff --git a/vall_e/webui.py b/vall_e/webui.py index b2e2bc8..d83f8cb 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -19,6 +19,7 @@ tts = None layout = {} layout["inference"] = {} layout["training"] = {} +layout["settings"] = {} for k in layout.keys(): layout[k]["inputs"] = { "progress": None } @@ -37,14 +38,42 @@ def gradio_wrapper(inputs): return decorated class timer: - def __enter__(self): - self.start = perf_counter() - return self + def __init__(self, msg="Elapsed time:"): + self.msg = msg - def __exit__(self, type, value, traceback): - print(f'[{datetime.now().isoformat()}] Elapsed time: {(perf_counter() - self.start):.3f}s') + def __enter__(self): + self.start = perf_counter() + return self -def init_tts(restart=False): + def __exit__(self, type, value, traceback): + msg = f'{self.msg} {(perf_counter() - self.start):.3f}s' + + gr.Info(msg) + print(f'[{datetime.now().isoformat()}] {msg}') + +# returns a list of models, assuming the models are placed under ./training/ or ./models/ +def get_model_paths( paths=[Path("./training/"), Path("./models/")] ): + yamls = [] + + for path in paths: + if not path.exists(): + continue + + for yaml in path.glob("**/*.yaml"): + if "/logs/" in str(yaml): + continue + + yamls.append( yaml ) + + return yamls + +# +def load_model( yaml ): + gr.Info(f"Loading: {yaml}") + init_tts( yaml=Path(yaml), restart=True ) + gr.Info(f"Loaded model") + +def init_tts(yaml=None, restart=False): global tts if tts is not None: @@ -53,13 +82,13 @@ def init_tts(restart=False): del tts parser = argparse.ArgumentParser(allow_abbrev=False) - parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too + parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--amp", action="store_true") parser.add_argument("--dtype", type=str, default="auto") args, unknown = parser.parse_known_args() - tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp ) + tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp ) return tts @gradio_wrapper(inputs=layout["inference"]["inputs"].keys()) @@ -78,7 +107,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--language", type=str, default="en") parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second)) - parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"]) + parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"]) parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"]) parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"]) @@ -99,7 +128,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): raise ValueError("No reference audio provided.") tts = init_tts() - with timer() as t: + + gr.Info("Inferencing...") + with timer("Inferenced in") as t: wav, sr = tts.inference( text=args.text, language=args.language, @@ -169,6 +200,7 @@ def get_random_prompt(): # setup args parser = argparse.ArgumentParser(allow_abbrev=False) +parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--listen", default=None, help="Path for Gradio to listen on") parser.add_argument("--share", action="store_true") parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ) @@ -208,7 +240,7 @@ with ui: with gr.Column(scale=7): with gr.Row(): layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") - layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") + #layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.") with gr.Row(): layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") @@ -249,6 +281,19 @@ with ui: ) """ + with gr.Tab("Settings"): + with gr.Row(): + with gr.Column(scale=7): + layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model") + with gr.Column(scale=1): + layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model") + + layout["settings"]["buttons"]["load"].click( + fn=load_model, + inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None], + outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None], + ) + if os.path.exists("README.md") and args.render_markdown: md = open("README.md", "r", encoding="utf-8").read() # remove HF's metadata