fixed NAR-len issues with non-english maybe (langs weren't being passed), added interface to inference in batches through tts.batched_inference (no support for rolling context/prefixes because there's no way to do that), demo page uses batched inferencing now

This commit is contained in:
mrq 2024-12-07 19:21:05 -06:00
parent 1f54bf5b40
commit 5d80a2d0d4
4 changed files with 225 additions and 116 deletions

View File

@ -744,6 +744,8 @@ class Inference:
normalize: bool = False # to-do: actually normalize input / output audio, I believe this might cause issues though normalize: bool = False # to-do: actually normalize input / output audio, I believe this might cause issues though
batch_size: int = 16 # I don't know what would be a good batch size
@property @property
def dtype(self): def dtype(self):
if self.weight_dtype == "float16": if self.weight_dtype == "float16":

View File

@ -37,12 +37,40 @@ def encode(path):
return "" return ""
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8') return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
def safe_inference( tts, out_path, **kwargs ):
if args.skip_existing and out_path.exists():
return
try:
tts.inference( out_path=out_path, **kwargs )
except Exception as e:
raise e
print(f'Error while processing {out_path}: {e}')
def safe_batched_inference( tts, **kwargs ):
try:
tts.batched_inference( **kwargs )
except Exception as e:
raise e
print(f'Error while processing batch: {e}')
def process_batch( tts, inputs, kwargs={} ):
kwargs = kwargs | dict(
texts=[ x[0] for x in inputs ],
references=[ x[1] for x in inputs ],
languages=[ x[2] for x in inputs ],
out_paths=[ x[3] for x in inputs ],
)
safe_batched_inference( tts, **kwargs )
# Would be downright sugoi if I could incorporate this with into __main__ # Would be downright sugoi if I could incorporate this with into __main__
def main(): def main():
parser = argparse.ArgumentParser("VALL-E TTS Demo") parser = argparse.ArgumentParser("VALL-E TTS Demo")
parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--model", type=Path, default=None) parser.add_argument("--model", type=Path, default=None)
parser.add_argument("--batch-size", type=int, default=0)
parser.add_argument("--demo-dir", type=Path, default=None) parser.add_argument("--demo-dir", type=Path, default=None)
parser.add_argument("--skip-existing", action="store_true") parser.add_argument("--skip-existing", action="store_true")
@ -61,14 +89,14 @@ def main():
parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second) parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
parser.add_argument("--max-steps", type=int, default=25) parser.add_argument("--max-steps", type=int, default=50)
parser.add_argument("--max-levels", type=int, default=7) parser.add_argument("--max-levels", type=int, default=7)
parser.add_argument("--ar-temperature", type=float, default=1.0) parser.add_argument("--ar-temperature", type=float, default=1.0)
parser.add_argument("--nar-temperature", type=float, default=0.0) parser.add_argument("--nar-temperature", type=float, default=0.0)
parser.add_argument("--min-ar-temperature", type=float, default=-1.0) parser.add_argument("--min-ar-temperature", type=float, default=-1.0)
parser.add_argument("--min-nar-temperature", type=float, default=-1.0) parser.add_argument("--min-nar-temperature", type=float, default=-1.0)
parser.add_argument("--input-prompt-length", type=float, default=3.0) parser.add_argument("--input-prompt-length", type=float, default=5.0)
parser.add_argument("--input-prompt-prefix", action="store_true") parser.add_argument("--input-prompt-prefix", action="store_true")
parser.add_argument("--prefix-silence", type=float, default=0.0) parser.add_argument("--prefix-silence", type=float, default=0.0)
parser.add_argument("--cfg-strength", type=float, default=1.0) parser.add_argument("--cfg-strength", type=float, default=1.0)
@ -90,18 +118,6 @@ def main():
parser.add_argument("--dry-base", type=float, default=1.75) parser.add_argument("--dry-base", type=float, default=1.75)
parser.add_argument("--dry-allowed-length", type=int, default=2) parser.add_argument("--dry-allowed-length", type=int, default=2)
parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--layer-skip", action="store_true")
parser.add_argument("--layer-skip-exit-layer", type=int, default=None)
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
parser.add_argument("--refine-on-stop", action="store_true")
# experimental settings
parser.add_argument("--load-from-artifact", type=Path, default=None)
parser.add_argument("--denoise-start", type=float, default=0.0)
parser.add_argument("--seed", type=int, default=None) parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
@ -151,30 +167,6 @@ def main():
comparison_kwargs["disabled"]["ar_temperature"] = 0.0 comparison_kwargs["disabled"]["ar_temperature"] = 0.0
comparison_kwargs["enabled"]["use_lora"] = False comparison_kwargs["enabled"]["use_lora"] = False
comparison_kwargs["enabled"]["ar_temperature"] = 0.95 comparison_kwargs["enabled"]["ar_temperature"] = 0.95
elif args.comparison == "entropix-sampling":
comparison_kwargs["suffix"] = "entropix_sampling"
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
comparison_kwargs["disabled"]["entropix_sampling"] = False
comparison_kwargs["disabled"]["ar_temperature"] = args.ar_temperature
comparison_kwargs["disabled"]["top_k"] = args.top_k
comparison_kwargs["disabled"]["top_p"] = args.top_p
comparison_kwargs["enabled"]["entropix_sampling"] = True
comparison_kwargs["enabled"]["ar_temperature"] = 0.666
comparison_kwargs["enabled"]["top_k"] = 27
comparison_kwargs["enabled"]["top_p"] = 0.9
elif args.comparison == "layerskip":
comparison_kwargs["suffix"] = "layerskip"
comparison_kwargs["titles"] = [f"Without LayerSkip", "With LayerSkip"]
comparison_kwargs["disabled"]["layer_skip"] = False
comparison_kwargs["enabled"]["layer_skip"] = True
elif args.comparison == "refine-on-stop":
comparison_kwargs["suffix"] = "refine-on-stop"
comparison_kwargs["titles"] = [f"Without Ro<S>", "With Ro<S>"]
comparison_kwargs["disabled"]["refine_on_stop"] = False
comparison_kwargs["enabled"]["refine_on_stop"] = True
elif args.comparison == "ar-temp": elif args.comparison == "ar-temp":
current_temperature = args.ar_temperature current_temperature = args.ar_temperature
other_temperature = 1.0 other_temperature = 1.0
@ -254,18 +246,15 @@ def main():
beam_width=args.beam_width, beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
layer_skip=args.layer_skip,
layer_skip_exit_layer=args.layer_skip_exit_layer,
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
refine_on_stop=args.refine_on_stop,
denoise_start=args.denoise_start,
input_prompt_length=args.input_prompt_length, input_prompt_length=args.input_prompt_length,
input_prompt_prefix=args.input_prompt_prefix, input_prompt_prefix=args.input_prompt_prefix,
prefix_silence=args.prefix_silence, prefix_silence=args.prefix_silence,
cfg_strength=args.cfg_strength, cfg_strength=args.cfg_strength,
cfg_rescale=args.cfg_rescale, cfg_rescale=args.cfg_rescale,
seed = args.seed if args.seed else int(time.time()),
tqdm = True,
batch_size = args.batch_size,
) )
# replace values in our template # replace values in our template
@ -326,6 +315,9 @@ def main():
decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" ) decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" )
decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" ) decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" )
inputs = []
outputs = []
comparison_inputs = []
for k, sample_dir in samples_dirs.items(): for k, sample_dir in samples_dirs.items():
if not sample_dir.exists(): if not sample_dir.exists():
continue continue
@ -349,6 +341,13 @@ def main():
audio_samples += [ out_path_comparison ] audio_samples += [ out_path_comparison ]
audio_samples += [ p if p.exists() else None for p in external_sources ] audio_samples += [ p if p.exists() else None for p in external_sources ]
"""
# manual invocation
cmd = f'python3 -m vall_e --yaml="{args.yaml}" "{reference}" "{text}" --out-path={out_path}'
# F5
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
"""
if not args.random_prompts or k == "librispeech": if not args.random_prompts or k == "librispeech":
audio_samples += [ reference ] audio_samples += [ reference ]
@ -357,51 +356,22 @@ def main():
audio_samples, audio_samples,
)) ))
seed = args.seed if args.seed else int(time.time()) # segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs)
"""
# manual invocation
cmd = f'python3 -m vall_e --yaml="{args.yaml}" "{reference}" "{text}" --out-path={out_path}'
# F5
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
"""
kwargs = dict(
text=text,
references=[prompt],
language=language,
seed=seed,
tqdm=False,
**sampling_kwargs,
)
def safe_inference( out_path=out_path ):
if args.skip_existing and out_path.exists():
return
# swap model config swap
"""
if "dtype" in kwargs or "amp" in kwargs:
dtype = kwargs.pop("dtype", args.dtype)
amp = kwargs.pop("amp", args.amp)
del tts
tts = TTS( config=args.yaml, device=args.device, dtype=dtype, amp=amp )
"""
try:
tts.inference( out_path=out_path, **kwargs )
except Exception as e:
raise e
print(f'Error while processing {out_path}: {e}')
if args.comparison: if args.comparison:
kwargs.update( comparison_kwargs["enabled"] ) comparison_inputs.append((text, prompt, language, out_path_comparison))
safe_inference(out_path_comparison)
kwargs.update( comparison_kwargs["disabled"] )
safe_inference() inputs.append((text, prompt, language, out_path))
outputs.append((k, samples))
if inputs:
process_batch( tts, inputs, sampling_kwargs | (comparison_kwargs["disabled"] if args.comparison else {}) )
if comparison_inputs:
process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) )
# collate entries into HTML # collate entries into HTML
for k, samples in outputs:
samples = [ samples = [
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+ f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
"".join( [ "".join( [

View File

@ -68,6 +68,7 @@ class TTS():
self.device = device self.device = device
self.dtype = cfg.inference.dtype self.dtype = cfg.inference.dtype
self.amp = amp self.amp = amp
self.batch_size = cfg.inference.batch_size
self.model_kwargs = {} self.model_kwargs = {}
if attention: if attention:
@ -120,10 +121,13 @@ class TTS():
if isinstance( paths, str ): if isinstance( paths, str ):
paths = [ Path(p) for p in paths.split(";") ] paths = [ Path(p) for p in paths.split(";") ]
# merge inputs # not already a list
if isinstance( paths, Path ):
paths = [ paths ]
proms = [] proms = []
# merge inputs
for path in paths: for path in paths:
prom = qnt.encode_from_file(path) prom = qnt.encode_from_file(path)
if hasattr( prom, "codes" ): if hasattr( prom, "codes" ):
@ -185,26 +189,159 @@ class TTS():
modality = cfg.model.name modality = cfg.model.name
return modality return modality
# makes use of being able to batch inputs seamlessly by automatically batching
# this is NOT the default because it absolutely cannot make use of rolling context / prefixing
@torch.inference_mode()
def batched_inference(
self,
texts,
references=None,
languages=None,
text_languages=None,
out_paths=None,
**sampling_kwargs,
):
batch_size = sampling_kwargs.pop("batch_size", self.batch_size)
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
modality = sampling_kwargs.pop("modality", "auto")
seed = sampling_kwargs.pop("seed", None)
tqdm = sampling_kwargs.pop("tqdm", True)
use_lora = sampling_kwargs.pop("use_lora", None)
dtype = sampling_kwargs.pop("dtype", self.dtype)
amp = sampling_kwargs.pop("amp", self.amp)
model_ar = None
model_len = None
model_nar = None
for name, engine in self.engines.items():
if model_ar is None and "ar" in engine.hyper_config.capabilities:
model_ar = engine.module
if model_len is None and "len" in engine.hyper_config.capabilities:
model_len = engine.module
if model_nar is None and "nar" in engine.hyper_config.capabilities:
model_nar = engine.module
modality = self.modality( modality )
# force AR+NAR
if modality == "ar+nar":
model_len = None
# force NAR-len
elif modality == "nar-len":
model_ar = None
samples = len(texts)
# fill with null input proms
if not references:
references = [ None for _ in range(samples) ]
# fill with english
if not languages:
languages = [ "en" for _ in range(samples) ]
if not out_paths:
out_paths = [ None for _ in range(samples) ]
# use the audio language to phonemize the text
if not text_languages:
text_languages = languages
# tensorfy inputs
for i in range( samples ):
texts[i] = self.encode_text( texts[i], language=text_languages[i] )
references[i] = self.encode_audio( references[i], trim_length=input_prompt_length ) if references[i] else None
languages[i] = self.encode_lang( languages[i] )
texts[i] = to_device(texts[i], device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
references[i] = to_device(references[i], device=self.device, dtype=torch.int16)
languages[i] = to_device(languages[i], device=self.device, dtype=torch.uint8)
# create batches
batches = []
buffer = ([], [], [], [])
for batch in zip( texts, references, languages, out_paths ):
# flush
if len(buffer[0]) >= batch_size:
batches.append(buffer)
buffer = ([], [], [], [])
# insert into buffer
for i, x in enumerate( batch ):
buffer[i].append(x)
# flush
if len(buffer[0]) >= batch_size:
batches.append(buffer)
buffer = ([], [], [], [])
wavs = []
for texts, proms, langs, out_paths in batches:
seed = set_seed(seed)
batch_size = len(texts)
input_kwargs = dict(
text_list=texts,
proms_list=proms,
lang_list=langs,
disable_tqdm=not tqdm,
use_lora=use_lora,
)
with torch.autocast("cuda", dtype=dtype, enabled=amp):
if model_len is not None:
# extra kwargs
duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0)
len_list = model_len( **input_kwargs, task_list=["len"]*batch_size, **{"max_duration": 5} ) # "max_duration" is max tokens
# add an additional X seconds
len_list = [ int(l * duration_padding) for l in len_list ]
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"]*batch_size,
**sampling_kwargs,
)
elif model_ar is not None:
resps_list = model_ar(
**input_kwargs, task_list=["tts"]*batch_size,
**sampling_kwargs,
)
resps_list = model_nar(
**input_kwargs, resps_list=resps_list, task_list=["tts"]*batch_size,
**sampling_kwargs,
)
else:
raise Exception("!")
for resp, out_path in zip( resps_list, out_paths ):
if out_path:
wav, sr = qnt.decode_to_file(resp, out_path, device=self.device)
else:
wav, sr = qnt.decode(resp, device=self.device)
wavs.append(wav)
return wavs
# naive serial inferencing
# will automatically split a text into pieces (if requested) piece by piece
@torch.inference_mode() @torch.inference_mode()
def inference( def inference(
self, self,
text, text,
references, references,
text_language=None,
language="en", language="en",
text_language=None,
task="tts", task="tts",
modality="auto",
input_prompt_length = 0,
seed = None,
out_path=None, out_path=None,
tqdm=True,
use_lora=None,
**sampling_kwargs, **sampling_kwargs,
): ):
input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0)
modality = sampling_kwargs.pop("modality", "auto")
seed = sampling_kwargs.pop("seed", None)
tqdm = sampling_kwargs.pop("tqdm", True)
use_lora = sampling_kwargs.pop("use_lora", None)
dtype = sampling_kwargs.pop("dtype", self.dtype)
amp = sampling_kwargs.pop("amp", self.amp)
if not text_language: if not text_language:
text_language = language text_language = language
lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences")) lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences"))
wavs = [] wavs = []
@ -239,7 +376,7 @@ class TTS():
resp = to_device(resp, device=self.device, dtype=torch.int16) resp = to_device(resp, device=self.device, dtype=torch.int16)
lang = to_device(lang, device=self.device, dtype=torch.uint8) lang = to_device(lang, device=self.device, dtype=torch.uint8)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): with torch.autocast("cuda", dtype=dtype, enabled=amp):
model = model_ar if model_ar is not None else model_nar model = model_ar if model_ar is not None else model_nar
if model is not None: if model is not None:
text_list = model( text_list = model(
@ -275,14 +412,20 @@ class TTS():
phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16) phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16)
lang = to_device(lang, device=self.device, dtype=torch.uint8) lang = to_device(lang, device=self.device, dtype=torch.uint8)
# to-do: add in case for experimental.hf model with torch.autocast("cuda", dtype=dtype, enabled=amp):
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): input_kwargs = dict(
text_list=[phns],
proms_list=[prom],
lang_list=[lang],
disable_tqdm=not tqdm,
use_lora=use_lora,
)
if model_len is not None: if model_len is not None:
# extra kwargs # extra kwargs
duration_padding = sampling_kwargs.pop("duration_padding", 1.05) duration_padding = sampling_kwargs.pop("duration_padding", 1.05)
nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0) nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0)
len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_duration": 5} ) # "max_duration" is max tokens len_list = model_len( **input_kwargs, task_list=["len"], **{"max_duration": 5} ) # "max_duration" is max tokens
# add an additional X seconds # add an additional X seconds
len_list = [ int(l * duration_padding) for l in len_list ] len_list = [ int(l * duration_padding) for l in len_list ]
@ -291,9 +434,7 @@ class TTS():
if prefix_context is not None: if prefix_context is not None:
kwargs["prefix_context"] = prefix_context kwargs["prefix_context"] = prefix_context
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, task_list=["tts"], resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"],
disable_tqdm=not tqdm,
use_lora=use_lora,
**(sampling_kwargs | kwargs), **(sampling_kwargs | kwargs),
) )
elif model_ar is not None: elif model_ar is not None:
@ -302,16 +443,12 @@ class TTS():
kwargs["prefix_context"] = prefix_context kwargs["prefix_context"] = prefix_context
resps_list = model_ar( resps_list = model_ar(
text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"], **input_kwargs, task_list=["tts"],
disable_tqdm=not tqdm,
use_lora=use_lora,
**(sampling_kwargs | kwargs), **(sampling_kwargs | kwargs),
) )
resps_list = model_nar( resps_list = model_nar(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"], **input_kwargs, resps_list=resps_list, task_list=["tts"],
disable_tqdm=not tqdm,
use_lora=use_lora,
**sampling_kwargs, **sampling_kwargs,
) )
else: else:

View File

@ -622,7 +622,7 @@ class AR_NAR(Base):
r = [ logit[-1:].argmax(dim=1) for logit in logits ] r = [ logit[-1:].argmax(dim=1) for logit in logits ]
# sanitize # sanitize
for i, token in enumerate(r): for i, token in enumerate(r):
if token > 10: if token > stop_token:
r[i][0] = stop_token r[i][0] = stop_token
# append tokens # append tokens