modified demo page to be more modular with demoing comparisons, actually provide a path to use modified naive attention, entropix sampling is not tied to an experimental yaml flag now
This commit is contained in:
parent
666e8038fb
commit
04e983b86b
|
@ -43,6 +43,8 @@ def main():
|
|||
parser.add_argument("--dry-base", type=float, default=1.75)
|
||||
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
||||
|
||||
parser.add_argument("--entropix-sampling", action="store_true")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
|
@ -69,6 +71,7 @@ def main():
|
|||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||
entropix_sampling=args.entropix_sampling,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
|
|
|
@ -237,8 +237,6 @@ class ModelExperimentalSettings:
|
|||
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
|
||||
# to-to: just incorporate this as a task instead
|
||||
|
||||
entropix_sampling: bool = False # experimental sampling based on https://github.com/xjdr-alt/entropix, experimental flag because it requires using naive attention for output scores
|
||||
|
||||
# I really need to clean this up
|
||||
@dataclass()
|
||||
class Model:
|
||||
|
|
|
@ -76,6 +76,8 @@ def main():
|
|||
parser.add_argument("--dry-base", type=float, default=1.75)
|
||||
parser.add_argument("--dry-allowed-length", type=int, default=2)
|
||||
|
||||
parser.add_argument("--entropix-sampling", action="store_true")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
|
@ -84,6 +86,7 @@ def main():
|
|||
|
||||
parser.add_argument("--random-prompts", action="store_true")
|
||||
parser.add_argument("--lora", action="store_true")
|
||||
parser.add_argument("--comparison", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -98,6 +101,30 @@ def main():
|
|||
'Unlike the original VALL-E demo page, I\'m placing emphasis on the input prompt, as the model adheres to it stronger than others.',
|
||||
])
|
||||
|
||||
# comparison kwargs
|
||||
comparison_kwargs = {
|
||||
"enabled": False,
|
||||
"titles": [],
|
||||
"suffix": "_after",
|
||||
"before": {},
|
||||
"after": {}
|
||||
}
|
||||
|
||||
if args.lora:
|
||||
comparison_kwargs["enabled"] = True
|
||||
comparison_kwargs["suffix"] = "_lora"
|
||||
comparison_kwargs["titles"] = ["No LoRA", "LoRA"]
|
||||
comparison_kwargs["before"]["use_lora"] = True
|
||||
comparison_kwargs["after"]["use_lora"] = False
|
||||
# to-do: make this user definable
|
||||
elif args.comparison:
|
||||
comparison_kwargs["enabled"] = True
|
||||
comparison_kwargs["suffix"] = "_entropix"
|
||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||
comparison_kwargs["before"]["entropix_sampling"] = True
|
||||
comparison_kwargs["after"]["entropix_sampling"] = False
|
||||
|
||||
|
||||
# read html template
|
||||
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
||||
|
||||
|
@ -114,6 +141,7 @@ def main():
|
|||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||
entropix_sampling=args.entropix_sampling,
|
||||
)) )
|
||||
|
||||
# pull from provided samples
|
||||
|
@ -127,7 +155,7 @@ def main():
|
|||
# pull from dataset samples
|
||||
if args.sample_from_dataset:
|
||||
cfg.dataset.cache = False
|
||||
cfg.dataset.sample_type = "path" if args.lora else "speaker"
|
||||
cfg.dataset.sample_type = "path" if len(cfg.dataset.training) < cfg.evaluation.batch_size else "speaker"
|
||||
cfg.dataset.tasks_list = [ 'tts' ]
|
||||
|
||||
samples_dirs["dataset"] = args.demo_dir / args.dataset_dir_name
|
||||
|
@ -180,9 +208,9 @@ def main():
|
|||
prompt = dir / "prompt.wav"
|
||||
reference = dir / "reference.wav"
|
||||
out_path = dir / "out" / "ours.wav"
|
||||
out_path_lora = dir / "out" / "ours_lora.wav"
|
||||
out_path_comparison = dir / "out" / f"ours_{comparison_kwargs["suffix"]}.wav"
|
||||
|
||||
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_lora ] if args.lora else [])
|
||||
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_comparison ] if comparison_kwargs["enabled"] else [])
|
||||
|
||||
if not args.random_prompts or k == "librispeech":
|
||||
extra_sources += [ reference ]
|
||||
|
@ -210,23 +238,24 @@ def main():
|
|||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
|
||||
dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
|
||||
entropix_sampling=args.entropix_sampling,
|
||||
seed=seed,
|
||||
tqdm=False,
|
||||
)
|
||||
|
||||
if args.lora:
|
||||
tts.enable_lora() # I don't think this is necessary with the below
|
||||
kwargs["use_lora"] = True
|
||||
def safe_inference():
|
||||
try:
|
||||
tts.inference( out_path=out_path_lora, **kwargs )
|
||||
tts.inference( out_path=out_path_comparison, **kwargs )
|
||||
except Exception as e:
|
||||
print(f'Error while processing {out_path}: {e}')
|
||||
tts.disable_lora()
|
||||
kwargs["use_lora"] = False
|
||||
try:
|
||||
tts.inference( out_path=out_path, **kwargs )
|
||||
except Exception as e:
|
||||
print(f'Error while processing {out_path}: {e}')
|
||||
|
||||
if comparison_kwargs["enabled"]:
|
||||
kwargs.update( comparison_kwargs["before"] )
|
||||
safe_inference()
|
||||
kwargs.update( comparison_kwargs["after"] )
|
||||
|
||||
safe_inference()
|
||||
|
||||
|
||||
# collate entries into HTML
|
||||
|
@ -243,11 +272,12 @@ def main():
|
|||
# write audio into template
|
||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||
|
||||
if args.lora:
|
||||
if comparison_kwargs["enabled"]:
|
||||
before, after = comparison_kwargs["titles"]
|
||||
if args.random_prompts:
|
||||
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<th>Our VALL-E (LoRA)</th>")
|
||||
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
|
||||
else:
|
||||
html = html.replace("<th>Our VALL-E</th>", "<th>Our VALL-E (No LoRA)</th>\n\t\t\t\t\t<th>Our VALL-E (LoRA)</th>")
|
||||
html = html.replace("<th>Our VALL-E</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
|
||||
|
||||
# write demo page
|
||||
open( args.demo_dir / args.output_filename, "w", encoding="utf-8" ).write( html )
|
||||
|
|
|
@ -206,7 +206,9 @@ class TTS():
|
|||
dry_multiplier=0.0,
|
||||
dry_base=1.75,
|
||||
dry_allowed_length=2,
|
||||
|
||||
#
|
||||
entropix_sampling=False,
|
||||
#
|
||||
seed = None,
|
||||
|
||||
out_path=None,
|
||||
|
@ -255,6 +257,7 @@ class TTS():
|
|||
sampling_dry_multiplier=dry_multiplier,
|
||||
sampling_dry_base=dry_base,
|
||||
sampling_dry_allowed_length=dry_allowed_length,
|
||||
sampling_entropix=entropix_sampling,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
|
@ -299,6 +302,7 @@ class TTS():
|
|||
sampling_dry_multiplier=dry_multiplier,
|
||||
sampling_dry_base=dry_base,
|
||||
sampling_dry_allowed_length=dry_allowed_length,
|
||||
sampling_entropix=entropix_sampling,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
|
|
|
@ -64,8 +64,7 @@ class AR_NAR(Base):
|
|||
sampling_dry_multiplier=0.0,
|
||||
sampling_dry_base=1.75,
|
||||
sampling_dry_allowed_length=2,
|
||||
|
||||
sampling_entropix=None,
|
||||
sampling_entropix=False,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
|
@ -269,9 +268,6 @@ class AR_NAR(Base):
|
|||
scores = [ 1.0 ] * sampling_beam_width
|
||||
entropies = []
|
||||
|
||||
if sampling_entropix is None:
|
||||
sampling_entropix = self.config.experimental.entropix_sampling
|
||||
|
||||
for i, sequence in enumerate( sequence_list ):
|
||||
# add <bos> to text for STT
|
||||
if task_list[i] in text_task:
|
||||
|
|
|
@ -126,7 +126,8 @@ if torch.backends.cuda.cudnn_sdp_enabled():
|
|||
AVAILABLE_ATTENTIONS.append("cudnn")
|
||||
|
||||
if AVAILABLE_ATTENTIONS:
|
||||
AVAILABLE_ATTENTIONS.append("sdpa")
|
||||
AVAILABLE_ATTENTIONS.append("sdpa")
|
||||
AVAILABLE_ATTENTIONS.append("default")
|
||||
|
||||
class LlamaAttention_Adapted(LlamaAttention):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -144,6 +145,8 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
|
||||
elif self.mode == "cudnn":
|
||||
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
else:
|
||||
self.mode = None
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
@ -256,7 +259,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
if output_attentions or not self.mode:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
return self._forward(
|
||||
hidden_states=hidden_states,
|
||||
|
|
|
@ -531,7 +531,7 @@ class Base(nn.Module):
|
|||
if AVAILABLE_ATTENTIONS:
|
||||
attention_backend = AVAILABLE_ATTENTIONS[0]
|
||||
else:
|
||||
attention_backend = "eager"
|
||||
attention_backend = "default"
|
||||
|
||||
hf_attention = attention_backend
|
||||
HF_ATTENTIONS = ["eager", "sdpa", "flash_attention_2"]
|
||||
|
@ -1530,13 +1530,14 @@ class Base(nn.Module):
|
|||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
|
||||
# (AR) entropix sampling
|
||||
# we do it after the penalizers because entropix's internal sampling doesn't account for them (but does do top_k/top_p/min_p)
|
||||
if attentions is not None and quant_levels is None:
|
||||
# move to CPU for speedups
|
||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
|
||||
res = [ sample_entropix(
|
||||
logit,
|
||||
attentions[-1], #torch.stack(attentions, dim=1),
|
||||
attentions[-1], # original code just uses the last attention scores
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
|
|
|
@ -407,7 +407,7 @@ def sample_entropix(
|
|||
else:
|
||||
metrics["action"] = 4
|
||||
|
||||
log_softmax = torch.nn.functional.log_softmax(logits)
|
||||
log_softmax = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
logits_uncertainty = ent + vent
|
||||
attn_uncertainty = attn_ent + attn_vent
|
||||
|
||||
|
|
|
@ -167,6 +167,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
|
||||
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
|
||||
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
||||
parser.add_argument("--entropix-sampling", action="store_true")
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix='.wav')
|
||||
|
@ -176,6 +177,9 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
raise Exception("No reference audio provided.")
|
||||
"""
|
||||
|
||||
if kwargs.pop("entropix-sampling", False):
|
||||
args.entropix_sampling = True
|
||||
|
||||
tts = init_tts()
|
||||
|
||||
gr.Info("Inferencing...")
|
||||
|
@ -206,6 +210,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
dry_multiplier=args.dry_multiplier,
|
||||
dry_base=args.dry_base,
|
||||
dry_allowed_length=args.dry_allowed_length,
|
||||
entropix_sampling=args.entropix_sampling
|
||||
)
|
||||
|
||||
wav = wav.squeeze(0).cpu().numpy()
|
||||
|
@ -240,8 +245,10 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
|
||||
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
|
||||
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
||||
parser.add_argument("--entropix-sampling", action="store_true")
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
|
||||
"""
|
||||
if not args.references:
|
||||
raise Exception("No reference audio provided.")
|
||||
|
@ -254,6 +261,9 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
duration = metadata.num_frames / metadata.sample_rate
|
||||
args.max_ar_steps += duration
|
||||
args.max_ar_steps = math.floor( args.max_ar_steps * 20 ) # assume 20 tokens per second
|
||||
|
||||
if kwargs.pop("entropix-sampling", False):
|
||||
args.entropix_sampling = True
|
||||
|
||||
tts = init_tts()
|
||||
|
||||
|
@ -278,6 +288,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
dry_multiplier=args.dry_multiplier,
|
||||
dry_base=args.dry_base,
|
||||
dry_allowed_length=args.dry_allowed_length,
|
||||
entropix_sampling=args.entropix_sampling,
|
||||
)
|
||||
|
||||
return text
|
||||
|
@ -342,6 +353,7 @@ with ui:
|
|||
with gr.Row():
|
||||
#layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
||||
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||
with gr.Tab("Sampler Settings"):
|
||||
with gr.Row():
|
||||
|
@ -382,6 +394,7 @@ with ui:
|
|||
layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||
with gr.Row():
|
||||
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||
layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||
with gr.Tab("Sampler Settings"):
|
||||
with gr.Row():
|
||||
|
|
Loading…
Reference in New Issue
Block a user