more tweaks

This commit is contained in:
mrq 2024-10-18 13:19:36 -05:00
parent 0dfab973e7
commit 07f4935a75
4 changed files with 67 additions and 17 deletions

View File

@ -172,6 +172,7 @@ For audio backends:
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
* `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently)
* `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
* `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed)
* `transformers` Llama\*Attention implementations:
* `eager`: default `LlamaAttention`
* `sdpa`: integrated `LlamaSdpaAttention` attention model
@ -337,6 +338,7 @@ Despite how lightweight it is in comparison to other TTS's I've meddled with, th
* speakers that aren't similar to an audiobook narrator voice has similarity issues due to the majority of training used `path`-based dataloader sampling instead of `speaker`-based (or `group`-based) dataloader sampling.
+ although LoRAs help a ton for fixing results for a single voice.
+ a diverse dataset in prosidy and speaker (such as a corpus sourced from dramatic media like video games) helps a ton.
* On my test system (7900XTX), it seems inferencing quality depends on the moon phase; I don't know if it's a matter of ROCm nuances (since I've always found it to not be up to par with actual CUDA) or `bfloat16` (due to the model being trained under `float16`+AMP) being the culprit, but your mileage *will* vary depending on the system + dtype + sampler settings.
## Notices and Citations

View File

@ -26,6 +26,7 @@ from .utils import set_seed, prune_missing
@dataclass()
class BaseConfig:
yaml_path: str | None = None # path passed in through --yaml
model_path: str | None = None # path passed in through --model
@property
def cfg_path(self):
@ -114,12 +115,12 @@ class BaseConfig:
if arg.startswith("yaml"):
args[i] = f'--{arg}'
parser = argparse.ArgumentParser(allow_abbrev=False)
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
args, unknown = parser.parse_known_args(args=args)
if args.yaml:
return cls.from_yaml( args.yaml )
return cls.from_yaml( args.yaml )
return cls(**{})

View File

@ -33,6 +33,8 @@ from .emb.qnt import decode_to_file
from tqdm import tqdm, trange
def encode(path):
if path is None or path.exists():
return ""
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
# Would be downright sugoi if I could incorporate this with into __main__
@ -122,7 +124,8 @@ def main():
comparison_kwargs["enabled"]["use_lora"] = True
elif args.comparison == "entropix-sampling":
comparison_kwargs["suffix"] = "entropix_sampling"
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
comparison_kwargs["disabled"]["entropix_sampling"] = False
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
comparison_kwargs["disabled"]["top_k"] = args.top_k
@ -132,17 +135,46 @@ def main():
comparison_kwargs["enabled"]["top_k"] = 27
comparison_kwargs["enabled"]["top_p"] = 0.9
elif args.comparison == "ar-temp":
current_temp = args.ar_temp
other_temp = 1.0
comparison_kwargs["suffix"] = "temperature"
comparison_kwargs["titles"] = [f"Temp: {args.ar_temp:.2f}", "Temp: 1.0"]
comparison_kwargs["titles"] = [f"Temp: {current_temp:.2f}", f"Temp: {other_temp:.2f}"]
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
comparison_kwargs["enabled"]["ar_temp"] = 1.0
comparison_kwargs["disabled"]["ar_temp"] = current_temp
comparison_kwargs["enabled"]["ar_temp"] = other_temp
elif args.comparison == "input-prompt-length":
comparison_kwargs["suffix"] = "input_prompt_length"
comparison_kwargs["titles"] = [f"Prompt Length: {args.input_prompt_length:.2f}s", "Prompt Length: 6.0s"]
current_length = args.input_prompt_length
other_length = 3.0
comparison_kwargs["disabled"]["input-prompt-length"] = args.input_prompt_length
comparison_kwargs["enabled"]["input-prompt-length"] = 6.0
comparison_kwargs["suffix"] = "input_prompt_length"
comparison_kwargs["titles"] = [f"Prompt Length: {current_length:.2f}s", f"Prompt Length: {other_length:.2f}s"]
comparison_kwargs["disabled"]["input_prompt_length"] = current_length
comparison_kwargs["enabled"]["input_prompt_length"] = other_length
elif args.comparison == "dtype":
current_dtype = cfg.inference.weight_dtype
other_dtype = "float32"
if current_dtype == "float16":
other_dtype = "bfloat16"
elif current_dtype == "bfloat16":
other_dtype = "float16"
comparison_kwargs["suffix"] = f"dtype_{other_dtype}"
comparison_kwargs["titles"] = [f"With {current_dtype}", f"With {other_dtype}"]
comparison_kwargs["disabled"]["dtype"] = current_dtype
comparison_kwargs["enabled"]["dtype"] = other_dtype
elif args.comparison == "amp":
current_amp = cfg.inference.weight_amp
other_amp = not current_amp
comparison_kwargs["suffix"] = f"with{'out' if not other_amp else ''}_amp"
comparison_kwargs["titles"] = [f"With {current_amp}", f"With {other_amp}"]
comparison_kwargs["disabled"]["amp"] = current_amp
comparison_kwargs["enabled"]["amp"] = other_amp
else:
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
@ -234,7 +266,7 @@ def main():
audio_samples = [ prompt, out_path ]
if args.comparison:
audio_samples += [ out_path_comparison ]
audio_samples += [ p for p in external_sources if p.exists() ]
audio_samples += [ p for p in external_sources if p.exists() else None ]
if not args.random_prompts or k == "librispeech":
audio_samples += [ reference ]
@ -246,6 +278,13 @@ def main():
seed = args.seed if args.seed else int(time.time())
"""
# manual invocation
cmd = f'python3 -m vall_e --yaml="{args.yaml}" "{reference}" "{text}" --out-path={out_path}'
# F5
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
"""
kwargs = dict(
text=text,
references=[prompt],
@ -268,6 +307,14 @@ def main():
def safe_inference( out_path=out_path ):
if args.skip_existing and out_path.exists():
return
# swap model config swap
if "dtype" in kwargs or "amp" in kwargs:
dtype = kwargs.pop("dtype", args.dtype)
amp = kwargs.pop("amp", args.amp)
del tts
tts = TTS( config=args.yaml, device=args.device, dtype=dtype, amp=amp )
try:
tts.inference( out_path=out_path, **kwargs )
except Exception as e:
@ -295,11 +342,11 @@ def main():
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
if args.comparison:
before, after = comparison_kwargs["titles"]
disabled, enabled = comparison_kwargs["titles"]
if args.random_prompts:
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({disabled})</th>\n\t\t\t\t\t<th>Our VALL-E ({enabled})</th>")
else:
html = html.replace("<th>Our VALL-E</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
html = html.replace("<th>Our VALL-E</th>", f"<th>Our VALL-E ({disabled})</th>\n\t\t\t\t\t<th>Our VALL-E ({enabled})</th>")
# write demo page
open( args.demo_dir / args.output_filename, "w", encoding="utf-8" ).write( html )

View File

@ -117,7 +117,7 @@ def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=No
del tts
tts = None
parser = argparse.ArgumentParser(allow_abbrev=False)
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--device", type=str, default=device)
parser.add_argument("--amp", action="store_true")
@ -140,7 +140,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
kwargs['min-ar-temp'] = -1
kwargs['min-nar-temp'] = -1
parser = argparse.ArgumentParser(allow_abbrev=False)
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
# I'm very sure I can procedurally generate this list
parser.add_argument("--text", type=str, default=kwargs["text"])
parser.add_argument("--task", type=str, default="tts")
@ -226,7 +226,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
else:
kwargs['min-ar-temp'] = -1
parser = argparse.ArgumentParser(allow_abbrev=False)
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
# I'm very sure I can procedurally generate this list
parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--language", type=str, default=kwargs["language"])