added light web UI (need to port the telemetry disabling bandaids from aivc)

This commit is contained in:
mrq 2023-09-09 16:17:20 -05:00
parent 10c34c5b98
commit 5ac119a6e7
8 changed files with 208 additions and 28 deletions

View File

@ -26,6 +26,12 @@ I've tested this repo under Python versions `3.10.9` and `3.11.3`.
## Try Me
### Online
A HuggingFace space hosting the code and models can be found [here](https://huggingface.co/spaces/ecker/vall-e).
### Local
To quickly try it out, you can choose between the following modes:
* AR only: `python -m vall_e.models.ar yaml="./data/config.yaml"`

28
scripts/stitch_embs.py Normal file
View File

@ -0,0 +1,28 @@
import torch
action = None
# copies the resp_embs from a given AR and NAR into an AR as a base to convert into an AR+NAR monolithic odel
if action == "merge_resp_embs":
src_ar = torch.load("./data/source-ar.pth", map_location="cpu")
src_nar = torch.load("./data/source-nar.pth", map_location="cpu")
# copies all weights from the AR since the AR is usually "better", might need to experiment more with using a NAR as the base
dst = torch.load("./data/source-ar.pth", map_location="cpu")
# copy resps_emb to layer 0 from AR
dst['module']['resps_emb.weight'][:0, :, :] = src_ar['module']['resps_emb.weight']
# copy resps_emb to remaining layers from NAR
dst['module']['resps_emb.weight'][1:, :-1, :] = src_nar['module']['resps_emb.weight']
# copies an existing AR+NAR monolithic model's resp_emb onto an AR
elif action == "copy_resps_emb":
src = torch.load("./data/source.pth", map_location="cpu")
dst = torch.load("./data/destination.pth", map_location="cpu")
dst['module']['resps_emb.weight'] = src['module']['resps_emb.weight']
elif action == "extend_resps_emb":
dst = torch.load("./data/destination.pth", map_location="cpu")
dst['module']['resps_emb.weight'] = dst['module']['resps_emb.weight'].expand(4, -1, -1)
dst['module']['resps_emb.weight'][1:] = torch.randn(3, 1025, 1024)
else
raise Exception(f"invalid action: {action}")
torch.save(dst, './data/fp32.pth')

View File

@ -26,10 +26,12 @@ def main():
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0)
parser.add_argument("--device", default="cuda")
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=None)
args = parser.parse_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty )
if __name__ == "__main__":

View File

@ -64,6 +64,8 @@ class Engine():
self.global_samples = 0
self.tokens_processed = 0
self._frozen_params = set()
def freeze(self, freeze_all=True):
# set to freeze
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
@ -134,7 +136,7 @@ class Engine():
if not load_path.exists():
return
state = torch.load(load_path)
state = torch.load(load_path, map_location=torch.device(cfg.device))
self.global_steps = state['global_step']
self.micro_steps = state['micro_step']
self.global_samples = state['global_samples']
@ -145,10 +147,10 @@ class Engine():
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
if load_optimizer_states:
self.optimizer.load_state_dict(state['optimizer'])
self.optimizer.load_state_dict(state['optimizer'], map_location=torch.device(cfg.device))
if load_lr_scheduler_states:
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
self.lr_scheduler.load_state_dict(state['lr_scheduler'], map_location=torch.device(cfg.device))
def eval(self):
return self.module.eval()

View File

@ -16,10 +16,9 @@ from .train import load_engines
from .data import get_phone_symmap, _load_quants
class TTS():
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ):
self.loading = True
self.device = device
self.input_sample_rate = 24000
self.output_sample_rate = 24000
@ -32,9 +31,24 @@ class TTS():
except Exception as e:
pass
if amp is None:
amp = cfg.inference.amp
if dtype is None:
dtype = cfg.inference.dtype
if device is None:
device = cfg.device
cfg.mode = "inferencing"
cfg.trainer.load_module_only = True
cfg.device = device
cfg.trainer.load_state_dict = True
cfg.trainer.backend = "local"
cfg.trainer.weight_dtype = dtype
cfg.inference.weight_dtype = dtype
self.device = device
self.dtype = cfg.inference.dtype
self.amp = amp
self.symmap = None
if ar_ckpt and nar_ckpt:
self.ar_ckpt = ar_ckpt
@ -50,7 +64,7 @@ class TTS():
if "module" in state:
state = state['module']
self.ar.load_state_dict(state)
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.nar = self.ar
elif name.startswith("ar"):
self.ar = model
@ -60,7 +74,7 @@ class TTS():
if "module" in state:
state = state['module']
self.ar.load_state_dict(state)
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
elif name.startswith("nar"):
self.nar = model
state = torch.load(self.nar_ckpt)
@ -69,7 +83,7 @@ class TTS():
if "module" in state:
state = state['module']
self.nar.load_state_dict(state)
self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
else:
self.load_models()
@ -85,12 +99,12 @@ class TTS():
engines = load_engines()
for name, engine in engines.items():
if name[:6] == "ar+nar":
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.nar = self.ar
elif name[:2] == "ar":
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
elif name[:3] == "nar":
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
self.nar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
def encode_text( self, text, language="en" ):
# already a tensor, return it
@ -135,12 +149,12 @@ class TTS():
prom = to_device(prom, self.device).to(torch.int16)
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
with torch.autocast(self.device, dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty)
wav, sr = qnt.decode_to_file(resps_list[0], out_path)
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
return (wav, sr)

