allow loading a different model within the web ui (apparently I did not have the web UI in the documentation)

This commit is contained in:
mrq 2024-07-15 19:59:48 -05:00
parent 7b210d9738
commit 3acc54df22
9 changed files with 129 additions and 67 deletions

View File

@ -20,7 +20,6 @@ Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install g
I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
## Pre-Trained Model
> [!NOTE]
@ -196,11 +195,34 @@ And some experimental sampling flags you can use too (your mileage will ***defin
+ **!**NOTE**!**: This is incompatible with beam search sampling (for the meantime at least).
* `--mirostat-eta`: (AR only) the "learning rate" during mirostat sampling applied to the maximum surprise.
### Web UI
A Gradio-based web UI is accessible by running `python3 -m vall_e.webui`. You can, optionally, pass:
* `--yaml=./path/to/your/config.yaml`: will load the targeted YAML
* `--listen 0.0.0.0:7860`: will set the web UI to listen to all IPs at port 7860. Replace the IP and Port to your preference.
#### Inference
Synthesizing speech is simple:
* `Input Prompt`: The guiding text prompt. Each new line will be it's own generated audio to be stitched together at the end.
* `Audio Input`: The reference audio for the synthesis. Under Gradio, you can trim your clip accordingly, but leaving it as-is works fine.
* `Output`: The resultant audio.
* `Inference`: Button to start generating the audio.
All the additional knobs have a description that can be correlated to the above CLI flags.
#### Settings
So far, this only allows you to load a different model without needing to restart. The previous model should seamlessly unload, and the new one will load in place.
## To-Do
* [x] train and release a serviceable model for finetuning against.
* [ ] train and release a ***good*** zero-shot model.
- this should, hopefully, just simply requires another epoch or two for `ar+nar-llama-8`, as the foundation seems rather robust now.
* [ ] well-integrated training through the Web UI (without the kludge from ai-voice-cloning)
* [x] ~~explore alternative setups, like a NAR-only model~~
- the current experiment of an AR length-predictor + NAR for the rest seems to fall apart...
* [x] ~~explore better sampling techniques~~

View File

@ -210,11 +210,10 @@ class ModelExperimentalSettings:
# I really need to clean this up
@dataclass()
class Model:
name: str = "" # vanity name for the model
name: str = "ar+nar" # vanity name for the model
version: int = 5 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding, 3+ = additional embeddings
size: str | dict = "full" # preset string or explicitly defined dimensionality
resp_levels: int = 1 # RVQ-bin levels this model targets for outputs
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
resp_levels: int = 8 # RVQ-bin levels this model supports
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") (unused)
langs: int = 1 # defined languages (semi-unused)
tones: int = 1 # defined tones (unsued)
@ -238,7 +237,10 @@ class Model:
@property
def max_levels(self):
return max(self.prom_levels, self.resp_levels)
# return RVQ level range
if self.experimental is not None and self.experimental.rvq_level_range:
return self.experimental.rvq_level_range[-1]
return self.resp_levels
@property
# required for fp8 as the lengths needs to be divisible by 8
@ -626,7 +628,7 @@ class Inference:
use_encodec: bool = True
use_dac: bool = True
@cached_property
@property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
@ -651,7 +653,7 @@ class Optimizations:
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
bitsandbytes: bool = False # use bitsandbytes
dadaptation: bool = True # use dadaptation optimizer
dadaptation: bool = False # use dadaptation optimizer
bitnet: bool = False # use bitnet
fp8: bool = False # use fp8
@ -671,7 +673,8 @@ class Config(BaseConfig):
bitsandbytes: dict | list | None = None # deprecated
optimizations: Optimizations = field(default_factory=lambda: Optimizations)
tokenizer: str = "./tokenizer.json"
tokenizer: str | None = None
tokenizer_path: str = "./tokenizer.json"
sample_rate: int = 24_000
variable_sample_rate: bool = False # NOT recommended, as running directly 24Khz audio in the 44Khz DAC model will have detrimental quality loss
@ -760,14 +763,21 @@ class Config(BaseConfig):
self.dataset.validation = [ Path(dir) for dir in self.dataset.validation ]
self.dataset.noise = [ Path(dir) for dir in self.dataset.noise ]
# do cleanup
for model in self.models:
if not isinstance( model, dict ):
continue
if "prom_levels" in model:
del model["prom_levels"]
if "interleave" in model:
del model["interleave"]
if "audio_embedding_sums" not in model:
continue
if not model["experimental"]:
if "experimental" not in model or not model["experimental"]:
model["experimental"] = {}
model["experimental"]["audio_embedding_sums"] = model.pop("audio_embedding_sums")
@ -837,9 +847,9 @@ class Config(BaseConfig):
try:
from transformers import PreTrainedTokenizerFast
tokenizer_path = cfg.rel_path / cfg.tokenizer
tokenizer_path = cfg.rel_path / cfg.tokenizer_path
if not tokenizer_path.exists():
tokenizer_path = Path("./data/") / cfg.tokenizer
tokenizer_path = Path("./data/") / cfg.tokenizer_path
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
except Exception as e:
cfg.tokenizer = NaiveTokenizer()

View File

@ -961,8 +961,8 @@ class Dataset(_Dataset):
"""
# trim to fit to requested prom/resps levels
proms = proms[:, :cfg.model.prom_levels]
resps = resps[:, :cfg.model.prom_levels]
proms = proms[:, :cfg.model.resp_levels]
resps = resps[:, :cfg.model.resp_levels]
return dict(
@ -1466,7 +1466,7 @@ if __name__ == "__main__":
if task not in cfg.dataset.tasks_list:
continue
print(text, task, cfg.model.prom_levels)
print(text, task, cfg.model.resp_levels)
print( proms.shape, resps.shape )
tokens = 0

View File

@ -37,7 +37,7 @@ def load_engines(training=True):
optimizer = None
lr_scheduler = None
inferencing = cfg.mode == "inferencing" or not model.config.training
inferencing = cfg.mode == "inferencing" or not model.config.training or not training
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp

View File

@ -8,10 +8,10 @@ from einops import rearrange
from pathlib import Path
from .emb import g2p, qnt
from .emb.qnt import trim, trim_random
from .emb.qnt import trim, trim_random, unload_model
from .utils import to_device, set_seed, wrapper as ml
from .config import cfg
from .config import cfg, Config
from .models import get_models
from .engines import load_engines, deepspeed_available
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
@ -23,10 +23,14 @@ class TTS():
def __init__( self, config=None, device=None, amp=None, dtype=None ):
self.loading = True
self.input_sample_rate = 24000
self.output_sample_rate = 24000
self.load_config( config=config, device=device, amp=amp, dtype=dtype )
self.load_model()
self.loading = False
def load_config( self, config=None, device=None, amp=None, dtype=None ):
if config:
print("Loading YAML:", config)
cfg.load_yaml( config )
try:
@ -53,7 +57,10 @@ class TTS():
self.dtype = cfg.inference.dtype
self.amp = amp
self.symmap = None
def load_model( self ):
load_engines.cache_clear()
unload_model()
self.engines = load_engines(training=False)
for name, engine in self.engines.items():
@ -61,11 +68,8 @@ class TTS():
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.engines.eval()
if self.symmap is None:
self.symmap = get_phone_symmap()
self.loading = False
print("Loaded model")
def encode_text( self, text, language="en" ):
# already a tensor, return it

View File

@ -33,12 +33,6 @@ class AR_NAR(Base):
def causal(self):
return "ar" in self.capabilities
@property
def n_prom_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.prom_levels
return cfg.model.prom_levels
@property
def n_resp_levels(self) -> int:
if hasattr(self, "config") and self.config:
@ -197,7 +191,7 @@ class AR_NAR(Base):
# is NAR
if max_levels == 0:
max_levels = self.n_resp_levels - 1
max_levels = self.n_max_levels - 1
# expand if given a raw 1D tensor
for i, resp in enumerate(resps_list):
@ -373,7 +367,6 @@ def example_usage():
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
"""
if "mamba" in cfg.model.arch_type:
cfg.model.prom_levels = 1
cfg.model.resp_levels = 1
"""
# cfg.model.loss_factors = {}
@ -383,7 +376,7 @@ def example_usage():
def _load_quants(path) -> Tensor:
qnt = np.load(path, allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")

View File

@ -123,7 +123,7 @@ class AudioEmbedding_Old(nn.Module):
):
super().__init__()
# array of embeddings
# proms are [0, prom_levels]
# proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
@ -154,7 +154,7 @@ class AudioEmbedding(nn.Module):
):
super().__init__()
# array of embeddings
# proms are [0, prom_levels]
# proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
@ -282,10 +282,6 @@ class Base(nn.Module):
def causal(self) -> bool:
raise NotImplementedError
@property
def n_prom_levels(self) -> int:
raise NotImplementedError
@property
def n_resp_levels(self) -> int:
raise NotImplementedError
@ -368,7 +364,6 @@ class Base(nn.Module):
self.l_padding = l_padding
n_prom_tokens = n_audio_tokens
arch_type = self.config.arch_type if self.config is not None else "llama"
self.arch_type = arch_type
@ -397,14 +392,14 @@ class Base(nn.Module):
self.len_emb = None
if self.version == 1: # legacy
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
n_audio_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
elif self.version < 5:
# [1024] * 8
self.proms_emb = AudioEmbedding_Old(
[n_prom_tokens] * self.n_prom_levels, d_model,
levels=self.n_prom_levels if self.version > 3 else None,
[n_audio_tokens] * self.n_resp_levels, d_model,
levels=self.n_resp_levels if self.version > 3 else None,
)
# [1024 + STOP] + [1024] * 8
self.resps_emb = AudioEmbedding_Old(
@ -413,7 +408,7 @@ class Base(nn.Module):
)
else:
self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model,
[n_audio_tokens] * self.n_resp_levels, d_model,
sums=audio_embedding_sums,
external_mode=audio_embedding_mode,
)

View File

@ -31,12 +31,6 @@ class NAR(Base):
def causal(self):
return "len" in self.capabilities
@property
def n_prom_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.prom_levels
return cfg.model.prom_levels
@property
def n_resp_levels(self) -> int:
if hasattr(self, "config") and self.config:
@ -309,7 +303,6 @@ def example_usage():
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
"""
if "mamba" in cfg.model.arch_type:
cfg.model.prom_levels = 1
cfg.model.resp_levels = 1
"""
# cfg.model.loss_factors = {}
@ -319,7 +312,7 @@ def example_usage():
def _load_quants(path) -> Tensor:
qnt = np.load(path, allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")

View File

@ -19,6 +19,7 @@ tts = None
layout = {}
layout["inference"] = {}
layout["training"] = {}
layout["settings"] = {}
for k in layout.keys():
layout[k]["inputs"] = { "progress": None }
@ -37,14 +38,42 @@ def gradio_wrapper(inputs):
return decorated
class timer:
def __init__(self, msg="Elapsed time:"):
self.msg = msg
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, type, value, traceback):
print(f'[{datetime.now().isoformat()}] Elapsed time: {(perf_counter() - self.start):.3f}s')
msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
def init_tts(restart=False):
gr.Info(msg)
print(f'[{datetime.now().isoformat()}] {msg}')
# returns a list of models, assuming the models are placed under ./training/ or ./models/
def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
yamls = []
for path in paths:
if not path.exists():
continue
for yaml in path.glob("**/*.yaml"):
if "/logs/" in str(yaml):
continue
yamls.append( yaml )
return yamls
#
def load_model( yaml ):
gr.Info(f"Loading: {yaml}")
init_tts( yaml=Path(yaml), restart=True )
gr.Info(f"Loaded model")
def init_tts(yaml=None, restart=False):
global tts
if tts is not None:
@ -53,13 +82,13 @@ def init_tts(restart=False):
del tts
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default="auto")
args, unknown = parser.parse_known_args()
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
return tts
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
@ -78,7 +107,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-nar-levels", type=int, default=kwargs["max-nar-levels"])
parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
parser.add_argument("--nar-temp", type=float, default=kwargs["nar-temp"])
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
@ -99,7 +128,9 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
raise ValueError("No reference audio provided.")
tts = init_tts()
with timer() as t:
gr.Info("Inferencing...")
with timer("Inferenced in") as t:
wav, sr = tts.inference(
text=args.text,
language=args.language,
@ -169,6 +200,7 @@ def get_random_prompt():
# setup args
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
parser.add_argument("--share", action="store_true")
parser.add_argument("--render_markdown", action="store_true", default="VALLE_YAML" in os.environ)
@ -208,7 +240,7 @@ with ui:
with gr.Column(scale=7):
with gr.Row():
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
#layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
@ -249,6 +281,19 @@ with ui:
)
"""
with gr.Tab("Settings"):
with gr.Row():
with gr.Column(scale=7):
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
with gr.Column(scale=1):
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
layout["settings"]["buttons"]["load"].click(
fn=load_model,
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
)
if os.path.exists("README.md") and args.render_markdown:
md = open("README.md", "r", encoding="utf-8").read()
# remove HF's metadata