rolling context finally (use last N utterances as the prefix for the next gen), option to split input text prompt by sentences instead of lines (or no splitting)

This commit is contained in:
mrq 2024-12-04 20:31:44 -06:00
parent 9dff68c0c5
commit 93d27be539
9 changed files with 154 additions and 55 deletions

3
.gitignore vendored
View File

@ -6,4 +6,5 @@ __pycache__
/vall_e/version.py
/.cache
/voices
/wandb
/wandb
/.nltk

View File

@ -20,15 +20,14 @@ Besides a working PyTorch environment, the only hard requirement is [`espeak-ng`
Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install git+https://github.com/e-c-k-e-r/vall-e`.
I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
This repo is tested under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
## Pre-Trained Model
My pre-trained weights can be acquired from [here](https://huggingface.co/ecker/vall-e).
A script to setup a proper environment and download the weights can be invoked with `./scripts/setup.sh`. This will automatically create a `venv`, and download the `ar+nar-llama-8` weights and config file to the right place.
When inferencing, either through the web UI or CLI, if no model is passed, the default model will download automatically instead, and should automatically update.
Pre-trained weights can be acquired from
* [here](https://huggingface.co/ecker/vall-e) or automatically when either inferencing or running the web UI.
* `./scripts/setup.sh`, a script to setup a proper environment and download the weights. This will also automatically create a `venv`.
* when inferencing, either through the web UI or CLI, if no model is passed, the default model will download automatically instead, and should automatically update.
## Documentation

View File

@ -46,6 +46,8 @@ However, at this point and time, the implementation is rather divorced from VALL
- KV caching both yields broken output and quadratically slow output, unless I'm doing something grossly wrong.
* [x] provide a pure NAR model that foregoes most of the inferencing slowdowns a regular AR+NAR model will provide.
* [ ] HF-ify the model
* [x] write a weights converter
* [ ] implement a pure llama_HF implementation
- this might be easily possible by subjugating the tokenizer to handle all the embeddings / classifiers
- this will pave the way to use the model under an easy marriage of `llama.cpp` and `encodec.cpp`
* [ ] replace the phonemizer with something that doesn't depend on espeak
@ -55,7 +57,7 @@ However, at this point and time, the implementation is rather divorced from VALL
- espeak is nice, but I can only really put my whole trust with phonemizing English.
- a small model trained to handle converting text to phonemes might work, but has it's own problems (another model to carry around, as accurate as the dataset it was trained against, requires training for each language... etc).
* [ ] smarter/clever inferencing, such as:
* [ ] "rolling" context, where the last generated sentence is the prefix for the next sentence.
* [x] "rolling" context, where the last generated sentence is the prefix for the next sentence.
* [ ] explore exotic features like:
* using a pure text vocab rather than IPA phonemes (as a transformer should be "smart" enough to map text tokens)
* interleaving by using summed embedding tokens:
@ -79,9 +81,11 @@ However, while this solution boasts being lightweight, there are some caveats fo
* `hf`-ifying it is possible, but it'd be a chore to set up the tokenizer properly
* it still seems like the phase of the moon matters with how it wants to cooperate
* some eval tests it seems fine, other times issues like word errors will crop up
* the `NAR-len` requires CFGs > 2-ish to cooperate
* the `NAR-len` requires CFGs > 2-ish to cooperate (or a prefix)
* this isn't *so* much of an issue, but this can lead to user error, and CFG incurs an additional sampling step per step.
* guidance distillation would be nice, but distillation in general harms finetuning (assuming this just as likely harms it)
* rolling context/prefix does solve this
* VALL-E Continuous (prefixing with the input prompt) could also fix this, but technically makes it one-shot and not zero-shot
## Notices and Citations

View File

@ -96,7 +96,7 @@ It is ***crucial*** to:
* without this, you ***will*** get stuttering and unaligned utterances. I do not know why this is such a big problem but I imagine this "interleaves" many different sequences between each step.
* use unfiltered/unprocessed logit scores:
* not that crucial, but helps stability.
* use a CFG strength of at least 2
* use a CFG strength of at least 2 (or a prefix)
It is not required to train a model from scratch to use this modality, as training from existing weights works just as well, if not better (as it can piggyback off the original model).
* additional training is still required to help confidence issues and to condition the model to not fall apart for longer durations.

View File

@ -17,6 +17,9 @@ def main():
parser.add_argument("--modality", type=str, default="auto")
parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--split-text-by", type=str, default="\n")
parser.add_argument("--context-history", type=int, default=0)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--model", type=Path, default=None)
parser.add_argument("--lora", type=Path, default=None)
@ -81,6 +84,8 @@ def main():
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
sampling_kwargs = dict(
split_text_by=args.split_text_by,
context_history=args.context_history,
max_steps=args.max_steps,
max_levels=args.max_levels,
max_duration=args.max_duration,

View File

@ -35,6 +35,34 @@ from tqdm.auto import tqdm
_logger = logging.getLogger(__name__)
# cringe
try:
import nltk
if not Path(".nltk").exists():
nltk.download('punkt_tab', download_dir="./.nltk/")
except Exception as e:
nltk = None
_logger.warning(f"Error while querying for NTLK: {str(e)}")
def sentence_split( s, split_by="sentences", quote_placeholder="<QUOTE>" ):
if split_by is None:
return [s]
# NTLK is not available, fallback
if nltk is None:
split_by = "\n"
# split by delimiter instead
if split_by != "sentences":
return s.split(split_by)
# use NTLK to handle splitting by sentences, because I don't want to write my own parser to split by punctuation
# nltk does not split quotations all that nicely, so we coerce them into placeholders, then replace afterwards
s = s.replace('"', quote_placeholder)
sentences = nltk.sent_tokenize(s)
return [ sentence.replace(quote_placeholder, '"') for sentence in sentences ]
@cache
def get_random_prompts( validation=False, min_length=0, tokenized=False ):
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range

View File

@ -19,7 +19,7 @@ from .config import cfg, Config
from .models import get_models
from .models.lora import enable_lora
from .engines import load_engines, deepspeed_available
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize, sentence_split
from .models import download_model, DEFAULT_MODEL_PATH
if deepspeed_available:
@ -195,7 +195,6 @@ class TTS():
modality="auto",
input_prompt_length = 0,
load_from_artifact = False,
seed = None,
out_path=None,
@ -203,7 +202,7 @@ class TTS():
use_lora=None,
**sampling_kwargs,
):
lines = text.split("\n")
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
wavs = []
sr = None
@ -253,6 +252,11 @@ class TTS():
return text_list[0]
# stuff for rolling context
prefix_context = None
prefix_contexts = []
context_history = sampling_kwargs.get("context_history", 0)
for line in lines:
if out_path is None:
output_dir = Path("./data/results/")
@ -275,30 +279,14 @@ class TTS():
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0)
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_duration": 5} ) # don't need more than that
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_duration": 5} ) # "max_duration" is max tokens
# add an additional X seconds
len_list = [ int(l * duration_padding) for l in len_list ]
kwargs = {}
# nasty hardcode to load a reference file and have that as the input target
if load_from_artifact and load_from_artifact.exists():
artifact = np.load(load_from_artifact, allow_pickle=True)[()]
phns = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=self.device)
resp = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=self.device)
len_list = [ resp.shape[0] ]
kwargs["resps_list"] = [ resp[:, 0] ]
# kludge experiment
elif nar_len_prefix_length > 0:
resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"],
disable_tqdm=not tqdm,
use_lora=use_lora,
**(sampling_kwargs | {"max_duration": nar_len_prefix_length}),
)
kwargs["resps_list"] = [ resp if resp.dim() == 1 else resp[:, 0] for resp in resps_list ]
if prefix_context is not None:
kwargs["prefix_context"] = prefix_context
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, task_list=["tts"],
disable_tqdm=not tqdm,
@ -306,12 +294,17 @@ class TTS():
**(sampling_kwargs | kwargs),
)
elif model_ar is not None:
kwargs = {}
if prefix_context is not None:
kwargs["prefix_context"] = prefix_context
resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"],
disable_tqdm=not tqdm,
use_lora=use_lora,
**sampling_kwargs,
**(sampling_kwargs | kwargs),
)
resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"],
disable_tqdm=not tqdm,
@ -321,8 +314,25 @@ class TTS():
else:
raise Exception("!")
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
# to-do: care about batching later
resps = resps_list[0]
# store current context to use as the initial input for later
if context_history > 0:
# add to history
prefix_contexts.append(( phns, resps, resps.shape[0] ))
# then generate the prefix based on how much history to provide
prefix_context = (
[ torch.concat( [ x[0] for x in prefix_contexts[-context_history:] ] ) ],
[ torch.concat( [ x[1] for x in prefix_contexts[-context_history:] ] ) ],
[ sum([ x[2] for x in prefix_contexts[-context_history:] ]) ]
)
# write to file
wav, sr = qnt.decode_to_file(resps, out_path, device=self.device)
# add utterances
wavs.append(wav)
# combine all utterances
return (torch.concat(wavs, dim=-1), sr)