View File

@ -101,7 +101,7 @@ def load_engines(invert=False):
if cfg.trainer.load_state_dict or not model._cfg.training:
load_path = cfg.ckpt_dir / name / "fp32.pth"
state = torch.load(load_path)
state = torch.load(load_path, map_location=torch.device(cfg.device))
# exporting the model from the zero_to_fp32.py exports the actual module's dict
# exporting with vall_e.export exports the state dict under .module
if "module" in state:

View File

@ -17,21 +17,12 @@ from torch import Tensor, nn
from tqdm.auto import tqdm
from typing import Callable, TypeVar, overload
try:
from deepspeed.runtime.utils import empty_cache
except Exception as e:
print(str(e))
def empty_cache():
...
T = TypeVar("T")
def do_gc():
gc.collect()
torch.cuda.empty_cache()
empty_cache()
def flatten_dict(d):
records = pd.json_normalize(d).to_dict(orient="records")
return records[0] if records else {}

137
vall_e/webui.py Normal file
View File

@ -0,0 +1,137 @@
import os
import re
import argparse
import tempfile
import functools
import gradio as gr
from pathlib import Path
from .inference import TTS
tts = None
layout = {}
layout["inference"] = {}
layout["inference"]["inputs"] = {
"progress": None
}
layout["inference"]["outputs"] = {}
layout["inference"]["buttons"] = {}
# there's got to be a better way to go about this
def gradio_wrapper(inputs):
def decorated(fun):
@functools.wraps(fun)
def wrapped_function(*args, **kwargs):
for i, key in enumerate(inputs):
kwargs[key] = args[i]
return fun(**kwargs)
return wrapped_function
return decorated
def init_tts(restart=False):
global tts
if tts is not None:
if not restart:
return tts
del tts
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--ar-ckpt", type=Path, default=None)
parser.add_argument("--nar-ckpt", type=Path, default=None)
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default="float32")
args, unknown = parser.parse_known_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
return tts
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--text", type=str, default=kwargs["text"])
parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--max-ar-steps", type=int, default=kwargs["steps"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0)
args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
tts = init_tts()
wav, sr = tts.inference(
text=args.text,
references=[args.references.split(";")],
out_path=tmp.name,
max_ar_steps=args.max_ar_steps,
ar_temp=args.ar_temp,
nar_temp=args.nar_temp,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty
)
wav = wav.squeeze(0).cpu().numpy()
return (sr, wav)
ui = gr.Blocks()
with ui:
with gr.Tab("Inference"):
with gr.Row():
with gr.Column():
layout["inference"]["inputs"]["text"] = gr.Textbox(lines=4, value="Your prompt here", label="Input Prompt")
with gr.Row():
with gr.Column():
layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", source="upload", type="filepath")
with gr.Column():
layout["inference"]["inputs"]["steps"] = gr.Slider(value=450, minimum=2, maximum=1024, step=1, label="Steps")
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)")
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)")
with gr.Column():
layout["inference"]["buttons"]["start"] = gr.Button(value="Inference")
# layout["inference"]["stop"] = gr.Button(value="Stop")
layout["inference"]["outputs"]["output"] = gr.Audio(label="Output")
layout["inference"]["buttons"]["start"].click(
fn=do_inference,
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
)
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
parser.add_argument("--share", action="store_true")
args, unknown = parser.parse_known_args()
args.listen_host = None
args.listen_port = None
args.listen_path = None
if args.listen:
try:
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
args.listen_port = match[1] if match[1] != "" else None
args.listen_path = match[2] if match[2] != "" else "/"
except Exception as e:
pass
if args.listen_port is not None:
args.listen_port = int(args.listen_port)
if args.listen_port == 0:
args.listen_port = None
ui.queue(max_size=8)
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port)