rolling context finally (use last N utterances as the prefix for the next gen), option to split input text prompt by sentences instead of lines (or no splitting)
This commit is contained in:
parent
9dff68c0c5
commit
93d27be539
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -6,4 +6,5 @@ __pycache__
|
|||
/vall_e/version.py
|
||||
/.cache
|
||||
/voices
|
||||
/wandb
|
||||
/wandb
|
||||
/.nltk
|
11
README.md
11
README.md
|
@ -20,15 +20,14 @@ Besides a working PyTorch environment, the only hard requirement is [`espeak-ng`
|
|||
|
||||
Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install git+https://github.com/e-c-k-e-r/vall-e`.
|
||||
|
||||
I've tested this repo under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
|
||||
This repo is tested under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
|
||||
|
||||
## Pre-Trained Model
|
||||
|
||||
My pre-trained weights can be acquired from [here](https://huggingface.co/ecker/vall-e).
|
||||
|
||||
A script to setup a proper environment and download the weights can be invoked with `./scripts/setup.sh`. This will automatically create a `venv`, and download the `ar+nar-llama-8` weights and config file to the right place.
|
||||
|
||||
When inferencing, either through the web UI or CLI, if no model is passed, the default model will download automatically instead, and should automatically update.
|
||||
Pre-trained weights can be acquired from
|
||||
* [here](https://huggingface.co/ecker/vall-e) or automatically when either inferencing or running the web UI.
|
||||
* `./scripts/setup.sh`, a script to setup a proper environment and download the weights. This will also automatically create a `venv`.
|
||||
* when inferencing, either through the web UI or CLI, if no model is passed, the default model will download automatically instead, and should automatically update.
|
||||
|
||||
## Documentation
|
||||
|
||||
|
|
|
@ -46,6 +46,8 @@ However, at this point and time, the implementation is rather divorced from VALL
|
|||
- KV caching both yields broken output and quadratically slow output, unless I'm doing something grossly wrong.
|
||||
* [x] provide a pure NAR model that foregoes most of the inferencing slowdowns a regular AR+NAR model will provide.
|
||||
* [ ] HF-ify the model
|
||||
* [x] write a weights converter
|
||||
* [ ] implement a pure llama_HF implementation
|
||||
- this might be easily possible by subjugating the tokenizer to handle all the embeddings / classifiers
|
||||
- this will pave the way to use the model under an easy marriage of `llama.cpp` and `encodec.cpp`
|
||||
* [ ] replace the phonemizer with something that doesn't depend on espeak
|
||||
|
@ -55,7 +57,7 @@ However, at this point and time, the implementation is rather divorced from VALL
|
|||
- espeak is nice, but I can only really put my whole trust with phonemizing English.
|
||||
- a small model trained to handle converting text to phonemes might work, but has it's own problems (another model to carry around, as accurate as the dataset it was trained against, requires training for each language... etc).
|
||||
* [ ] smarter/clever inferencing, such as:
|
||||
* [ ] "rolling" context, where the last generated sentence is the prefix for the next sentence.
|
||||
* [x] "rolling" context, where the last generated sentence is the prefix for the next sentence.
|
||||
* [ ] explore exotic features like:
|
||||
* using a pure text vocab rather than IPA phonemes (as a transformer should be "smart" enough to map text tokens)
|
||||
* interleaving by using summed embedding tokens:
|
||||
|
@ -79,9 +81,11 @@ However, while this solution boasts being lightweight, there are some caveats fo
|
|||
* `hf`-ifying it is possible, but it'd be a chore to set up the tokenizer properly
|
||||
* it still seems like the phase of the moon matters with how it wants to cooperate
|
||||
* some eval tests it seems fine, other times issues like word errors will crop up
|
||||
* the `NAR-len` requires CFGs > 2-ish to cooperate
|
||||
* the `NAR-len` requires CFGs > 2-ish to cooperate (or a prefix)
|
||||
* this isn't *so* much of an issue, but this can lead to user error, and CFG incurs an additional sampling step per step.
|
||||
* guidance distillation would be nice, but distillation in general harms finetuning (assuming this just as likely harms it)
|
||||
* rolling context/prefix does solve this
|
||||
* VALL-E Continuous (prefixing with the input prompt) could also fix this, but technically makes it one-shot and not zero-shot
|
||||
|
||||
|
||||
## Notices and Citations
|
||||
|
|
|
@ -96,7 +96,7 @@ It is ***crucial*** to:
|
|||
* without this, you ***will*** get stuttering and unaligned utterances. I do not know why this is such a big problem but I imagine this "interleaves" many different sequences between each step.
|
||||
* use unfiltered/unprocessed logit scores:
|
||||
* not that crucial, but helps stability.
|
||||
* use a CFG strength of at least 2
|
||||
* use a CFG strength of at least 2 (or a prefix)
|
||||
|
||||
It is not required to train a model from scratch to use this modality, as training from existing weights works just as well, if not better (as it can piggyback off the original model).
|
||||
* additional training is still required to help confidence issues and to condition the model to not fall apart for longer durations.
|
||||
|
|
|
@ -17,6 +17,9 @@ def main():
|
|||
parser.add_argument("--modality", type=str, default="auto")
|
||||
parser.add_argument("--out-path", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--split-text-by", type=str, default="\n")
|
||||
parser.add_argument("--context-history", type=int, default=0)
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--model", type=Path, default=None)
|
||||
parser.add_argument("--lora", type=Path, default=None)
|
||||
|
@ -81,6 +84,8 @@ def main():
|
|||
tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
|
||||
|
||||
sampling_kwargs = dict(
|
||||
split_text_by=args.split_text_by,
|
||||
context_history=args.context_history,
|
||||
max_steps=args.max_steps,
|
||||
max_levels=args.max_levels,
|
||||
max_duration=args.max_duration,
|
||||
|
|
|
@ -35,6 +35,34 @@ from tqdm.auto import tqdm
|
|||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# cringe
|
||||
try:
|
||||
import nltk
|
||||
|
||||
if not Path(".nltk").exists():
|
||||
nltk.download('punkt_tab', download_dir="./.nltk/")
|
||||
except Exception as e:
|
||||
nltk = None
|
||||
_logger.warning(f"Error while querying for NTLK: {str(e)}")
|
||||
|
||||
def sentence_split( s, split_by="sentences", quote_placeholder="<QUOTE>" ):
|
||||
if split_by is None:
|
||||
return [s]
|
||||
|
||||
# NTLK is not available, fallback
|
||||
if nltk is None:
|
||||
split_by = "\n"
|
||||
|
||||
# split by delimiter instead
|
||||
if split_by != "sentences":
|
||||
return s.split(split_by)
|
||||
|
||||
# use NTLK to handle splitting by sentences, because I don't want to write my own parser to split by punctuation
|
||||
# nltk does not split quotations all that nicely, so we coerce them into placeholders, then replace afterwards
|
||||
s = s.replace('"', quote_placeholder)
|
||||
sentences = nltk.sent_tokenize(s)
|
||||
return [ sentence.replace(quote_placeholder, '"') for sentence in sentences ]
|
||||
|
||||
@cache
|
||||
def get_random_prompts( validation=False, min_length=0, tokenized=False ):
|
||||
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
|
||||
|
|
|
@ -19,7 +19,7 @@ from .config import cfg, Config
|
|||
from .models import get_models
|
||||
from .models.lora import enable_lora
|
||||
from .engines import load_engines, deepspeed_available
|
||||
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
|
||||
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize, sentence_split
|
||||
from .models import download_model, DEFAULT_MODEL_PATH
|
||||
|
||||
if deepspeed_available:
|
||||
|
@ -195,7 +195,6 @@ class TTS():
|
|||
modality="auto",
|
||||
|
||||
input_prompt_length = 0,
|
||||
load_from_artifact = False,
|
||||
|
||||
seed = None,
|
||||
out_path=None,
|
||||
|
@ -203,7 +202,7 @@ class TTS():
|
|||
use_lora=None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
lines = text.split("\n")
|
||||
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
|
||||
|
||||
wavs = []
|
||||
sr = None
|
||||
|
@ -253,6 +252,11 @@ class TTS():
|
|||
|
||||
return text_list[0]
|
||||
|
||||
# stuff for rolling context
|
||||
prefix_context = None
|
||||
prefix_contexts = []
|
||||
context_history = sampling_kwargs.get("context_history", 0)
|
||||
|
||||
for line in lines:
|
||||
if out_path is None:
|
||||
output_dir = Path("./data/results/")
|
||||
|
@ -275,30 +279,14 @@ class TTS():
|
|||
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
|
||||
nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0)
|
||||
|
||||
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_duration": 5} ) # don't need more than that
|
||||
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_duration": 5} ) # "max_duration" is max tokens
|
||||
|
||||
# add an additional X seconds
|
||||
len_list = [ int(l * duration_padding) for l in len_list ]
|
||||
|
||||
kwargs = {}
|
||||
# nasty hardcode to load a reference file and have that as the input target
|
||||
if load_from_artifact and load_from_artifact.exists():
|
||||
artifact = np.load(load_from_artifact, allow_pickle=True)[()]
|
||||
|
||||
phns = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=self.device)
|
||||
resp = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=self.device)
|
||||
len_list = [ resp.shape[0] ]
|
||||
|
||||
kwargs["resps_list"] = [ resp[:, 0] ]
|
||||
# kludge experiment
|
||||
elif nar_len_prefix_length > 0:
|
||||
resps_list = model_nar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"],
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
**(sampling_kwargs | {"max_duration": nar_len_prefix_length}),
|
||||
)
|
||||
kwargs["resps_list"] = [ resp if resp.dim() == 1 else resp[:, 0] for resp in resps_list ]
|
||||
if prefix_context is not None:
|
||||
kwargs["prefix_context"] = prefix_context
|
||||
|
||||
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, task_list=["tts"],
|
||||
disable_tqdm=not tqdm,
|
||||
|
@ -306,12 +294,17 @@ class TTS():
|
|||
**(sampling_kwargs | kwargs),
|
||||
)
|
||||
elif model_ar is not None:
|
||||
kwargs = {}
|
||||
if prefix_context is not None:
|
||||
kwargs["prefix_context"] = prefix_context
|
||||
|
||||
resps_list = model_ar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"],
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
**(sampling_kwargs | kwargs),
|
||||
)
|
||||
|
||||
resps_list = model_nar(
|
||||
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"],
|
||||
disable_tqdm=not tqdm,
|
||||
|
@ -321,8 +314,25 @@ class TTS():
|
|||
else:
|
||||
raise Exception("!")
|
||||
|
||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
|
||||
# to-do: care about batching later
|
||||
resps = resps_list[0]
|
||||
|
||||
# store current context to use as the initial input for later
|
||||
if context_history > 0:
|
||||
# add to history
|
||||
prefix_contexts.append(( phns, resps, resps.shape[0] ))
|
||||
# then generate the prefix based on how much history to provide
|
||||
prefix_context = (
|
||||
[ torch.concat( [ x[0] for x in prefix_contexts[-context_history:] ] ) ],
|
||||
[ torch.concat( [ x[1] for x in prefix_contexts[-context_history:] ] ) ],
|
||||
[ sum([ x[2] for x in prefix_contexts[-context_history:] ]) ]
|
||||
)
|
||||
|
||||
# write to file
|
||||
wav, sr = qnt.decode_to_file(resps, out_path, device=self.device)
|
||||
# add utterances
|
||||
wavs.append(wav)
|
||||
|
||||
|
||||
# combine all utterances
|
||||
return (torch.concat(wavs, dim=-1), sr)
|
||||
|
||||
|
|
|
@ -269,8 +269,15 @@ class AR_NAR(Base):
|
|||
|
||||
# force set CFG because too low / no CFG causes issues
|
||||
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 3.0)
|
||||
original_cfg_strength = cfg_strength
|
||||
cfg_strength = max( cfg_strength, minimum_cfg_strength )
|
||||
|
||||
prefix_context = sampling_kwargs.get("prefix_context", None)
|
||||
# we can get away with just providing a list of resps to prefix later, and it will magically get removed anyways when masking and scoring
|
||||
if prefix_context is not None:
|
||||
text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_context[0], text_list ) ]
|
||||
prefix_resps_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_context[1] ]
|
||||
|
||||
# if we're denoising from an existing sequence
|
||||
if start_noise > 0.0 and resps_list is not None:
|
||||
# flatten if needed
|
||||
|
@ -279,19 +286,6 @@ class AR_NAR(Base):
|
|||
noise_p = math.cos( start_noise * math.pi * 0.5 )
|
||||
# generate scoring mask (because the above mask will get masked off per the scores, so we do not need to mask beforehand)
|
||||
scores = [ torch.tensor( [ 1.0 if random.random() < noise_p else 0.0 for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ]
|
||||
# deduce that this is a prefix
|
||||
elif resps_list is not None:
|
||||
# number of remaining tokens
|
||||
tokens_to_mask = [ l - resps.shape[0] for resps, l in zip( resps_list, len_list ) ]
|
||||
# pad with masked tokens
|
||||
resps_list = [ torch.concat([ resps if resps.dim() == 1 else resps[:, 0], torch.tensor( [ self.stop_token ] * l, dtype=resps.dtype, device=resps.device ) ]) for resps, l in zip( resps_list, tokens_to_mask ) ]
|
||||
# update scores to ignore the prefix
|
||||
scores = [ torch.concat( [ torch.zeros((resps.shape[0],), dtype=torch.int16, device=device), torch.ones((l), dtype=torch.int16, device=device) ] ) for resps, l in zip( resps_list, tokens_to_mask ) ]
|
||||
# set start noise
|
||||
# only the first because we do not have variable noising at the moment
|
||||
# *technically* the prefix can be a fixed portion for all inputs in a batch, rather than a fixed length
|
||||
# this will set the starting noise_p with the right ratio
|
||||
start_noise = 2 / math.pi * math.acos(resps_list[0].shape[0] / len_list[0])
|
||||
else:
|
||||
# fill with masked tokens (even though they get masked anyways)
|
||||
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
|
||||
|
@ -302,7 +296,8 @@ class AR_NAR(Base):
|
|||
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||
null_prom = [ None for _ in range(batch_size) ]
|
||||
|
||||
for timestep in tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm):
|
||||
iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm)
|
||||
for timestep in iterator:
|
||||
# update previous list of tokens
|
||||
prev_list = resps_list
|
||||
# ramp down over time
|
||||
|
@ -327,11 +322,19 @@ class AR_NAR(Base):
|
|||
if sampling_cfg < minimum_cfg_strength * 0.5:
|
||||
sampling_cfg = 0
|
||||
|
||||
if prefix_context is not None:
|
||||
input_resps_list = [ torch.concat( [ prefix, resps ] ) for prefix, resps in zip( prefix_resps_list, resps_list ) ]
|
||||
# originally requested no CFG, safe to ignore if we have a prefix
|
||||
if original_cfg_strength == 0:
|
||||
sampling_cfg = 0
|
||||
else:
|
||||
input_resps_list = resps_list
|
||||
|
||||
# setup inputs
|
||||
inputs = super().inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
time_list=time_list,
|
||||
|
@ -349,7 +352,7 @@ class AR_NAR(Base):
|
|||
null_inputs = super().inputs(
|
||||
text_list=null_text,
|
||||
proms_list=null_prom,
|
||||
resps_list=resps_list,
|
||||
resps_list=input_resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
time_list=time_list,
|
||||
|
@ -433,6 +436,17 @@ class AR_NAR(Base):
|
|||
if max_levels == 0:
|
||||
max_levels = self.n_max_levels - 1
|
||||
|
||||
# prefixed context provided
|
||||
"""
|
||||
prefix_context = sampling_kwargs.get("prefix_context", None)
|
||||
if prefix_context is not None:
|
||||
prefix_text, prefix_resps, _ = prefix_context
|
||||
# to-do: check if we actually need to drop the middle "<eos><bos>"
|
||||
text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, text_list ) ]
|
||||
# feeding this into the NAR-len should automatically handle things
|
||||
resps_list = [ resps for resps in prefix_resps ]
|
||||
"""
|
||||
|
||||
"""
|
||||
sampling_layer_skip_variables = {} if sampling_layer_skip else None
|
||||
|
||||
|
@ -468,9 +482,11 @@ class AR_NAR(Base):
|
|||
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||
null_prom = [ None for _ in range(batch_size) ]
|
||||
|
||||
for n in trange( max_levels, desc="NAR", disable=disable_tqdm ):
|
||||
iterator = trange( max_levels, desc="NAR", disable=disable_tqdm )
|
||||
for n in iterator:
|
||||
level = prev_list[0].shape[-1]
|
||||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||
iterator.close()
|
||||
break
|
||||
|
||||
if cfg.lora is not None:
|
||||
|
@ -578,7 +594,8 @@ class AR_NAR(Base):
|
|||
task_list = [ "len" for _ in range(batch_size) ]
|
||||
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
|
||||
|
||||
for n in trange(10, desc="AR", disable=disable_tqdm):
|
||||
iterator = trange(10, desc="AR", disable=disable_tqdm)
|
||||
for n in iterator:
|
||||
len_list = sequence_list
|
||||
|
||||
inputs = self.inputs(
|
||||
|
@ -613,6 +630,7 @@ class AR_NAR(Base):
|
|||
# stop token found
|
||||
stopped |= r == stop_token
|
||||
if stopped.all().item():
|
||||
iterator.close()
|
||||
break
|
||||
|
||||
# convert tokens into int
|
||||
|
@ -660,11 +678,21 @@ class AR_NAR(Base):
|
|||
sequence_list[i] = sequence_list[i][:, 0]
|
||||
# start_slice[i] = sequence_list[i].shape[0]
|
||||
|
||||
# prefixed context provided
|
||||
prefix_context = sampling_kwargs.get("prefix_context", None)
|
||||
if prefix_context is not None:
|
||||
prefix_text, prefix_resps, _ = prefix_context
|
||||
# to-do: check if we actually need to drop the middle "<eos><bos>"
|
||||
text_list = [ torch.concat([prefix[:-1], text[1:]]) for prefix, text in zip( prefix_text, text_list ) ]
|
||||
# feeding this into the NAR-len should automatically handle things
|
||||
sequence_list = [ resps if resps.dim() == 1 else resps[:, 0] for resps in prefix_resps ]
|
||||
|
||||
null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ]
|
||||
null_prom = [ None for _ in range(batch_size) ]
|
||||
|
||||
# get next in sequence
|
||||
for n in trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
|
||||
iterator = trange(max_duration // max(1, self.causal_size), desc="AR", disable=disable_tqdm)
|
||||
for n in iterator:
|
||||
# it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it
|
||||
text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ]
|
||||
resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ]
|
||||
|
@ -750,6 +778,7 @@ class AR_NAR(Base):
|
|||
# stop token found
|
||||
# stopped |= r == stop_token
|
||||
if stopped.all().item():
|
||||
iterator.close()
|
||||
break
|
||||
|
||||
# to-do for layerskip / speculative sampling: rerun the last sequence again at max depth
|
||||
|
@ -785,7 +814,12 @@ class AR_NAR(Base):
|
|||
refined_list = [ logit.argmax(dim=-1) for logit in logits ]
|
||||
# to-do: compare scores
|
||||
# set the "refined" list as the output
|
||||
sequence_list = refined_list
|
||||
sequence_list = refined_list
|
||||
|
||||
# slice out prefix
|
||||
if prefix_context is not None:
|
||||
prefix_text, prefix_resps, prefix_lens = prefix_context
|
||||
sequence_list = [ resps[l:] for resps, l in zip(sequence_list, prefix_lens) ]
|
||||
|
||||
return sequence_list
|
||||
|
||||
|
@ -837,6 +871,8 @@ class AR_NAR(Base):
|
|||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
)
|
||||
|
||||
# is NAR
|
||||
|
@ -849,6 +885,8 @@ class AR_NAR(Base):
|
|||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
|
@ -861,6 +899,8 @@ class AR_NAR(Base):
|
|||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
|
|
|
@ -204,6 +204,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--modality", type=str, default=kwargs["modality"])
|
||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||
parser.add_argument("--language", type=str, default=kwargs["language"])
|
||||
parser.add_argument("--split-text-by", type=str, default=kwargs["split-text-by"])
|
||||
parser.add_argument("--context-history", type=int, default=kwargs["context-history"])
|
||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||
parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
|
||||
parser.add_argument("--max-duration", type=int, default=int(kwargs["max-duration"]*cfg.dataset.frames_per_second))
|
||||
|
@ -257,11 +259,18 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
if kwargs.pop("refine-on-stop", False):
|
||||
args.refine_on_stop = True
|
||||
|
||||
if args.split_text_by == "lines":
|
||||
args.split_text_by = "\n"
|
||||
elif args.split_text_by == "none":
|
||||
args.split_text_by = None
|
||||
|
||||
tts = init_tts()
|
||||
|
||||
gr.Info(f"Inferencing... (Modality: {tts.modality(args.modality.lower())})")
|
||||
|
||||
sampling_kwargs = dict(
|
||||
split_text_by=args.split_text_by,
|
||||
context_history=args.context_history,
|
||||
max_steps=args.max_steps,
|
||||
max_levels=args.max_levels,
|
||||
max_duration=args.max_duration,
|
||||
|
@ -437,6 +446,9 @@ with ui:
|
|||
layout["inference_tts"]["inputs"]["cfg-strength"] = gr.Slider(value=1.0, minimum=0.0, maximum=14.0, step=0.05, label="CFG Strength", info="Classifier Free Guidance scale (AR needs 1, NAR-len needs 3).")
|
||||
layout["inference_tts"]["inputs"]["cfg-rescale"] = gr.Slider(value=0.75, minimum=0.0, maximum=1.0, step=0.05, label="CFG Rescale (Phi)", info="Factor when rescaling for Classifier Free Guidance (0 to disable).")
|
||||
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["split-text-by"] = gr.Dropdown(choices=["sentences", "lines"], label="Text Delimiter", info="Splits the text into pieces.", value="sentences")
|
||||
layout["inference_tts"]["inputs"]["context-history"] = gr.Slider(value=0, minimum=0, maximum=4, step=1, label="(Rolling) Context History", info="How many prior lines to serve as the context/prefix (0 to disable).")
|
||||
with gr.Tab("Sampler Settings"):
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
||||
|
|
Loading…
Reference in New Issue
Block a user