more tweaks
This commit is contained in:
parent
0dfab973e7
commit
07f4935a75
|
@ -172,6 +172,7 @@ For audio backends:
|
|||
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
|
||||
* `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently)
|
||||
* `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
|
||||
* `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed)
|
||||
* `transformers` Llama\*Attention implementations:
|
||||
* `eager`: default `LlamaAttention`
|
||||
* `sdpa`: integrated `LlamaSdpaAttention` attention model
|
||||
|
@ -337,6 +338,7 @@ Despite how lightweight it is in comparison to other TTS's I've meddled with, th
|
|||
* speakers that aren't similar to an audiobook narrator voice has similarity issues due to the majority of training used `path`-based dataloader sampling instead of `speaker`-based (or `group`-based) dataloader sampling.
|
||||
+ although LoRAs help a ton for fixing results for a single voice.
|
||||
+ a diverse dataset in prosidy and speaker (such as a corpus sourced from dramatic media like video games) helps a ton.
|
||||
* On my test system (7900XTX), it seems inferencing quality depends on the moon phase; I don't know if it's a matter of ROCm nuances (since I've always found it to not be up to par with actual CUDA) or `bfloat16` (due to the model being trained under `float16`+AMP) being the culprit, but your mileage *will* vary depending on the system + dtype + sampler settings.
|
||||
|
||||
## Notices and Citations
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ from .utils import set_seed, prune_missing
|
|||
@dataclass()
|
||||
class BaseConfig:
|
||||
yaml_path: str | None = None # path passed in through --yaml
|
||||
model_path: str | None = None # path passed in through --model
|
||||
|
||||
@property
|
||||
def cfg_path(self):
|
||||
|
@ -114,12 +115,12 @@ class BaseConfig:
|
|||
if arg.startswith("yaml"):
|
||||
args[i] = f'--{arg}'
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
args, unknown = parser.parse_known_args(args=args)
|
||||
|
||||
if args.yaml:
|
||||
return cls.from_yaml( args.yaml )
|
||||
return cls.from_yaml( args.yaml )
|
||||
|
||||
return cls(**{})
|
||||
|
||||
|
|
|
@ -33,6 +33,8 @@ from .emb.qnt import decode_to_file
|
|||
from tqdm import tqdm, trange
|
||||
|
||||
def encode(path):
|
||||
if path is None or path.exists():
|
||||
return ""
|
||||
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
|
||||
|
||||
# Would be downright sugoi if I could incorporate this with into __main__
|
||||
|
@ -122,7 +124,8 @@ def main():
|
|||
comparison_kwargs["enabled"]["use_lora"] = True
|
||||
elif args.comparison == "entropix-sampling":
|
||||
comparison_kwargs["suffix"] = "entropix_sampling"
|
||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||
|
||||
comparison_kwargs["disabled"]["entropix_sampling"] = False
|
||||
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
|
||||
comparison_kwargs["disabled"]["top_k"] = args.top_k
|
||||
|
@ -132,17 +135,46 @@ def main():
|
|||
comparison_kwargs["enabled"]["top_k"] = 27
|
||||
comparison_kwargs["enabled"]["top_p"] = 0.9
|
||||
elif args.comparison == "ar-temp":
|
||||
current_temp = args.ar_temp
|
||||
other_temp = 1.0
|
||||
|
||||
comparison_kwargs["suffix"] = "temperature"
|
||||
comparison_kwargs["titles"] = [f"Temp: {args.ar_temp:.2f}", "Temp: 1.0"]
|
||||
comparison_kwargs["titles"] = [f"Temp: {current_temp:.2f}", f"Temp: {other_temp:.2f}"]
|
||||
|
||||
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
|
||||
comparison_kwargs["enabled"]["ar_temp"] = 1.0
|
||||
comparison_kwargs["disabled"]["ar_temp"] = current_temp
|
||||
comparison_kwargs["enabled"]["ar_temp"] = other_temp
|
||||
elif args.comparison == "input-prompt-length":
|
||||
comparison_kwargs["suffix"] = "input_prompt_length"
|
||||
comparison_kwargs["titles"] = [f"Prompt Length: {args.input_prompt_length:.2f}s", "Prompt Length: 6.0s"]
|
||||
current_length = args.input_prompt_length
|
||||
other_length = 3.0
|
||||
|
||||
comparison_kwargs["disabled"]["input-prompt-length"] = args.input_prompt_length
|
||||
comparison_kwargs["enabled"]["input-prompt-length"] = 6.0
|
||||
comparison_kwargs["suffix"] = "input_prompt_length"
|
||||
comparison_kwargs["titles"] = [f"Prompt Length: {current_length:.2f}s", f"Prompt Length: {other_length:.2f}s"]
|
||||
|
||||
comparison_kwargs["disabled"]["input_prompt_length"] = current_length
|
||||
comparison_kwargs["enabled"]["input_prompt_length"] = other_length
|
||||
elif args.comparison == "dtype":
|
||||
current_dtype = cfg.inference.weight_dtype
|
||||
other_dtype = "float32"
|
||||
|
||||
if current_dtype == "float16":
|
||||
other_dtype = "bfloat16"
|
||||
elif current_dtype == "bfloat16":
|
||||
other_dtype = "float16"
|
||||
|
||||
comparison_kwargs["suffix"] = f"dtype_{other_dtype}"
|
||||
comparison_kwargs["titles"] = [f"With {current_dtype}", f"With {other_dtype}"]
|
||||
|
||||
comparison_kwargs["disabled"]["dtype"] = current_dtype
|
||||
comparison_kwargs["enabled"]["dtype"] = other_dtype
|
||||
elif args.comparison == "amp":
|
||||
current_amp = cfg.inference.weight_amp
|
||||
other_amp = not current_amp
|
||||
|
||||
comparison_kwargs["suffix"] = f"with{'out' if not other_amp else ''}_amp"
|
||||
comparison_kwargs["titles"] = [f"With {current_amp}", f"With {other_amp}"]
|
||||
|
||||
comparison_kwargs["disabled"]["amp"] = current_amp
|
||||
comparison_kwargs["enabled"]["amp"] = other_amp
|
||||
else:
|
||||
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
|
||||
|
||||
|
@ -234,7 +266,7 @@ def main():
|
|||
audio_samples = [ prompt, out_path ]
|
||||
if args.comparison:
|
||||
audio_samples += [ out_path_comparison ]
|
||||
audio_samples += [ p for p in external_sources if p.exists() ]
|
||||
audio_samples += [ p for p in external_sources if p.exists() else None ]
|
||||
|
||||
if not args.random_prompts or k == "librispeech":
|
||||
audio_samples += [ reference ]
|
||||
|
@ -246,6 +278,13 @@ def main():
|
|||
|
||||
seed = args.seed if args.seed else int(time.time())
|
||||
|
||||
"""
|
||||
# manual invocation
|
||||
cmd = f'python3 -m vall_e --yaml="{args.yaml}" "{reference}" "{text}" --out-path={out_path}'
|
||||
# F5
|
||||
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
|
||||
"""
|
||||
|
||||
kwargs = dict(
|
||||
text=text,
|
||||
references=[prompt],
|
||||
|
@ -268,6 +307,14 @@ def main():
|
|||
def safe_inference( out_path=out_path ):
|
||||
if args.skip_existing and out_path.exists():
|
||||
return
|
||||
|
||||
# swap model config swap
|
||||
if "dtype" in kwargs or "amp" in kwargs:
|
||||
dtype = kwargs.pop("dtype", args.dtype)
|
||||
amp = kwargs.pop("amp", args.amp)
|
||||
|
||||
del tts
|
||||
tts = TTS( config=args.yaml, device=args.device, dtype=dtype, amp=amp )
|
||||
try:
|
||||
tts.inference( out_path=out_path, **kwargs )
|
||||
except Exception as e:
|
||||
|
@ -295,11 +342,11 @@ def main():
|
|||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||
|
||||
if args.comparison:
|
||||
before, after = comparison_kwargs["titles"]
|
||||
disabled, enabled = comparison_kwargs["titles"]
|
||||
if args.random_prompts:
|
||||
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
|
||||
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({disabled})</th>\n\t\t\t\t\t<th>Our VALL-E ({enabled})</th>")
|
||||
else:
|
||||
html = html.replace("<th>Our VALL-E</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
|
||||
html = html.replace("<th>Our VALL-E</th>", f"<th>Our VALL-E ({disabled})</th>\n\t\t\t\t\t<th>Our VALL-E ({enabled})</th>")
|
||||
|
||||
# write demo page
|
||||
open( args.demo_dir / args.output_filename, "w", encoding="utf-8" ).write( html )
|
||||
|
|
|
@ -117,7 +117,7 @@ def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=No
|
|||
del tts
|
||||
tts = None
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--device", type=str, default=device)
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
|
@ -140,7 +140,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
kwargs['min-ar-temp'] = -1
|
||||
kwargs['min-nar-temp'] = -1
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||
# I'm very sure I can procedurally generate this list
|
||||
parser.add_argument("--text", type=str, default=kwargs["text"])
|
||||
parser.add_argument("--task", type=str, default="tts")
|
||||
|
@ -226,7 +226,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
else:
|
||||
kwargs['min-ar-temp'] = -1
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||
# I'm very sure I can procedurally generate this list
|
||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||
parser.add_argument("--language", type=str, default=kwargs["language"])
|
||||
|
|
Loading…
Reference in New Issue
Block a user