modified demo page to be more modular with demoing comparisons, actually provide a path to use modified naive attention, entropix sampling is not tied to an experimental yaml flag now

This commit is contained in:
mrq 2024-10-12 11:27:55 -05:00
parent 666e8038fb
commit 04e983b86b
9 changed files with 77 additions and 29 deletions

View File

@ -43,6 +43,8 @@ def main():
parser.add_argument("--dry-base", type=float, default=1.75) parser.add_argument("--dry-base", type=float, default=1.75)
parser.add_argument("--dry-allowed-length", type=int, default=2) parser.add_argument("--dry-allowed-length", type=int, default=2)
parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--seed", type=int, default=None) parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
@ -69,6 +71,7 @@ def main():
beam_width=args.beam_width, beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
seed=args.seed, seed=args.seed,
) )

View File

@ -237,8 +237,6 @@ class ModelExperimentalSettings:
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
# to-to: just incorporate this as a task instead # to-to: just incorporate this as a task instead
entropix_sampling: bool = False # experimental sampling based on https://github.com/xjdr-alt/entropix, experimental flag because it requires using naive attention for output scores
# I really need to clean this up # I really need to clean this up
@dataclass() @dataclass()
class Model: class Model:

View File

@ -76,6 +76,8 @@ def main():
parser.add_argument("--dry-base", type=float, default=1.75) parser.add_argument("--dry-base", type=float, default=1.75)
parser.add_argument("--dry-allowed-length", type=int, default=2) parser.add_argument("--dry-allowed-length", type=int, default=2)
parser.add_argument("--entropix-sampling", action="store_true")
parser.add_argument("--seed", type=int, default=None) parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
@ -84,6 +86,7 @@ def main():
parser.add_argument("--random-prompts", action="store_true") parser.add_argument("--random-prompts", action="store_true")
parser.add_argument("--lora", action="store_true") parser.add_argument("--lora", action="store_true")
parser.add_argument("--comparison", action="store_true")
args = parser.parse_args() args = parser.parse_args()
@ -98,6 +101,30 @@ def main():
'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.', 'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.',
]) ])
# comparison kwargs
comparison_kwargs = {
"enabled": False,
"titles": [],
"suffix": "_after",
"before": {},
"after": {}
}
if args.lora:
comparison_kwargs["enabled"] = True
comparison_kwargs["suffix"] = "_lora"
comparison_kwargs["titles"] = ["No LoRA", "LoRA"]
comparison_kwargs["before"]["use_lora"] = True
comparison_kwargs["after"]["use_lora"] = False
# to-do: make this user definable
elif args.comparison:
comparison_kwargs["enabled"] = True
comparison_kwargs["suffix"] = "_entropix"
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
comparison_kwargs["before"]["entropix_sampling"] = True
comparison_kwargs["after"]["entropix_sampling"] = False
# read html template # read html template
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read() html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
@ -114,6 +141,7 @@ def main():
beam_width=args.beam_width, beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
)) ) )) )
# pull from provided samples # pull from provided samples
@ -127,7 +155,7 @@ def main():
# pull from dataset samples # pull from dataset samples
if args.sample_from_dataset: if args.sample_from_dataset:
cfg.dataset.cache = False cfg.dataset.cache = False
cfg.dataset.sample_type = "path" if args.lora else "speaker" cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker"
cfg.dataset.tasks_list = [ 'tts' ] cfg.dataset.tasks_list = [ 'tts' ]
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
@ -180,9 +208,9 @@ def main():
prompt = dir / "prompt.wav" prompt = dir / "prompt.wav"
reference = dir / "reference.wav" reference = dir / "reference.wav"
out_path = dir / "out" / "ours.wav" out_path = dir / "out" / "ours.wav"
out_path_lora = dir / "out" / "ours_lora.wav" out_path_comparison = dir / "out" / f"ours_{comparison_kwargs["suffix"]}.wav"
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else []) extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_comparison ] if comparison_kwargs["enabled"] else [])
if not args.random_prompts or k == "librispeech": if not args.random_prompts or k == "librispeech":
extra_sources += [ reference ] extra_sources += [ reference ]
@ -210,23 +238,24 @@ def main():
length_penalty=args.length_penalty, length_penalty=args.length_penalty,
beam_width=args.beam_width, beam_width=args.beam_width,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
seed=seed, seed=seed,
tqdm=False, tqdm=False,
) )
if args.lora: def safe_inference():
tts.enable_lora() # I don't think this is necessary with the below
kwargs["use_lora"] = True
try: try:
tts.inference( out_path=out_path_lora, **kwargs ) tts.inference( out_path=out_path_comparison, **kwargs )
except Exception as e: except Exception as e:
print(f'Error while processing {out_path}: {e}') print(f'Error while processing {out_path}: {e}')
tts.disable_lora()
kwargs["use_lora"] = False if comparison_kwargs["enabled"]:
try: kwargs.update( comparison_kwargs["before"] )
tts.inference( out_path=out_path, **kwargs ) safe_inference()
except Exception as e: kwargs.update( comparison_kwargs["after"] )
print(f'Error while processing {out_path}: {e}')
safe_inference()
# collate entries into HTML # collate entries into HTML
@ -243,11 +272,12 @@ def main():
# write audio into template # write audio into template
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) ) html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
if args.lora: if comparison_kwargs["enabled"]:
before, after = comparison_kwargs["titles"]
if args.random_prompts: if args.random_prompts:
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<th>Our VALL-E (LoRA)</th>") html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
else: else:
html = html.replace("<th>Our VALL-E</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<th>Our VALL-E (LoRA)</th>") html = html.replace("<th>Our VALL-E</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
# write demo page # write demo page
open( args.demo_dir / args.output_filename, "w", encoding="utf-8" ).write( html ) open( args.demo_dir / args.output_filename, "w", encoding="utf-8" ).write( html )

View File

@ -206,7 +206,9 @@ class TTS():
dry_multiplier=0.0, dry_multiplier=0.0,
dry_base=1.75, dry_base=1.75,
dry_allowed_length=2, dry_allowed_length=2,
#
entropix_sampling=False,
#
seed = None, seed = None,
out_path=None, out_path=None,
@ -255,6 +257,7 @@ class TTS():
sampling_dry_multiplier=dry_multiplier, sampling_dry_multiplier=dry_multiplier,
sampling_dry_base=dry_base, sampling_dry_base=dry_base,
sampling_dry_allowed_length=dry_allowed_length, sampling_dry_allowed_length=dry_allowed_length,
sampling_entropix=entropix_sampling,
disable_tqdm=not tqdm, disable_tqdm=not tqdm,
use_lora=use_lora, use_lora=use_lora,
@ -299,6 +302,7 @@ class TTS():
sampling_dry_multiplier=dry_multiplier, sampling_dry_multiplier=dry_multiplier,
sampling_dry_base=dry_base, sampling_dry_base=dry_base,
sampling_dry_allowed_length=dry_allowed_length, sampling_dry_allowed_length=dry_allowed_length,
sampling_entropix=entropix_sampling,
disable_tqdm=not tqdm, disable_tqdm=not tqdm,
use_lora=use_lora, use_lora=use_lora,

View File

@ -64,8 +64,7 @@ class AR_NAR(Base):
sampling_dry_multiplier=0.0, sampling_dry_multiplier=0.0,
sampling_dry_base=1.75, sampling_dry_base=1.75,
sampling_dry_allowed_length=2, sampling_dry_allowed_length=2,
sampling_entropix=False,
sampling_entropix=None,
disable_tqdm=False, disable_tqdm=False,
use_lora=None, use_lora=None,
@ -269,9 +268,6 @@ class AR_NAR(Base):
scores = [ 1.0 ] * sampling_beam_width scores = [ 1.0 ] * sampling_beam_width
entropies = [] entropies = []
if sampling_entropix is None:
sampling_entropix = self.config.experimental.entropix_sampling
for i, sequence in enumerate( sequence_list ): for i, sequence in enumerate( sequence_list ):
# add <bos> to text for STT # add <bos> to text for STT
if task_list[i] in text_task: if task_list[i] in text_task:

View File

@ -126,7 +126,8 @@ if torch.backends.cuda.cudnn_sdp_enabled():
AVAILABLE_ATTENTIONS.append("cudnn") AVAILABLE_ATTENTIONS.append("cudnn")
if AVAILABLE_ATTENTIONS: if AVAILABLE_ATTENTIONS:
AVAILABLE_ATTENTIONS.append("sdpa") AVAILABLE_ATTENTIONS.append("sdpa")
AVAILABLE_ATTENTIONS.append("default")
class LlamaAttention_Adapted(LlamaAttention): class LlamaAttention_Adapted(LlamaAttention):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -144,6 +145,8 @@ class LlamaAttention_Adapted(LlamaAttention):
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
elif self.mode == "cudnn": elif self.mode == "cudnn":
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
else:
self.mode = None
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -256,7 +259,7 @@ class LlamaAttention_Adapted(LlamaAttention):
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions or not self.mode:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
return self._forward( return self._forward(
hidden_states=hidden_states, hidden_states=hidden_states,

View File

@ -531,7 +531,7 @@ class Base(nn.Module):
if AVAILABLE_ATTENTIONS: if AVAILABLE_ATTENTIONS:
attention_backend = AVAILABLE_ATTENTIONS[0] attention_backend = AVAILABLE_ATTENTIONS[0]
else: else:
attention_backend = "eager" attention_backend = "default"
hf_attention = attention_backend hf_attention = attention_backend
HF_ATTENTIONS = ["eager", "sdpa", "flash_attention_2"] HF_ATTENTIONS = ["eager", "sdpa", "flash_attention_2"]
@ -1530,13 +1530,14 @@ class Base(nn.Module):
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ] logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
# (AR) entropix sampling # (AR) entropix sampling
# we do it after the penalizers because entropix's internal sampling doesn't account for them (but does do top_k/top_p/min_p)
if attentions is not None and quant_levels is None: if attentions is not None and quant_levels is None:
# move to CPU for speedups # move to CPU for speedups
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
res = [ sample_entropix( res = [ sample_entropix(
logit, logit,
attentions[-1], #torch.stack(attentions, dim=1), attentions[-1], # original code just uses the last attention scores
temperature, temperature,
top_k, top_k,
top_p, top_p,

View File

@ -407,7 +407,7 @@ def sample_entropix(
else: else:
metrics["action"] = 4 metrics["action"] = 4
log_softmax = torch.nn.functional.log_softmax(logits) log_softmax = torch.nn.functional.log_softmax(logits, dim=-1)
logits_uncertainty = ent + vent logits_uncertainty = ent + vent
attn_uncertainty = attn_ent + attn_vent attn_uncertainty = attn_ent + attn_vent

View File

@ -167,6 +167,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"]) parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"]) parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"]) parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
parser.add_argument("--entropix-sampling", action="store_true")
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
tmp = tempfile.NamedTemporaryFile(suffix='.wav') tmp = tempfile.NamedTemporaryFile(suffix='.wav')
@ -176,6 +177,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
raise Exception("No reference audio provided.") raise Exception("No reference audio provided.")
""" """
if kwargs.pop("entropix-sampling", False):
args.entropix_sampling = True
tts = init_tts() tts = init_tts()
gr.Info("Inferencing...") gr.Info("Inferencing...")
@ -206,6 +210,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
dry_multiplier=args.dry_multiplier, dry_multiplier=args.dry_multiplier,
dry_base=args.dry_base, dry_base=args.dry_base,
dry_allowed_length=args.dry_allowed_length, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling
) )
wav = wav.squeeze(0).cpu().numpy() wav = wav.squeeze(0).cpu().numpy()
@ -240,8 +245,10 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"]) parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"]) parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"]) parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
parser.add_argument("--entropix-sampling", action="store_true")
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
""" """
if not args.references: if not args.references:
raise Exception("No reference audio provided.") raise Exception("No reference audio provided.")
@ -254,6 +261,9 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
duration = metadata.num_frames / metadata.sample_rate duration = metadata.num_frames / metadata.sample_rate
args.max_ar_steps += duration args.max_ar_steps += duration
args.max_ar_steps = math.floor( args.max_ar_steps * 20 ) # assume 20 tokens per second args.max_ar_steps = math.floor( args.max_ar_steps * 20 ) # assume 20 tokens per second
if kwargs.pop("entropix-sampling", False):
args.entropix_sampling = True
tts = init_tts() tts = init_tts()
@ -278,6 +288,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
dry_multiplier=args.dry_multiplier, dry_multiplier=args.dry_multiplier,
dry_base=args.dry_base, dry_base=args.dry_base,
dry_allowed_length=args.dry_allowed_length, dry_allowed_length=args.dry_allowed_length,
entropix_sampling=args.entropix_sampling,
) )
return text return text
@ -342,6 +353,7 @@ with ui:
with gr.Row(): with gr.Row():
#layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") #layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"): with gr.Tab("Sampler Settings"):
with gr.Row(): with gr.Row():
@ -382,6 +394,7 @@ with ui:
layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
with gr.Row(): with gr.Row():
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.") layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en") layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"): with gr.Tab("Sampler Settings"):
with gr.Row(): with gr.Row():