From 5d80a2d0d44f9b1949f7b43cdeebe48e2772649c Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 7 Dec 2024 19:21:05 -0600 Subject: [PATCH] fixed NAR-len issues with non-english maybe (langs weren't being passed), added interface to inference in batches through tts.batched_inference (no support for rolling context/prefixes because there's no way to do that), demo page uses batched inferencing now --- vall_e/config.py | 2 + vall_e/demo.py | 156 ++++++++++++++-------------------- vall_e/inference.py | 181 +++++++++++++++++++++++++++++++++++----- vall_e/models/ar_nar.py | 2 +- 4 files changed, 225 insertions(+), 116 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 6fd6547..2164914 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -744,6 +744,8 @@ class Inference: normalize: bool = False # to-do: actually normalize input / output audio, I believe this might cause issues though + batch_size: int = 16 # I don't know what would be a good batch size + @property def dtype(self): if self.weight_dtype == "float16": diff --git a/vall_e/demo.py b/vall_e/demo.py index 603a61a..4d73143 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -37,12 +37,40 @@ def encode(path): return "" return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8') +def safe_inference( tts, out_path, **kwargs ): + if args.skip_existing and out_path.exists(): + return + + try: + tts.inference( out_path=out_path, **kwargs ) + except Exception as e: + raise e + print(f'Error while processing {out_path}: {e}') + +def safe_batched_inference( tts, **kwargs ): + try: + tts.batched_inference( **kwargs ) + except Exception as e: + raise e + print(f'Error while processing batch: {e}') + +def process_batch( tts, inputs, kwargs={} ): + kwargs = kwargs | dict( + texts=[ x[0] for x in inputs ], + references=[ x[1] for x in inputs ], + languages=[ x[2] for x in inputs ], + out_paths=[ x[3] for x in inputs ], + ) + + safe_batched_inference( tts, **kwargs ) + # Would be downright sugoi if I could incorporate this with into __main__ def main(): parser = argparse.ArgumentParser("VALL-E TTS Demo") parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--model", type=Path, default=None) + parser.add_argument("--batch-size", type=int, default=0) parser.add_argument("--demo-dir", type=Path, default=None) parser.add_argument("--skip-existing", action="store_true") @@ -61,14 +89,14 @@ def main(): parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second) - parser.add_argument("--max-steps", type=int, default=25) + parser.add_argument("--max-steps", type=int, default=50) parser.add_argument("--max-levels", type=int, default=7) parser.add_argument("--ar-temperature", type=float, default=1.0) parser.add_argument("--nar-temperature", type=float, default=0.0) parser.add_argument("--min-ar-temperature", type=float, default=-1.0) parser.add_argument("--min-nar-temperature", type=float, default=-1.0) - parser.add_argument("--input-prompt-length", type=float, default=3.0) + parser.add_argument("--input-prompt-length", type=float, default=5.0) parser.add_argument("--input-prompt-prefix", action="store_true") parser.add_argument("--prefix-silence", type=float, default=0.0) parser.add_argument("--cfg-strength", type=float, default=1.0) @@ -90,18 +118,6 @@ def main(): parser.add_argument("--dry-base", type=float, default=1.75) parser.add_argument("--dry-allowed-length", type=int, default=2) - parser.add_argument("--entropix-sampling", action="store_true") - - parser.add_argument("--layer-skip", action="store_true") - parser.add_argument("--layer-skip-exit-layer", type=int, default=None) - parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1) - parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1) - parser.add_argument("--refine-on-stop", action="store_true") - - # experimental settings - parser.add_argument("--load-from-artifact", type=Path, default=None) - parser.add_argument("--denoise-start", type=float, default=0.0) - parser.add_argument("--seed", type=int, default=None) parser.add_argument("--device", type=str, default=None) @@ -151,30 +167,6 @@ def main(): comparison_kwargs["disabled"]["ar_temperature"] = 0.0 comparison_kwargs["enabled"]["use_lora"] = False comparison_kwargs["enabled"]["ar_temperature"] = 0.95 - elif args.comparison == "entropix-sampling": - comparison_kwargs["suffix"] = "entropix_sampling" - comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"] - - comparison_kwargs["disabled"]["entropix_sampling"] = False - comparison_kwargs["disabled"]["ar_temperature"] = args.ar_temperature - comparison_kwargs["disabled"]["top_k"] = args.top_k - comparison_kwargs["disabled"]["top_p"] = args.top_p - comparison_kwargs["enabled"]["entropix_sampling"] = True - comparison_kwargs["enabled"]["ar_temperature"] = 0.666 - comparison_kwargs["enabled"]["top_k"] = 27 - comparison_kwargs["enabled"]["top_p"] = 0.9 - elif args.comparison == "layerskip": - comparison_kwargs["suffix"] = "layerskip" - comparison_kwargs["titles"] = [f"Without LayerSkip", "With LayerSkip"] - - comparison_kwargs["disabled"]["layer_skip"] = False - comparison_kwargs["enabled"]["layer_skip"] = True - elif args.comparison == "refine-on-stop": - comparison_kwargs["suffix"] = "refine-on-stop" - comparison_kwargs["titles"] = [f"Without Ro", "With Ro"] - - comparison_kwargs["disabled"]["refine_on_stop"] = False - comparison_kwargs["enabled"]["refine_on_stop"] = True elif args.comparison == "ar-temp": current_temperature = args.ar_temperature other_temperature = 1.0 @@ -254,18 +246,15 @@ def main(): beam_width=args.beam_width, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, - entropix_sampling=args.entropix_sampling, - layer_skip=args.layer_skip, - layer_skip_exit_layer=args.layer_skip_exit_layer, - layer_skip_entropy_threshold=args.layer_skip_entropy_threshold, - layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold, - refine_on_stop=args.refine_on_stop, - denoise_start=args.denoise_start, input_prompt_length=args.input_prompt_length, input_prompt_prefix=args.input_prompt_prefix, prefix_silence=args.prefix_silence, cfg_strength=args.cfg_strength, cfg_rescale=args.cfg_rescale, + + seed = args.seed if args.seed else int(time.time()), + tqdm = True, + batch_size = args.batch_size, ) # replace values in our template @@ -326,6 +315,9 @@ def main(): decode_to_file( batch["proms"].to("cuda"), prompt, device="cuda" ) decode_to_file( batch["resps"].to("cuda"), reference, device="cuda" ) + inputs = [] + outputs = [] + comparison_inputs = [] for k, sample_dir in samples_dirs.items(): if not sample_dir.exists(): continue @@ -349,6 +341,13 @@ def main(): audio_samples += [ out_path_comparison ] audio_samples += [ p if p.exists() else None for p in external_sources ] + """ + # manual invocation + cmd = f'python3 -m vall_e --yaml="{args.yaml}" "{reference}" "{text}" --out-path={out_path}' + # F5 + cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"' + """ + if not args.random_prompts or k == "librispeech": audio_samples += [ reference ] @@ -357,51 +356,22 @@ def main(): audio_samples, )) - seed = args.seed if args.seed else int(time.time()) - - """ - # manual invocation - cmd = f'python3 -m vall_e --yaml="{args.yaml}" "{reference}" "{text}" --out-path={out_path}' - # F5 - cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"' - """ - - kwargs = dict( - text=text, - references=[prompt], - language=language, - seed=seed, - tqdm=False, - **sampling_kwargs, - ) - - def safe_inference( out_path=out_path ): - if args.skip_existing and out_path.exists(): - return - - # swap model config swap - """ - if "dtype" in kwargs or "amp" in kwargs: - dtype = kwargs.pop("dtype", args.dtype) - amp = kwargs.pop("amp", args.amp) - - del tts - tts = TTS( config=args.yaml, device=args.device, dtype=dtype, amp=amp ) - """ - try: - tts.inference( out_path=out_path, **kwargs ) - except Exception as e: - raise e - print(f'Error while processing {out_path}: {e}') - + # segregate comparisons into its own batch because they use different kwargs (and I do not support variadic-batched kwargs) if args.comparison: - kwargs.update( comparison_kwargs["enabled"] ) - safe_inference(out_path_comparison) - kwargs.update( comparison_kwargs["disabled"] ) + comparison_inputs.append((text, prompt, language, out_path_comparison)) - safe_inference() + inputs.append((text, prompt, language, out_path)) - # collate entries into HTML + outputs.append((k, samples)) + + if inputs: + process_batch( tts, inputs, sampling_kwargs | (comparison_kwargs["disabled"] if args.comparison else {}) ) + + if comparison_inputs: + process_batch( tts, comparison_inputs, sampling_kwargs | (comparison_kwargs["enabled"] if args.comparison else {}) ) + + # collate entries into HTML + for k, samples in outputs: samples = [ f'\n\t\t\t\n\t\t\t\t{text}'+ "".join( [ @@ -415,12 +385,12 @@ def main(): # write audio into template html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) ) - if args.comparison: - disabled, enabled = comparison_kwargs["titles"] - if args.random_prompts: - html = html.replace("Our VALL-E\n\t\t\t\t\tGround Truth", f"Our VALL-E ({disabled})\n\t\t\t\t\tOur VALL-E ({enabled})") - else: - html = html.replace("Our VALL-E", f"Our VALL-E ({disabled})\n\t\t\t\t\tOur VALL-E ({enabled})") + if args.comparison: + disabled, enabled = comparison_kwargs["titles"] + if args.random_prompts: + html = html.replace("Our VALL-E\n\t\t\t\t\tGround Truth", f"Our VALL-E ({disabled})\n\t\t\t\t\tOur VALL-E ({enabled})") + else: + html = html.replace("Our VALL-E", f"Our VALL-E ({disabled})\n\t\t\t\t\tOur VALL-E ({enabled})") # write demo page open( args.demo_dir / args.output_filename, "w", encoding="utf-8" ).write( html ) diff --git a/vall_e/inference.py b/vall_e/inference.py index 7df9534..e65027c 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -68,6 +68,7 @@ class TTS(): self.device = device self.dtype = cfg.inference.dtype self.amp = amp + self.batch_size = cfg.inference.batch_size self.model_kwargs = {} if attention: @@ -120,10 +121,13 @@ class TTS(): if isinstance( paths, str ): paths = [ Path(p) for p in paths.split(";") ] - # merge inputs + # not already a list + if isinstance( paths, Path ): + paths = [ paths ] proms = [] + # merge inputs for path in paths: prom = qnt.encode_from_file(path) if hasattr( prom, "codes" ): @@ -185,26 +189,159 @@ class TTS(): modality = cfg.model.name return modality + # makes use of being able to batch inputs seamlessly by automatically batching + # this is NOT the default because it absolutely cannot make use of rolling context / prefixing + @torch.inference_mode() + def batched_inference( + self, + texts, + references=None, + languages=None, + text_languages=None, + out_paths=None, + **sampling_kwargs, + ): + batch_size = sampling_kwargs.pop("batch_size", self.batch_size) + input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0) + modality = sampling_kwargs.pop("modality", "auto") + seed = sampling_kwargs.pop("seed", None) + tqdm = sampling_kwargs.pop("tqdm", True) + use_lora = sampling_kwargs.pop("use_lora", None) + dtype = sampling_kwargs.pop("dtype", self.dtype) + amp = sampling_kwargs.pop("amp", self.amp) + + model_ar = None + model_len = None + model_nar = None + + for name, engine in self.engines.items(): + if model_ar is None and "ar" in engine.hyper_config.capabilities: + model_ar = engine.module + if model_len is None and "len" in engine.hyper_config.capabilities: + model_len = engine.module + if model_nar is None and "nar" in engine.hyper_config.capabilities: + model_nar = engine.module + + modality = self.modality( modality ) + # force AR+NAR + if modality == "ar+nar": + model_len = None + # force NAR-len + elif modality == "nar-len": + model_ar = None + + samples = len(texts) + # fill with null input proms + if not references: + references = [ None for _ in range(samples) ] + # fill with english + if not languages: + languages = [ "en" for _ in range(samples) ] + if not out_paths: + out_paths = [ None for _ in range(samples) ] + # use the audio language to phonemize the text + if not text_languages: + text_languages = languages + + # tensorfy inputs + for i in range( samples ): + texts[i] = self.encode_text( texts[i], language=text_languages[i] ) + references[i] = self.encode_audio( references[i], trim_length=input_prompt_length ) if references[i] else None + languages[i] = self.encode_lang( languages[i] ) + + texts[i] = to_device(texts[i], device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16) + references[i] = to_device(references[i], device=self.device, dtype=torch.int16) + languages[i] = to_device(languages[i], device=self.device, dtype=torch.uint8) + + # create batches + batches = [] + buffer = ([], [], [], []) + for batch in zip( texts, references, languages, out_paths ): + # flush + if len(buffer[0]) >= batch_size: + batches.append(buffer) + buffer = ([], [], [], []) + + # insert into buffer + for i, x in enumerate( batch ): + buffer[i].append(x) + + # flush + if len(buffer[0]) >= batch_size: + batches.append(buffer) + buffer = ([], [], [], []) + + wavs = [] + for texts, proms, langs, out_paths in batches: + seed = set_seed(seed) + batch_size = len(texts) + input_kwargs = dict( + text_list=texts, + proms_list=proms, + lang_list=langs, + disable_tqdm=not tqdm, + use_lora=use_lora, + ) + + with torch.autocast("cuda", dtype=dtype, enabled=amp): + if model_len is not None: + # extra kwargs + duration_padding = sampling_kwargs.pop("duration_padding", 1.05) + nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0) + + len_list = model_len( **input_kwargs, task_list=["len"]*batch_size, **{"max_duration": 5} ) # "max_duration" is max tokens + + # add an additional X seconds + len_list = [ int(l * duration_padding) for l in len_list ] + + resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"]*batch_size, + **sampling_kwargs, + ) + elif model_ar is not None: + resps_list = model_ar( + **input_kwargs, task_list=["tts"]*batch_size, + **sampling_kwargs, + ) + + resps_list = model_nar( + **input_kwargs, resps_list=resps_list, task_list=["tts"]*batch_size, + **sampling_kwargs, + ) + else: + raise Exception("!") + + for resp, out_path in zip( resps_list, out_paths ): + if out_path: + wav, sr = qnt.decode_to_file(resp, out_path, device=self.device) + else: + wav, sr = qnt.decode(resp, device=self.device) + wavs.append(wav) + return wavs + + # naive serial inferencing + # will automatically split a text into pieces (if requested) piece by piece @torch.inference_mode() def inference( self, text, references, - text_language=None, language="en", + text_language=None, task="tts", - modality="auto", - - input_prompt_length = 0, - - seed = None, out_path=None, - tqdm=True, - use_lora=None, **sampling_kwargs, ): + input_prompt_length = sampling_kwargs.pop("input_prompt_length", 0) + modality = sampling_kwargs.pop("modality", "auto") + seed = sampling_kwargs.pop("seed", None) + tqdm = sampling_kwargs.pop("tqdm", True) + use_lora = sampling_kwargs.pop("use_lora", None) + dtype = sampling_kwargs.pop("dtype", self.dtype) + amp = sampling_kwargs.pop("amp", self.amp) + if not text_language: text_language = language + lines = sentence_split(text, split_by=sampling_kwargs.get("split_text_by", "sentences")) wavs = [] @@ -239,7 +376,7 @@ class TTS(): resp = to_device(resp, device=self.device, dtype=torch.int16) lang = to_device(lang, device=self.device, dtype=torch.uint8) - with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): + with torch.autocast("cuda", dtype=dtype, enabled=amp): model = model_ar if model_ar is not None else model_nar if model is not None: text_list = model( @@ -275,14 +412,20 @@ class TTS(): phns = to_device(phns, device=self.device, dtype=torch.uint8 if len(self.symmap) < 256 else torch.int16) lang = to_device(lang, device=self.device, dtype=torch.uint8) - # to-do: add in case for experimental.hf model - with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): + with torch.autocast("cuda", dtype=dtype, enabled=amp): + input_kwargs = dict( + text_list=[phns], + proms_list=[prom], + lang_list=[lang], + disable_tqdm=not tqdm, + use_lora=use_lora, + ) if model_len is not None: # extra kwargs duration_padding = sampling_kwargs.pop("duration_padding", 1.05) nar_len_prefix_length = sampling_kwargs.pop("nar_len_prefix_length", 0) - len_list = model_len( text_list=[phns], proms_list=[prom], task_list=["len"], disable_tqdm=not tqdm, **{"max_duration": 5} ) # "max_duration" is max tokens + len_list = model_len( **input_kwargs, task_list=["len"], **{"max_duration": 5} ) # "max_duration" is max tokens # add an additional X seconds len_list = [ int(l * duration_padding) for l in len_list ] @@ -291,9 +434,7 @@ class TTS(): if prefix_context is not None: kwargs["prefix_context"] = prefix_context - resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, task_list=["tts"], - disable_tqdm=not tqdm, - use_lora=use_lora, + resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"], **(sampling_kwargs | kwargs), ) elif model_ar is not None: @@ -302,16 +443,12 @@ class TTS(): kwargs["prefix_context"] = prefix_context resps_list = model_ar( - text_list=[phns], proms_list=[prom], lang_list=[lang], task_list=["tts"], - disable_tqdm=not tqdm, - use_lora=use_lora, + **input_kwargs, task_list=["tts"], **(sampling_kwargs | kwargs), ) resps_list = model_nar( - text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, task_list=["tts"], - disable_tqdm=not tqdm, - use_lora=use_lora, + **input_kwargs, resps_list=resps_list, task_list=["tts"], **sampling_kwargs, ) else: diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index a771e81..1067a14 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -622,7 +622,7 @@ class AR_NAR(Base): r = [ logit[-1:].argmax(dim=1) for logit in logits ] # sanitize for i, token in enumerate(r): - if token > 10: + if token > stop_token: r[i][0] = stop_token # append tokens