View File

@ -269,8 +269,15 @@ class AR_NAR(Base):
# force set CFG because too low / no CFG causes issues
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 3.0)
original_cfg_strength = cfg_strength
cfg_strength = max( cfg_strength, minimum_cfg_strength )
prefix_context = sampling_kwargs.get("prefix_context", None)
# we can get away with just providing a list of resps to prefix later, and it will magically get removed anyways when masking and scoring
if prefix_context is not None:
text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_context[0], text_list ) ]
prefix_resps_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_context[1] ]
# if we're denoising from an existing sequence
if start_noise > 0.0 and resps_list is not None:
# flatten if needed
@ -279,19 +286,6 @@ class AR_NAR(Base):
noise_p = math.cos( start_noise * math.pi * 0.5 )
# generate scoring mask (because the above mask will get masked off per the scores, so we do not need to mask beforehand)
scores = [ torch.tensor( [ 1.0 if random.random() < noise_p else 0.0 for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ]
# deduce that this is a prefix
elif resps_list is not None:
# number of remaining tokens
tokens_to_mask = [ l - resps.shape[0] for resps, l in zip( resps_list, len_list ) ]
# pad with masked tokens
resps_list = [ torch.concat([ resps if resps.dim() == 1 else resps[:, 0], torch.tensor( [ self.stop_token ] * l, dtype=resps.dtype, device=resps.device ) ]) for resps, l in zip( resps_list, tokens_to_mask ) ]
# update scores to ignore the prefix
scores = [ torch.concat( [ torch.zeros((resps.shape[0],), dtype=torch.int16, device=device), torch.ones((l), dtype=torch.int16, device=device) ] ) for resps, l in zip( resps_list, tokens_to_mask ) ]
# set start noise
# only the first because we do not have variable noising at the moment
# *technically* the prefix can be a fixed portion for all inputs in a batch, rather than a fixed length
# this will set the starting noise_p with the right ratio
start_noise = 2 / math.pi * math.acos(resps_list[0].shape[0] / len_list[0])
else:
# fill with masked tokens (even though they get masked anyways)
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
@ -302,7 +296,8 @@ class AR_NAR(Base):
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm)
for timestep in iterator:
# update previous list of tokens
prev_list = resps_list
# ramp down over time
@ -327,11 +322,19 @@ class AR_NAR(Base):
if sampling_cfg < minimum_cfg_strength * 0.5:
sampling_cfg = 0
if prefix_context is not None:
input_resps_list = [ torch.concat( [ prefix, resps ] ) for prefix, resps in zip( prefix_resps_list, resps_list ) ]
# originally requested no CFG, safe to ignore if we have a prefix
if original_cfg_strength == 0:
sampling_cfg = 0
else:
input_resps_list = resps_list
# setup inputs
inputs = super().inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
resps_list=input_resps_list,
lang_list=lang_list,
tone_list=tone_list,
time_list=time_list,
@ -349,7 +352,7 @@ class AR_NAR(Base):
null_inputs = super().inputs(
text_list=null_text,
proms_list=null_prom,
resps_list=resps_list,
resps_list=input_resps_list,
lang_list=lang_list,
tone_list=tone_list,
time_list=time_list,
@ -433,6 +436,17 @@ class AR_NAR(Base):
if max_levels == 0:
max_levels = self.n_max_levels - 1
# prefixed context provided
"""
prefix_context = sampling_kwargs.get("prefix_context", None)
if prefix_context is not None:
prefix_text, prefix_resps, _ = prefix_context
# to-do: check if we actually need to drop the middle "<eos><bos>"
text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, text_list ) ]
# feeding this into the NAR-len should automatically handle things
resps_list = [ resps for resps in prefix_resps ]
"""
"""
sampling_layer_skip_variables = {} if sampling_layer_skip else None
@ -468,9 +482,11 @@ class AR_NAR(Base):
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
iterator = trange( max_levels, desc="NAR", disable=disable_tqdm )
for n in iterator:
level = prev_list[0].shape[-1]
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
iterator.close()
break
if cfg.lora is not None:
@ -578,7 +594,8 @@ class AR_NAR(Base):
task_list = [ "len" for _ in range(batch_size) ]
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
for n in trange(10, desc="AR", disable=disable_tqdm):
iterator = trange(10, desc="AR", disable=disable_tqdm)
for n in iterator:
len_list = sequence_list
inputs = self.inputs(
@ -613,6 +630,7 @@ class AR_NAR(Base):
# stop token found
stopped |= r == stop_token
if stopped.all().item():
iterator.close()
break
# convert tokens into int
@ -660,11 +678,21 @@ class AR_NAR(Base):
sequence_list[i] = sequence_list[i][:, 0]
# start_slice[i] = sequence_list[i].shape[0]
# prefixed context provided
prefix_context = sampling_kwargs.get("prefix_context", None)
if prefix_context is not None:
prefix_text, prefix_resps, _ = prefix_context
# to-do: check if we actually need to drop the middle "<eos><bos>"
text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, text_list ) ]
# feeding this into the NAR-len should automatically handle things
sequence_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_resps ]
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
null_prom = [ None for _ in range(batch_size) ]
# get next in sequence
for n in trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
for n in iterator:
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
@ -750,6 +778,7 @@ class AR_NAR(Base):
# stop token found
# stopped |= r == stop_token
if stopped.all().item():
iterator.close()
break
# to-do for layerskip / speculative sampling: rerun the last sequence again at max depth
@ -785,7 +814,12 @@ class AR_NAR(Base):
refined_list = [ logit.argmax(dim=-1) for logit in logits ]
# to-do: compare scores
# set the "refined" list as the output
sequence_list = refined_list
sequence_list = refined_list
# slice out prefix
if prefix_context is not None:
prefix_text, prefix_resps, prefix_lens = prefix_context
sequence_list = [ resps[l:] for resps, l in zip(sequence_list, prefix_lens) ]
return sequence_list
@ -837,6 +871,8 @@ class AR_NAR(Base):
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
)
# is NAR
@ -849,6 +885,8 @@ class AR_NAR(Base):
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
@ -861,6 +899,8 @@ class AR_NAR(Base):
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
**sampling_kwargs,
)

View File

@ -204,6 +204,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--modality", type=str, default=kwargs["modality"])
parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--language", type=str, default=kwargs["language"])
parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"])
parser.add_argument("--context-history", type=int, default=kwargs["context-history"])
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
parser.add_argument("--max-duration", type=int, default=int(kwargs["max-duration"]*cfg.dataset.frames_per_second))
@ -257,11 +259,18 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if kwargs.pop("refine-on-stop", False):
args.refine_on_stop = True
if args.split_text_by == "lines":
args.split_text_by = "\n"
elif args.split_text_by == "none":
args.split_text_by = None
tts = init_tts()
gr.Info(f"Inferencing... (Modality: {tts.modality(args.modality.lower())})")
sampling_kwargs = dict(
split_text_by=args.split_text_by,
context_history=args.context_history,
max_steps=args.max_steps,
max_levels=args.max_levels,
max_duration=args.max_duration,
@ -437,6 +446,9 @@ with ui:
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale (AR needs 1, NAR-len needs 3).")
layout["inference_tts"]["inputs"]["cfg-rescale"] = gr.Slider(value=0.75, minimum=0.0, maximum=1.0, step=0.05, label="CFG Rescale (Phi)", info="Factor when rescaling for Classifier Free Guidance (0 to disable).")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Row():
layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="Splits the text into pieces.", value="sentences")
layout["inference_tts"]["inputs"]["context-history"] = gr.Slider(value=0, minimum=0, maximum=4, step=1, label="(Rolling) Context History", info="How many prior lines to serve as the context/prefix (0 to disable).")
with gr.Tab("Sampler Settings"):
with gr.Row():
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")