added light web UI (need to port the telemetry disabling bandaids from aivc)
This commit is contained in:
parent
10c34c5b98
commit
5ac119a6e7
|
@ -26,6 +26,12 @@ I've tested this repo under Python versions `3.10.9` and `3.11.3`.
|
|||
|
||||
## Try Me
|
||||
|
||||
### Online
|
||||
|
||||
A HuggingFace space hosting the code and models can be found [here](https://huggingface.co/spaces/ecker/vall-e).
|
||||
|
||||
### Local
|
||||
|
||||
To quickly try it out, you can choose between the following modes:
|
||||
|
||||
* AR only: `python -m vall_e.models.ar yaml="./data/config.yaml"`
|
||||
|
|
28
scripts/stitch_embs.py
Normal file
28
scripts/stitch_embs.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import torch
|
||||
|
||||
action = None
|
||||
# copies the resp_embs from a given AR and NAR into an AR as a base to convert into an AR+NAR monolithic odel
|
||||
if action == "merge_resp_embs":
|
||||
src_ar = torch.load("./data/source-ar.pth", map_location="cpu")
|
||||
src_nar = torch.load("./data/source-nar.pth", map_location="cpu")
|
||||
# copies all weights from the AR since the AR is usually "better", might need to experiment more with using a NAR as the base
|
||||
dst = torch.load("./data/source-ar.pth", map_location="cpu")
|
||||
|
||||
# copy resps_emb to layer 0 from AR
|
||||
dst['module']['resps_emb.weight'][:0, :, :] = src_ar['module']['resps_emb.weight']
|
||||
# copy resps_emb to remaining layers from NAR
|
||||
dst['module']['resps_emb.weight'][1:, :-1, :] = src_nar['module']['resps_emb.weight']
|
||||
# copies an existing AR+NAR monolithic model's resp_emb onto an AR
|
||||
elif action == "copy_resps_emb":
|
||||
src = torch.load("./data/source.pth", map_location="cpu")
|
||||
dst = torch.load("./data/destination.pth", map_location="cpu")
|
||||
dst['module']['resps_emb.weight'] = src['module']['resps_emb.weight']
|
||||
elif action == "extend_resps_emb":
|
||||
dst = torch.load("./data/destination.pth", map_location="cpu")
|
||||
dst['module']['resps_emb.weight'] = dst['module']['resps_emb.weight'].expand(4, -1, -1)
|
||||
dst['module']['resps_emb.weight'][1:] = torch.randn(3, 1025, 1024)
|
||||
|
||||
else
|
||||
raise Exception(f"invalid action: {action}")
|
||||
|
||||
torch.save(dst, './data/fp32.pth')
|
|
@ -26,10 +26,12 @@ def main():
|
|||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
|
||||
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty )
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -64,6 +64,8 @@ class Engine():
|
|||
self.global_samples = 0
|
||||
self.tokens_processed = 0
|
||||
|
||||
self._frozen_params = set()
|
||||
|
||||
def freeze(self, freeze_all=True):
|
||||
# set to freeze
|
||||
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
|
||||
|
@ -134,7 +136,7 @@ class Engine():
|
|||
if not load_path.exists():
|
||||
return
|
||||
|
||||
state = torch.load(load_path)
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
self.global_steps = state['global_step']
|
||||
self.micro_steps = state['micro_step']
|
||||
self.global_samples = state['global_samples']
|
||||
|
@ -145,10 +147,10 @@ class Engine():
|
|||
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
||||
|
||||
if load_optimizer_states:
|
||||
self.optimizer.load_state_dict(state['optimizer'])
|
||||
self.optimizer.load_state_dict(state['optimizer'], map_location=torch.device(cfg.device))
|
||||
|
||||
if load_lr_scheduler_states:
|
||||
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
|
||||
self.lr_scheduler.load_state_dict(state['lr_scheduler'], map_location=torch.device(cfg.device))
|
||||
|
||||
def eval(self):
|
||||
return self.module.eval()
|
||||
|
|
|
@ -16,9 +16,8 @@ from .train import load_engines
|
|||
from .data import get_phone_symmap, _load_quants
|
||||
|
||||
class TTS():
|
||||
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
|
||||
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ):
|
||||
self.loading = True
|
||||
self.device = device
|
||||
|
||||
self.input_sample_rate = 24000
|
||||
self.output_sample_rate = 24000
|
||||
|
@ -32,8 +31,23 @@ class TTS():
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
if amp is None:
|
||||
amp = cfg.inference.amp
|
||||
if dtype is None:
|
||||
dtype = cfg.inference.dtype
|
||||
if device is None:
|
||||
device = cfg.device
|
||||
|
||||
cfg.mode = "inferencing"
|
||||
cfg.trainer.load_module_only = True
|
||||
cfg.device = device
|
||||
cfg.trainer.load_state_dict = True
|
||||
cfg.trainer.backend = "local"
|
||||
cfg.trainer.weight_dtype = dtype
|
||||
cfg.inference.weight_dtype = dtype
|
||||
|
||||
self.device = device
|
||||
self.dtype = cfg.inference.dtype
|
||||
self.amp = amp
|
||||
|
||||
self.symmap = None
|
||||
if ar_ckpt and nar_ckpt:
|
||||
|
@ -50,7 +64,7 @@ class TTS():
|
|||
if "module" in state:
|
||||
state = state['module']
|
||||
self.ar.load_state_dict(state)
|
||||
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
self.nar = self.ar
|
||||
elif name.startswith("ar"):
|
||||
self.ar = model
|
||||
|
@ -60,7 +74,7 @@ class TTS():
|
|||
if "module" in state:
|
||||
state = state['module']
|
||||
self.ar.load_state_dict(state)
|
||||
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
elif name.startswith("nar"):
|
||||
self.nar = model
|
||||
state = torch.load(self.nar_ckpt)
|
||||
|
@ -69,7 +83,7 @@ class TTS():
|
|||
if "module" in state:
|
||||
state = state['module']
|
||||
self.nar.load_state_dict(state)
|
||||
self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
else:
|
||||
self.load_models()
|
||||
|
||||
|
@ -85,12 +99,12 @@ class TTS():
|
|||
engines = load_engines()
|
||||
for name, engine in engines.items():
|
||||
if name[:6] == "ar+nar":
|
||||
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
self.nar = self.ar
|
||||
elif name[:2] == "ar":
|
||||
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
elif name[:3] == "nar":
|
||||
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
self.nar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
|
||||
def encode_text( self, text, language="en" ):
|
||||
# already a tensor, return it
|
||||
|
@ -135,12 +149,12 @@ class TTS():
|
|||
prom = to_device(prom, self.device).to(torch.int16)
|
||||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||
|
||||
with torch.autocast(self.device, dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty)
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty)
|
||||
|
||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path)
|
||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
|
||||
|
||||
return (wav, sr)
|
||||
|
||||
|
|
|
@ -101,7 +101,7 @@ def load_engines(invert=False):
|
|||
|
||||
if cfg.trainer.load_state_dict or not model._cfg.training:
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
state = torch.load(load_path)
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
# exporting the model from the zero_to_fp32.py exports the actual module's dict
|
||||
# exporting with vall_e.export exports the state dict under .module
|
||||
if "module" in state:
|
||||
|
|
|
@ -17,21 +17,12 @@ from torch import Tensor, nn
|
|||
from tqdm.auto import tqdm
|
||||
from typing import Callable, TypeVar, overload
|
||||
|
||||
try:
|
||||
from deepspeed.runtime.utils import empty_cache
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
def empty_cache():
|
||||
...
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
def do_gc():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
empty_cache()
|
||||
|
||||
def flatten_dict(d):
|
||||
records = pd.json_normalize(d).to_dict(orient="records")
|
||||
return records[0] if records else {}
|
||||
|
|
137
vall_e/webui.py
Normal file
137
vall_e/webui.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
import os
|
||||
import re
|
||||
import argparse
|
||||
import tempfile
|
||||
import functools
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .inference import TTS
|
||||
|
||||
tts = None
|
||||
|
||||
layout = {}
|
||||
layout["inference"] = {}
|
||||
layout["inference"]["inputs"] = {
|
||||
"progress": None
|
||||
}
|
||||
layout["inference"]["outputs"] = {}
|
||||
layout["inference"]["buttons"] = {}
|
||||
|
||||
# there's got to be a better way to go about this
|
||||
def gradio_wrapper(inputs):
|
||||
def decorated(fun):
|
||||
@functools.wraps(fun)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
for i, key in enumerate(inputs):
|
||||
kwargs[key] = args[i]
|
||||
return fun(**kwargs)
|
||||
return wrapped_function
|
||||
return decorated
|
||||
|
||||
def init_tts(restart=False):
|
||||
global tts
|
||||
|
||||
if tts is not None:
|
||||
if not restart:
|
||||
return tts
|
||||
del tts
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--ar-ckpt", type=Path, default=None)
|
||||
parser.add_argument("--nar-ckpt", type=Path, default=None)
|
||||
parser.add_argument("--device", type=str, default="cpu")
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default="float32")
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||
return tts
|
||||
|
||||
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
||||
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--text", type=str, default=kwargs["text"])
|
||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||
parser.add_argument("--max-ar-steps", type=int, default=kwargs["steps"])
|
||||
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
|
||||
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||
|
||||
tts = init_tts()
|
||||
wav, sr = tts.inference(
|
||||
text=args.text,
|
||||
references=[args.references.split(";")],
|
||||
out_path=tmp.name,
|
||||
max_ar_steps=args.max_ar_steps,
|
||||
ar_temp=args.ar_temp,
|
||||
nar_temp=args.nar_temp,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty
|
||||
)
|
||||
|
||||
wav = wav.squeeze(0).cpu().numpy()
|
||||
return (sr, wav)
|
||||
|
||||
ui = gr.Blocks()
|
||||
with ui:
|
||||
with gr.Tab("Inference"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
layout["inference"]["inputs"]["text"] = gr.Textbox(lines=4, value="Your prompt here", label="Input Prompt")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", source="upload", type="filepath")
|
||||
with gr.Column():
|
||||
layout["inference"]["inputs"]["steps"] = gr.Slider(value=450, minimum=2, maximum=1024, step=1, label="Steps")
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)")
|
||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (NAR)")
|
||||
with gr.Column():
|
||||
layout["inference"]["buttons"]["start"] = gr.Button(value="Inference")
|
||||
# layout["inference"]["stop"] = gr.Button(value="Stop")
|
||||
layout["inference"]["outputs"]["output"] = gr.Audio(label="Output")
|
||||
|
||||
layout["inference"]["buttons"]["start"].click(
|
||||
fn=do_inference,
|
||||
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
|
||||
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
|
||||
parser.add_argument("--share", action="store_true")
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
args.listen_host = None
|
||||
args.listen_port = None
|
||||
args.listen_path = None
|
||||
if args.listen:
|
||||
try:
|
||||
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
|
||||
|
||||
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
|
||||
args.listen_port = match[1] if match[1] != "" else None
|
||||
args.listen_path = match[2] if match[2] != "" else "/"
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if args.listen_port is not None:
|
||||
args.listen_port = int(args.listen_port)
|
||||
if args.listen_port == 0:
|
||||
args.listen_port = None
|
||||
|
||||
ui.queue(max_size=8)
|
||||
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port)
|
Loading…
Reference in New Issue
Block a user