From 5ac119a6e7b1b379aff581f90a632f7204ffc2ca Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 9 Sep 2023 16:17:20 -0500 Subject: [PATCH] added light web UI (need to port the telemetry disabling bandaids from aivc) --- README.md | 6 ++ scripts/stitch_embs.py | 28 ++++++++ vall_e/__main__.py | 6 +- vall_e/engines/base.py | 8 ++- vall_e/inference.py | 40 ++++++++---- vall_e/utils/trainer.py | 2 +- vall_e/utils/utils.py | 9 --- vall_e/webui.py | 137 ++++++++++++++++++++++++++++++++++++++++ 8 files changed, 208 insertions(+), 28 deletions(-) create mode 100644 scripts/stitch_embs.py create mode 100644 vall_e/webui.py diff --git a/README.md b/README.md index 830a95f..e029b91 100755 --- a/README.md +++ b/README.md @@ -26,6 +26,12 @@ I've tested this repo under Python versions `3.10.9` and `3.11.3`. ## Try Me +### Online + +A HuggingFace space hosting the code and models can be found [here](https://huggingface.co/spaces/ecker/vall-e). + +### Local + To quickly try it out, you can choose between the following modes: * AR only: `python -m vall_e.models.ar yaml="./data/config.yaml"` diff --git a/scripts/stitch_embs.py b/scripts/stitch_embs.py new file mode 100644 index 0000000..d63ee37 --- /dev/null +++ b/scripts/stitch_embs.py @@ -0,0 +1,28 @@ +import torch + +action = None +# copies the resp_embs from a given AR and NAR into an AR as a base to convert into an AR+NAR monolithic odel +if action == "merge_resp_embs": + src_ar = torch.load("./data/source-ar.pth", map_location="cpu") + src_nar = torch.load("./data/source-nar.pth", map_location="cpu") + # copies all weights from the AR since the AR is usually "better", might need to experiment more with using a NAR as the base + dst = torch.load("./data/source-ar.pth", map_location="cpu") + + # copy resps_emb to layer 0 from AR + dst['module']['resps_emb.weight'][:0, :, :] = src_ar['module']['resps_emb.weight'] + # copy resps_emb to remaining layers from NAR + dst['module']['resps_emb.weight'][1:, :-1, :] = src_nar['module']['resps_emb.weight'] +# copies an existing AR+NAR monolithic model's resp_emb onto an AR +elif action == "copy_resps_emb": + src = torch.load("./data/source.pth", map_location="cpu") + dst = torch.load("./data/destination.pth", map_location="cpu") + dst['module']['resps_emb.weight'] = src['module']['resps_emb.weight'] +elif action == "extend_resps_emb": + dst = torch.load("./data/destination.pth", map_location="cpu") + dst['module']['resps_emb.weight'] = dst['module']['resps_emb.weight'].expand(4, -1, -1) + dst['module']['resps_emb.weight'][1:] = torch.randn(3, 1025, 1024) + +else + raise Exception(f"invalid action: {action}") + +torch.save(dst, './data/fp32.pth') \ No newline at end of file diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 834afe7..2709ace 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -26,10 +26,12 @@ def main(): parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) parser.add_argument("--length-penalty", type=float, default=0.0) - parser.add_argument("--device", default="cuda") + parser.add_argument("--device", type=str, default=None) + parser.add_argument("--amp", action="store_true") + parser.add_argument("--dtype", type=str, default=None) args = parser.parse_args() - tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device ) + tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp ) tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty ) if __name__ == "__main__": diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 8eb883d..b630fb1 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -64,6 +64,8 @@ class Engine(): self.global_samples = 0 self.tokens_processed = 0 + self._frozen_params = set() + def freeze(self, freeze_all=True): # set to freeze if self._cfg is None or not hasattr(self._cfg, "frozen_params"): @@ -134,7 +136,7 @@ class Engine(): if not load_path.exists(): return - state = torch.load(load_path) + state = torch.load(load_path, map_location=torch.device(cfg.device)) self.global_steps = state['global_step'] self.micro_steps = state['micro_step'] self.global_samples = state['global_samples'] @@ -145,10 +147,10 @@ class Engine(): load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state if load_optimizer_states: - self.optimizer.load_state_dict(state['optimizer']) + self.optimizer.load_state_dict(state['optimizer'], map_location=torch.device(cfg.device)) if load_lr_scheduler_states: - self.lr_scheduler.load_state_dict(state['lr_scheduler']) + self.lr_scheduler.load_state_dict(state['lr_scheduler'], map_location=torch.device(cfg.device)) def eval(self): return self.module.eval() diff --git a/vall_e/inference.py b/vall_e/inference.py index 707c889..293e705 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -16,10 +16,9 @@ from .train import load_engines from .data import get_phone_symmap, _load_quants class TTS(): - def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ): + def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ): self.loading = True - self.device = device - + self.input_sample_rate = 24000 self.output_sample_rate = 24000 @@ -32,9 +31,24 @@ class TTS(): except Exception as e: pass + if amp is None: + amp = cfg.inference.amp + if dtype is None: + dtype = cfg.inference.dtype + if device is None: + device = cfg.device + cfg.mode = "inferencing" - cfg.trainer.load_module_only = True - + cfg.device = device + cfg.trainer.load_state_dict = True + cfg.trainer.backend = "local" + cfg.trainer.weight_dtype = dtype + cfg.inference.weight_dtype = dtype + + self.device = device + self.dtype = cfg.inference.dtype + self.amp = amp + self.symmap = None if ar_ckpt and nar_ckpt: self.ar_ckpt = ar_ckpt @@ -50,7 +64,7 @@ class TTS(): if "module" in state: state = state['module'] self.ar.load_state_dict(state) - self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32) + self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) self.nar = self.ar elif name.startswith("ar"): self.ar = model @@ -60,7 +74,7 @@ class TTS(): if "module" in state: state = state['module'] self.ar.load_state_dict(state) - self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32) + self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) elif name.startswith("nar"): self.nar = model state = torch.load(self.nar_ckpt) @@ -69,7 +83,7 @@ class TTS(): if "module" in state: state = state['module'] self.nar.load_state_dict(state) - self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32) + self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) else: self.load_models() @@ -85,12 +99,12 @@ class TTS(): engines = load_engines() for name, engine in engines.items(): if name[:6] == "ar+nar": - self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32) + self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32) self.nar = self.ar elif name[:2] == "ar": - self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32) + self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32) elif name[:3] == "nar": - self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32) + self.nar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32) def encode_text( self, text, language="en" ): # already a tensor, return it @@ -135,12 +149,12 @@ class TTS(): prom = to_device(prom, self.device).to(torch.int16) phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16) - with torch.autocast(self.device, dtype=cfg.inference.dtype, enabled=cfg.inference.amp): + with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty) resps_list = [r.unsqueeze(-1) for r in resps_list] resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty) - wav, sr = qnt.decode_to_file(resps_list[0], out_path) + wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device) return (wav, sr) diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index d87df60..e16c004 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -101,7 +101,7 @@ def load_engines(invert=False): if cfg.trainer.load_state_dict or not model._cfg.training: load_path = cfg.ckpt_dir / name / "fp32.pth" - state = torch.load(load_path) + state = torch.load(load_path, map_location=torch.device(cfg.device)) # exporting the model from the zero_to_fp32.py exports the actual module's dict # exporting with vall_e.export exports the state dict under .module if "module" in state: diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 86214d4..988f595 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -17,21 +17,12 @@ from torch import Tensor, nn from tqdm.auto import tqdm from typing import Callable, TypeVar, overload -try: - from deepspeed.runtime.utils import empty_cache -except Exception as e: - print(str(e)) - def empty_cache(): - ... - T = TypeVar("T") def do_gc(): gc.collect() torch.cuda.empty_cache() - empty_cache() - def flatten_dict(d): records = pd.json_normalize(d).to_dict(orient="records") return records[0] if records else {} diff --git a/vall_e/webui.py b/vall_e/webui.py new file mode 100644 index 0000000..b0bb4bd --- /dev/null +++ b/vall_e/webui.py @@ -0,0 +1,137 @@ +import os +import re +import argparse +import tempfile +import functools + +import gradio as gr + +from pathlib import Path + +from .inference import TTS + +tts = None + +layout = {} +layout["inference"] = {} +layout["inference"]["inputs"] = { + "progress": None +} +layout["inference"]["outputs"] = {} +layout["inference"]["buttons"] = {} + +# there's got to be a better way to go about this +def gradio_wrapper(inputs): + def decorated(fun): + @functools.wraps(fun) + def wrapped_function(*args, **kwargs): + for i, key in enumerate(inputs): + kwargs[key] = args[i] + return fun(**kwargs) + return wrapped_function + return decorated + +def init_tts(restart=False): + global tts + + if tts is not None: + if not restart: + return tts + del tts + + parser = argparse.ArgumentParser(allow_abbrev=False) + parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too + parser.add_argument("--ar-ckpt", type=Path, default=None) + parser.add_argument("--nar-ckpt", type=Path, default=None) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--amp", action="store_true") + parser.add_argument("--dtype", type=str, default="float32") + args, unknown = parser.parse_known_args() + + tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp ) + return tts + +@gradio_wrapper(inputs=layout["inference"]["inputs"].keys()) +def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): + parser = argparse.ArgumentParser(allow_abbrev=False) + parser.add_argument("--text", type=str, default=kwargs["text"]) + parser.add_argument("--references", type=str, default=kwargs["reference"]) + parser.add_argument("--max-ar-steps", type=int, default=kwargs["steps"]) + parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) + parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"]) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--top-k", type=int, default=0) + parser.add_argument("--repetition-penalty", type=float, default=1.0) + parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) + parser.add_argument("--length-penalty", type=float, default=0.0) + args, unknown = parser.parse_known_args() + + tmp = tempfile.NamedTemporaryFile(suffix='.wav') + + tts = init_tts() + wav, sr = tts.inference( + text=args.text, + references=[args.references.split(";")], + out_path=tmp.name, + max_ar_steps=args.max_ar_steps, + ar_temp=args.ar_temp, + nar_temp=args.nar_temp, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + repetition_penalty_decay=args.repetition_penalty_decay, + length_penalty=args.length_penalty + ) + + wav = wav.squeeze(0).cpu().numpy() + return (sr, wav) + +ui = gr.Blocks() +with ui: + with gr.Tab("Inference"): + with gr.Row(): + with gr.Column(): + layout["inference"]["inputs"]["text"] = gr.Textbox(lines=4, value="Your prompt here", label="Input Prompt") + with gr.Row(): + with gr.Column(): + layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", source="upload", type="filepath") + with gr.Column(): + layout["inference"]["inputs"]["steps"] = gr.Slider(value=450, minimum=2, maximum=1024, step=1, label="Steps") + layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)") + layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)") + with gr.Column(): + layout["inference"]["buttons"]["start"] = gr.Button(value="Inference") + # layout["inference"]["stop"] = gr.Button(value="Stop") + layout["inference"]["outputs"]["output"] = gr.Audio(label="Output") + + layout["inference"]["buttons"]["start"].click( + fn=do_inference, + inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None], + outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None] + ) + +parser = argparse.ArgumentParser(allow_abbrev=False) +parser.add_argument("--listen", default=None, help="Path for Gradio to listen on") +parser.add_argument("--share", action="store_true") +args, unknown = parser.parse_known_args() + +args.listen_host = None +args.listen_port = None +args.listen_path = None +if args.listen: + try: + match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0] + + args.listen_host = match[0] if match[0] != "" else "127.0.0.1" + args.listen_port = match[1] if match[1] != "" else None + args.listen_path = match[2] if match[2] != "" else "/" + except Exception as e: + pass + +if args.listen_port is not None: + args.listen_port = int(args.listen_port) + if args.listen_port == 0: + args.listen_port = None + +ui.queue(max_size=8) +ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port) \ No newline at end of file