diff --git a/README.md b/README.md index 1c3eb9b..0f3222f 100755 --- a/README.md +++ b/README.md @@ -172,6 +172,7 @@ For audio backends: * `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper) * `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently) * `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while + * `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed) * `transformers` Llama\*Attention implementations: * `eager`: default `LlamaAttention` * `sdpa`: integrated `LlamaSdpaAttention` attention model @@ -337,6 +338,7 @@ Despite how lightweight it is in comparison to other TTS's I've meddled with, th * speakers that aren't similar to an audiobook narrator voice has similarity issues due to the majority of training used `path`-based dataloader sampling instead of `speaker`-based (or `group`-based) dataloader sampling. + although LoRAs help a ton for fixing results for a single voice. + a diverse dataset in prosidy and speaker (such as a corpus sourced from dramatic media like video games) helps a ton. +* On my test system (7900XTX), it seems inferencing quality depends on the moon phase; I don't know if it's a matter of ROCm nuances (since I've always found it to not be up to par with actual CUDA) or `bfloat16` (due to the model being trained under `float16`+AMP) being the culprit, but your mileage *will* vary depending on the system + dtype + sampler settings. ## Notices and Citations diff --git a/vall_e/config.py b/vall_e/config.py index 9ade194..eab5d7f 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -26,6 +26,7 @@ from .utils import set_seed, prune_missing @dataclass() class BaseConfig: yaml_path: str | None = None # path passed in through --yaml + model_path: str | None = None # path passed in through --model @property def cfg_path(self): @@ -114,12 +115,12 @@ class BaseConfig: if arg.startswith("yaml"): args[i] = f'--{arg}' - parser = argparse.ArgumentParser(allow_abbrev=False) + parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False) parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too args, unknown = parser.parse_known_args(args=args) if args.yaml: - return cls.from_yaml( args.yaml ) + return cls.from_yaml( args.yaml ) return cls(**{}) diff --git a/vall_e/demo.py b/vall_e/demo.py index cef1077..9c65ebf 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -33,6 +33,8 @@ from .emb.qnt import decode_to_file from tqdm import tqdm, trange def encode(path): + if path is None or path.exists(): + return "" return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8') # Would be downright sugoi if I could incorporate this with into __main__ @@ -122,7 +124,8 @@ def main(): comparison_kwargs["enabled"]["use_lora"] = True elif args.comparison == "entropix-sampling": comparison_kwargs["suffix"] = "entropix_sampling" - comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"] + comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"] + comparison_kwargs["disabled"]["entropix_sampling"] = False comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp comparison_kwargs["disabled"]["top_k"] = args.top_k @@ -132,17 +135,46 @@ def main(): comparison_kwargs["enabled"]["top_k"] = 27 comparison_kwargs["enabled"]["top_p"] = 0.9 elif args.comparison == "ar-temp": + current_temp = args.ar_temp + other_temp = 1.0 + comparison_kwargs["suffix"] = "temperature" - comparison_kwargs["titles"] = [f"Temp: {args.ar_temp:.2f}", "Temp: 1.0"] + comparison_kwargs["titles"] = [f"Temp: {current_temp:.2f}", f"Temp: {other_temp:.2f}"] - comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp - comparison_kwargs["enabled"]["ar_temp"] = 1.0 + comparison_kwargs["disabled"]["ar_temp"] = current_temp + comparison_kwargs["enabled"]["ar_temp"] = other_temp elif args.comparison == "input-prompt-length": - comparison_kwargs["suffix"] = "input_prompt_length" - comparison_kwargs["titles"] = [f"Prompt Length: {args.input_prompt_length:.2f}s", "Prompt Length: 6.0s"] + current_length = args.input_prompt_length + other_length = 3.0 - comparison_kwargs["disabled"]["input-prompt-length"] = args.input_prompt_length - comparison_kwargs["enabled"]["input-prompt-length"] = 6.0 + comparison_kwargs["suffix"] = "input_prompt_length" + comparison_kwargs["titles"] = [f"Prompt Length: {current_length:.2f}s", f"Prompt Length: {other_length:.2f}s"] + + comparison_kwargs["disabled"]["input_prompt_length"] = current_length + comparison_kwargs["enabled"]["input_prompt_length"] = other_length + elif args.comparison == "dtype": + current_dtype = cfg.inference.weight_dtype + other_dtype = "float32" + + if current_dtype == "float16": + other_dtype = "bfloat16" + elif current_dtype == "bfloat16": + other_dtype = "float16" + + comparison_kwargs["suffix"] = f"dtype_{other_dtype}" + comparison_kwargs["titles"] = [f"With {current_dtype}", f"With {other_dtype}"] + + comparison_kwargs["disabled"]["dtype"] = current_dtype + comparison_kwargs["enabled"]["dtype"] = other_dtype + elif args.comparison == "amp": + current_amp = cfg.inference.weight_amp + other_amp = not current_amp + + comparison_kwargs["suffix"] = f"with{'out' if not other_amp else ''}_amp" + comparison_kwargs["titles"] = [f"With {current_amp}", f"With {other_amp}"] + + comparison_kwargs["disabled"]["amp"] = current_amp + comparison_kwargs["enabled"]["amp"] = other_amp else: raise Exception(f"Unrecognized comparison flag: {args.comparison}") @@ -234,7 +266,7 @@ def main(): audio_samples = [ prompt, out_path ] if args.comparison: audio_samples += [ out_path_comparison ] - audio_samples += [ p for p in external_sources if p.exists() ] + audio_samples += [ p for p in external_sources if p.exists() else None ] if not args.random_prompts or k == "librispeech": audio_samples += [ reference ] @@ -246,6 +278,13 @@ def main(): seed = args.seed if args.seed else int(time.time()) + """ + # manual invocation + cmd = f'python3 -m vall_e --yaml="{args.yaml}" "{reference}" "{text}" --out-path={out_path}' + # F5 + cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"' + """ + kwargs = dict( text=text, references=[prompt], @@ -268,6 +307,14 @@ def main(): def safe_inference( out_path=out_path ): if args.skip_existing and out_path.exists(): return + + # swap model config swap + if "dtype" in kwargs or "amp" in kwargs: + dtype = kwargs.pop("dtype", args.dtype) + amp = kwargs.pop("amp", args.amp) + + del tts + tts = TTS( config=args.yaml, device=args.device, dtype=dtype, amp=amp ) try: tts.inference( out_path=out_path, **kwargs ) except Exception as e: @@ -295,11 +342,11 @@ def main(): html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) ) if args.comparison: - before, after = comparison_kwargs["titles"] + disabled, enabled = comparison_kwargs["titles"] if args.random_prompts: - html = html.replace("