modified demo page to be more modular with demoing comparisons, actually provide a path to use modified naive attention, entropix sampling is not tied to an experimental yaml flag now

This commit is contained in:
mrq 2024-10-12 11:27:55 -05:00
parent 666e8038fb
commit 04e983b86b
9 changed files with 77 additions and 29 deletions

View File

@ -43,6 +43,8 @@ def main():
parser.add_argument("--dry-base", type=float, default=1.75)
parser.add_argument("--dry-allowed-length", type=int, default=2)
parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--device", type=str, default=None)
@ -69,6 +71,7 @@ def main():
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
seed=args.seed,
)

View File

@ -237,8 +237,6 @@ class ModelExperimentalSettings:
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
# to-to: just incorporate this as a task instead
entropix_sampling: bool = False # experimental sampling based on https://github.com/xjdr-alt/entropix, experimental flag because it requires using naive attention for output scores
# I really need to clean this up
@dataclass()
class Model:

View File

@ -76,6 +76,8 @@ def main():
parser.add_argument("--dry-base", type=float, default=1.75)
parser.add_argument("--dry-allowed-length", type=int, default=2)
parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--device", type=str, default=None)
@ -84,6 +86,7 @@ def main():
parser.add_argument("--random-prompts", action="store_true")
parser.add_argument("--lora", action="store_true")
parser.add_argument("--comparison", action="store_true")
args = parser.parse_args()
@ -98,6 +101,30 @@ def main():
'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.',
])
# comparison kwargs
comparison_kwargs = {
"enabled": False,
"titles": [],
"suffix": "_after",
"before": {},
"after": {}
}
if args.lora:
comparison_kwargs["enabled"] = True
comparison_kwargs["suffix"] = "_lora"
comparison_kwargs["titles"] = ["No LoRA", "LoRA"]
comparison_kwargs["before"]["use_lora"] = True
comparison_kwargs["after"]["use_lora"] = False
# to-do: make this user definable
elif args.comparison:
comparison_kwargs["enabled"] = True
comparison_kwargs["suffix"] = "_entropix"
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
comparison_kwargs["before"]["entropix_sampling"] = True
comparison_kwargs["after"]["entropix_sampling"] = False
# read html template
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
@ -114,6 +141,7 @@ def main():
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
)) )
# pull from provided samples
@ -127,7 +155,7 @@ def main():
# pull from dataset samples
if args.sample_from_dataset:
cfg.dataset.cache = False
cfg.dataset.sample_type = "path" if args.lora else "speaker"
cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker"
cfg.dataset.tasks_list = [ 'tts' ]
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
@ -180,9 +208,9 @@ def main():
prompt = dir / "prompt.wav"
reference = dir / "reference.wav"
out_path = dir / "out" / "ours.wav"
out_path_lora = dir / "out" / "ours_lora.wav"
out_path_comparison = dir / "out" / f"ours_{comparison_kwargs["suffix"]}.wav"
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_comparison ] if comparison_kwargs["enabled"] else [])
if not args.random_prompts or k == "librispeech":
extra_sources += [ reference ]
@ -210,23 +238,24 @@ def main():
length_penalty=args.length_penalty,
beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
seed=seed,
tqdm=False,
)
if args.lora:
tts.enable_lora() # I don't think this is necessary with the below
kwargs["use_lora"] = True
def safe_inference():
try:
tts.inference( out_path=out_path_lora, **kwargs )
tts.inference( out_path=out_path_comparison, **kwargs )
except Exception as e:
print(f'Error while processing {out_path}: {e}')
tts.disable_lora()
kwargs["use_lora"] = False
try:
tts.inference( out_path=out_path, **kwargs )
except Exception as e:
print(f'Error while processing {out_path}: {e}')
if comparison_kwargs["enabled"]:
kwargs.update( comparison_kwargs["before"] )
safe_inference()
kwargs.update( comparison_kwargs["after"] )
safe_inference()
# collate entries into HTML
@ -243,11 +272,12 @@ def main():
# write audio into template
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
if args.lora:
if comparison_kwargs["enabled"]:
before, after = comparison_kwargs["titles"]
if args.random_prompts:
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<th>Our VALL-E (LoRA)</th>")
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
else:
html = html.replace("<th>Our VALL-E</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<th>Our VALL-E (LoRA)</th>")
html = html.replace("<th>Our VALL-E</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
# write demo page
open( args.demo_dir / args.output_filename, "w", encoding="utf-8" ).write( html )

View File

@ -206,7 +206,9 @@ class TTS():
dry_multiplier=0.0,
dry_base=1.75,
dry_allowed_length=2,
#
entropix_sampling=False,
#
seed = None,
out_path=None,
@ -255,6 +257,7 @@ class TTS():
sampling_dry_multiplier=dry_multiplier,
sampling_dry_base=dry_base,
sampling_dry_allowed_length=dry_allowed_length,
sampling_entropix=entropix_sampling,
disable_tqdm=not tqdm,
use_lora=use_lora,
@ -299,6 +302,7 @@ class TTS():
sampling_dry_multiplier=dry_multiplier,
sampling_dry_base=dry_base,
sampling_dry_allowed_length=dry_allowed_length,
sampling_entropix=entropix_sampling,
disable_tqdm=not tqdm,
use_lora=use_lora,

View File

@ -64,8 +64,7 @@ class AR_NAR(Base):
sampling_dry_multiplier=0.0,
sampling_dry_base=1.75,
sampling_dry_allowed_length=2,
sampling_entropix=None,
sampling_entropix=False,
disable_tqdm=False,
use_lora=None,
@ -269,9 +268,6 @@ class AR_NAR(Base):
scores = [ 1.0 ] * sampling_beam_width
entropies = []
if sampling_entropix is None:
sampling_entropix = self.config.experimental.entropix_sampling
for i, sequence in enumerate( sequence_list ):
# add <bos> to text for STT
if task_list[i] in text_task:

View File

@ -126,7 +126,8 @@ if torch.backends.cuda.cudnn_sdp_enabled():
AVAILABLE_ATTENTIONS.append("cudnn")
if AVAILABLE_ATTENTIONS:
AVAILABLE_ATTENTIONS.append("sdpa")
AVAILABLE_ATTENTIONS.append("sdpa")
AVAILABLE_ATTENTIONS.append("default")
class LlamaAttention_Adapted(LlamaAttention):
def __init__(self, *args, **kwargs):
@ -144,6 +145,8 @@ class LlamaAttention_Adapted(LlamaAttention):
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
elif self.mode == "cudnn":
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
else:
self.mode = None
super().__init__(*args, **kwargs)
@ -256,7 +259,7 @@ class LlamaAttention_Adapted(LlamaAttention):
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
if output_attentions or not self.mode:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
return self._forward(
hidden_states=hidden_states,

View File

@ -531,7 +531,7 @@ class Base(nn.Module):
if AVAILABLE_ATTENTIONS:
attention_backend = AVAILABLE_ATTENTIONS[0]
else:
attention_backend = "eager"
attention_backend = "default"
hf_attention = attention_backend
HF_ATTENTIONS = ["eager", "sdpa", "flash_attention_2"]
@ -1530,13 +1530,14 @@ class Base(nn.Module):
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
# (AR) entropix sampling
# we do it after the penalizers because entropix's internal sampling doesn't account for them (but does do top_k/top_p/min_p)
if attentions is not None and quant_levels is None:
# move to CPU for speedups
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
res = [ sample_entropix(
logit,
attentions[-1], #torch.stack(attentions, dim=1),
attentions[-1], # original code just uses the last attention scores
temperature,
top_k,
top_p,

View File

@ -407,7 +407,7 @@ def sample_entropix(
else:
metrics["action"] = 4
log_softmax = torch.nn.functional.log_softmax(logits)
log_softmax = torch.nn.functional.log_softmax(logits, dim=-1)
logits_uncertainty = ent + vent
attn_uncertainty = attn_ent + attn_vent

View File

@ -167,6 +167,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
parser.add_argument("--entropix-sampling", action="store_true")
args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
@ -176,6 +177,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
raise Exception("No reference audio provided.")
"""
if kwargs.pop("entropix-sampling", False):
args.entropix_sampling = True
tts = init_tts()
gr.Info("Inferencing...")
@ -206,6 +210,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
dry_multiplier=args.dry_multiplier,
dry_base=args.dry_base,
dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling
)
wav = wav.squeeze(0).cpu().numpy()
@ -240,8 +245,10 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
parser.add_argument("--entropix-sampling", action="store_true")
args, unknown = parser.parse_known_args()
"""
if not args.references:
raise Exception("No reference audio provided.")
@ -254,6 +261,9 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
duration = metadata.num_frames / metadata.sample_rate
args.max_ar_steps += duration
args.max_ar_steps = math.floor( args.max_ar_steps * 20 ) # assume 20 tokens per second
if kwargs.pop("entropix-sampling", False):
args.entropix_sampling = True
tts = init_tts()
@ -278,6 +288,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
dry_multiplier=args.dry_multiplier,
dry_base=args.dry_base,
dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
)
return text
@ -342,6 +353,7 @@ with ui:
with gr.Row():
#layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):
with gr.Row():
@ -382,6 +394,7 @@ with ui:
layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
with gr.Row():
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):
with gr.Row():