more tweaks
This commit is contained in:
parent
0dfab973e7
commit
07f4935a75
|
@ -172,6 +172,7 @@ For audio backends:
|
||||||
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
|
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
|
||||||
* `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently)
|
* `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently)
|
||||||
* `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
|
* `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
|
||||||
|
* `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed)
|
||||||
* `transformers` Llama\*Attention implementations:
|
* `transformers` Llama\*Attention implementations:
|
||||||
* `eager`: default `LlamaAttention`
|
* `eager`: default `LlamaAttention`
|
||||||
* `sdpa`: integrated `LlamaSdpaAttention` attention model
|
* `sdpa`: integrated `LlamaSdpaAttention` attention model
|
||||||
|
@ -337,6 +338,7 @@ Despite how lightweight it is in comparison to other TTS's I've meddled with, th
|
||||||
* speakers that aren't similar to an audiobook narrator voice has similarity issues due to the majority of training used `path`-based dataloader sampling instead of `speaker`-based (or `group`-based) dataloader sampling.
|
* speakers that aren't similar to an audiobook narrator voice has similarity issues due to the majority of training used `path`-based dataloader sampling instead of `speaker`-based (or `group`-based) dataloader sampling.
|
||||||
+ although LoRAs help a ton for fixing results for a single voice.
|
+ although LoRAs help a ton for fixing results for a single voice.
|
||||||
+ a diverse dataset in prosidy and speaker (such as a corpus sourced from dramatic media like video games) helps a ton.
|
+ a diverse dataset in prosidy and speaker (such as a corpus sourced from dramatic media like video games) helps a ton.
|
||||||
|
* On my test system (7900XTX), it seems inferencing quality depends on the moon phase; I don't know if it's a matter of ROCm nuances (since I've always found it to not be up to par with actual CUDA) or `bfloat16` (due to the model being trained under `float16`+AMP) being the culprit, but your mileage *will* vary depending on the system + dtype + sampler settings.
|
||||||
|
|
||||||
## Notices and Citations
|
## Notices and Citations
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ from .utils import set_seed, prune_missing
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class BaseConfig:
|
class BaseConfig:
|
||||||
yaml_path: str | None = None # path passed in through --yaml
|
yaml_path: str | None = None # path passed in through --yaml
|
||||||
|
model_path: str | None = None # path passed in through --model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cfg_path(self):
|
def cfg_path(self):
|
||||||
|
@ -114,12 +115,12 @@ class BaseConfig:
|
||||||
if arg.startswith("yaml"):
|
if arg.startswith("yaml"):
|
||||||
args[i] = f'--{arg}'
|
args[i] = f'--{arg}'
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||||
args, unknown = parser.parse_known_args(args=args)
|
args, unknown = parser.parse_known_args(args=args)
|
||||||
|
|
||||||
if args.yaml:
|
if args.yaml:
|
||||||
return cls.from_yaml( args.yaml )
|
return cls.from_yaml( args.yaml )
|
||||||
|
|
||||||
return cls(**{})
|
return cls(**{})
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,8 @@ from .emb.qnt import decode_to_file
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
def encode(path):
|
def encode(path):
|
||||||
|
if path is None or path.exists():
|
||||||
|
return ""
|
||||||
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
|
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
|
||||||
|
|
||||||
# Would be downright sugoi if I could incorporate this with into __main__
|
# Would be downright sugoi if I could incorporate this with into __main__
|
||||||
|
@ -122,7 +124,8 @@ def main():
|
||||||
comparison_kwargs["enabled"]["use_lora"] = True
|
comparison_kwargs["enabled"]["use_lora"] = True
|
||||||
elif args.comparison == "entropix-sampling":
|
elif args.comparison == "entropix-sampling":
|
||||||
comparison_kwargs["suffix"] = "entropix_sampling"
|
comparison_kwargs["suffix"] = "entropix_sampling"
|
||||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||||
|
|
||||||
comparison_kwargs["disabled"]["entropix_sampling"] = False
|
comparison_kwargs["disabled"]["entropix_sampling"] = False
|
||||||
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
|
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
|
||||||
comparison_kwargs["disabled"]["top_k"] = args.top_k
|
comparison_kwargs["disabled"]["top_k"] = args.top_k
|
||||||
|
@ -132,17 +135,46 @@ def main():
|
||||||
comparison_kwargs["enabled"]["top_k"] = 27
|
comparison_kwargs["enabled"]["top_k"] = 27
|
||||||
comparison_kwargs["enabled"]["top_p"] = 0.9
|
comparison_kwargs["enabled"]["top_p"] = 0.9
|
||||||
elif args.comparison == "ar-temp":
|
elif args.comparison == "ar-temp":
|
||||||
|
current_temp = args.ar_temp
|
||||||
|
other_temp = 1.0
|
||||||
|
|
||||||
comparison_kwargs["suffix"] = "temperature"
|
comparison_kwargs["suffix"] = "temperature"
|
||||||
comparison_kwargs["titles"] = [f"Temp: {args.ar_temp:.2f}", "Temp: 1.0"]
|
comparison_kwargs["titles"] = [f"Temp: {current_temp:.2f}", f"Temp: {other_temp:.2f}"]
|
||||||
|
|
||||||
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
|
comparison_kwargs["disabled"]["ar_temp"] = current_temp
|
||||||
comparison_kwargs["enabled"]["ar_temp"] = 1.0
|
comparison_kwargs["enabled"]["ar_temp"] = other_temp
|
||||||
elif args.comparison == "input-prompt-length":
|
elif args.comparison == "input-prompt-length":
|
||||||
comparison_kwargs["suffix"] = "input_prompt_length"
|
current_length = args.input_prompt_length
|
||||||
comparison_kwargs["titles"] = [f"Prompt Length: {args.input_prompt_length:.2f}s", "Prompt Length: 6.0s"]
|
other_length = 3.0
|
||||||
|
|
||||||
comparison_kwargs["disabled"]["input-prompt-length"] = args.input_prompt_length
|
comparison_kwargs["suffix"] = "input_prompt_length"
|
||||||
comparison_kwargs["enabled"]["input-prompt-length"] = 6.0
|
comparison_kwargs["titles"] = [f"Prompt Length: {current_length:.2f}s", f"Prompt Length: {other_length:.2f}s"]
|
||||||
|
|
||||||
|
comparison_kwargs["disabled"]["input_prompt_length"] = current_length
|
||||||
|
comparison_kwargs["enabled"]["input_prompt_length"] = other_length
|
||||||
|
elif args.comparison == "dtype":
|
||||||
|
current_dtype = cfg.inference.weight_dtype
|
||||||
|
other_dtype = "float32"
|
||||||
|
|
||||||
|
if current_dtype == "float16":
|
||||||
|
other_dtype = "bfloat16"
|
||||||
|
elif current_dtype == "bfloat16":
|
||||||
|
other_dtype = "float16"
|
||||||
|
|
||||||
|
comparison_kwargs["suffix"] = f"dtype_{other_dtype}"
|
||||||
|
comparison_kwargs["titles"] = [f"With {current_dtype}", f"With {other_dtype}"]
|
||||||
|
|
||||||
|
comparison_kwargs["disabled"]["dtype"] = current_dtype
|
||||||
|
comparison_kwargs["enabled"]["dtype"] = other_dtype
|
||||||
|
elif args.comparison == "amp":
|
||||||
|
current_amp = cfg.inference.weight_amp
|
||||||
|
other_amp = not current_amp
|
||||||
|
|
||||||
|
comparison_kwargs["suffix"] = f"with{'out' if not other_amp else ''}_amp"
|
||||||
|
comparison_kwargs["titles"] = [f"With {current_amp}", f"With {other_amp}"]
|
||||||
|
|
||||||
|
comparison_kwargs["disabled"]["amp"] = current_amp
|
||||||
|
comparison_kwargs["enabled"]["amp"] = other_amp
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
|
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
|
||||||
|
|
||||||
|
@ -234,7 +266,7 @@ def main():
|
||||||
audio_samples = [ prompt, out_path ]
|
audio_samples = [ prompt, out_path ]
|
||||||
if args.comparison:
|
if args.comparison:
|
||||||
audio_samples += [ out_path_comparison ]
|
audio_samples += [ out_path_comparison ]
|
||||||
audio_samples += [ p for p in external_sources if p.exists() ]
|
audio_samples += [ p for p in external_sources if p.exists() else None ]
|
||||||
|
|
||||||
if not args.random_prompts or k == "librispeech":
|
if not args.random_prompts or k == "librispeech":
|
||||||
audio_samples += [ reference ]
|
audio_samples += [ reference ]
|
||||||
|
@ -246,6 +278,13 @@ def main():
|
||||||
|
|
||||||
seed = args.seed if args.seed else int(time.time())
|
seed = args.seed if args.seed else int(time.time())
|
||||||
|
|
||||||
|
"""
|
||||||
|
# manual invocation
|
||||||
|
cmd = f'python3 -m vall_e --yaml="{args.yaml}" "{reference}" "{text}" --out-path={out_path}'
|
||||||
|
# F5
|
||||||
|
cmd = f'python inference-cli.py --model "F5-TTS" --ref_audio "{reference}" --gen_text "{text}" --output_dir "{out_path.parent}"'
|
||||||
|
"""
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
text=text,
|
text=text,
|
||||||
references=[prompt],
|
references=[prompt],
|
||||||
|
@ -268,6 +307,14 @@ def main():
|
||||||
def safe_inference( out_path=out_path ):
|
def safe_inference( out_path=out_path ):
|
||||||
if args.skip_existing and out_path.exists():
|
if args.skip_existing and out_path.exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# swap model config swap
|
||||||
|
if "dtype" in kwargs or "amp" in kwargs:
|
||||||
|
dtype = kwargs.pop("dtype", args.dtype)
|
||||||
|
amp = kwargs.pop("amp", args.amp)
|
||||||
|
|
||||||
|
del tts
|
||||||
|
tts = TTS( config=args.yaml, device=args.device, dtype=dtype, amp=amp )
|
||||||
try:
|
try:
|
||||||
tts.inference( out_path=out_path, **kwargs )
|
tts.inference( out_path=out_path, **kwargs )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -295,11 +342,11 @@ def main():
|
||||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||||
|
|
||||||
if args.comparison:
|
if args.comparison:
|
||||||
before, after = comparison_kwargs["titles"]
|
disabled, enabled = comparison_kwargs["titles"]
|
||||||
if args.random_prompts:
|
if args.random_prompts:
|
||||||
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
|
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({disabled})</th>\n\t\t\t\t\t<th>Our VALL-E ({enabled})</th>")
|
||||||
else:
|
else:
|
||||||
html = html.replace("<th>Our VALL-E</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
|
html = html.replace("<th>Our VALL-E</th>", f"<th>Our VALL-E ({disabled})</th>\n\t\t\t\t\t<th>Our VALL-E ({enabled})</th>")
|
||||||
|
|
||||||
# write demo page
|
# write demo page
|
||||||
open( args.demo_dir / args.output_filename, "w", encoding="utf-8" ).write( html )
|
open( args.demo_dir / args.output_filename, "w", encoding="utf-8" ).write( html )
|
||||||
|
|
|
@ -117,7 +117,7 @@ def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=No
|
||||||
del tts
|
del tts
|
||||||
tts = None
|
tts = None
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
|
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
|
||||||
parser.add_argument("--device", type=str, default=device)
|
parser.add_argument("--device", type=str, default=device)
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
|
@ -140,7 +140,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
kwargs['min-ar-temp'] = -1
|
kwargs['min-ar-temp'] = -1
|
||||||
kwargs['min-nar-temp'] = -1
|
kwargs['min-nar-temp'] = -1
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||||
# I'm very sure I can procedurally generate this list
|
# I'm very sure I can procedurally generate this list
|
||||||
parser.add_argument("--text", type=str, default=kwargs["text"])
|
parser.add_argument("--text", type=str, default=kwargs["text"])
|
||||||
parser.add_argument("--task", type=str, default="tts")
|
parser.add_argument("--task", type=str, default="tts")
|
||||||
|
@ -226,7 +226,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
else:
|
else:
|
||||||
kwargs['min-ar-temp'] = -1
|
kwargs['min-ar-temp'] = -1
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
|
||||||
# I'm very sure I can procedurally generate this list
|
# I'm very sure I can procedurally generate this list
|
||||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||||
parser.add_argument("--language", type=str, default=kwargs["language"])
|
parser.add_argument("--language", type=str, default=kwargs["language"])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user