diff --git a/vall_e/__main__.py b/vall_e/__main__.py index a478921..a1806c7 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -30,6 +30,7 @@ def main(): parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=0) + parser.add_argument("--min-p", type=float, default=0.0) parser.add_argument("--repetition-penalty", type=float, default=1.0) parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) parser.add_argument("--length-penalty", type=float, default=0.0) @@ -62,7 +63,7 @@ def main(): max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, ar_temp=args.ar_temp, nar_temp=args.nar_temp, min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, - top_p=args.top_p, top_k=args.top_k, + top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, beam_width=args.beam_width, diff --git a/vall_e/inference.py b/vall_e/inference.py index 5dd1459..07b4ed1 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -192,6 +192,7 @@ class TTS(): # top_p=1.0, top_k=0, + min_p=0.0, # repetition_penalty=1.0, repetition_penalty_decay=0.0, @@ -245,7 +246,7 @@ class TTS(): text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_min_temperature=min_ar_temp, - sampling_top_p=top_p, sampling_top_k=top_k, + sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty, sampling_beam_width=beam_width, @@ -289,7 +290,7 @@ class TTS(): input_prompt_prefix=input_prompt_prefix, sampling_temperature=ar_temp, sampling_min_temperature=min_ar_temp, - sampling_top_p=top_p, sampling_top_k=top_k, + sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, sampling_length_penalty=length_penalty, sampling_beam_width=beam_width, @@ -308,7 +309,7 @@ class TTS(): max_levels=max_nar_levels, sampling_temperature=nar_temp, sampling_min_temperature=min_nar_temp, - sampling_top_p=top_p, sampling_top_k=top_k, + sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, disable_tqdm=not tqdm, @@ -320,7 +321,7 @@ class TTS(): max_levels=max_nar_levels, sampling_temperature=nar_temp, sampling_min_temperature=min_nar_temp, - sampling_top_p=top_p, sampling_top_k=top_k, + sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, disable_tqdm=not tqdm, diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 06d9897..b5b9dbb 100644 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -47,6 +47,7 @@ class AR(Base): sampling_min_temperature: float = -1.0, sampling_top_k: int = -100, sampling_top_p: float = 1.0, + sampling_min_p: float = 0.0, sampling_repetition_penalty: float = 1.0, sampling_repetition_penalty_decay: float = 0.0, sampling_length_penalty: float = 0.0, @@ -202,6 +203,7 @@ class AR(Base): min_temperature=sampling_min_temperature, top_p=sampling_top_p, top_k=sampling_top_k, + min_p=sampling_min_p, repetition_penalty=sampling_repetition_penalty, repetition_penalty_decay=sampling_repetition_penalty_decay, length_penalty=sampling_length_penalty, diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index f95b764..d08abfc 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -54,6 +54,7 @@ class AR_NAR(Base): sampling_min_temperature: float = -1.0, sampling_top_k: int = -100, sampling_top_p: float = 1.0, + sampling_min_p: float = 0.0, sampling_repetition_penalty: float = 1.0, sampling_repetition_penalty_decay: float = 0.0, sampling_length_penalty: float = 0.0, @@ -235,6 +236,7 @@ class AR_NAR(Base): min_temperature=sampling_min_temperature, top_p=sampling_top_p, top_k=sampling_top_k, + min_p=sampling_min_p, #repetition_penalty=sampling_repetition_penalty, #repetition_penalty_decay=sampling_repetition_penalty_decay, #length_penalty=sampling_length_penalty, @@ -314,6 +316,7 @@ class AR_NAR(Base): min_temperature=sampling_min_temperature, top_p=sampling_top_p, top_k=sampling_top_k, + min_p=sampling_min_p, repetition_penalty=sampling_repetition_penalty, repetition_penalty_decay=sampling_repetition_penalty_decay, length_penalty=sampling_length_penalty, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 0603643..c2d976f 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -45,6 +45,9 @@ Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more from ..utils.pattern import DelayedPatternProvider, VALLEPattern """ +def clamp(n, lo, hi): + return max(lo, min(n, hi)) + def _create_mask(l, device): """1 is valid region and 0 is invalid.""" seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) @@ -1473,6 +1476,7 @@ class Base(nn.Module): min_temperature: float = -1.0, # activates dynamic temperature sampling top_k: int = -100, top_p: float = 1.0, + min_p: float = 0.0, # repetition penalty parameters repetition_penalty: float = 1.0, repetition_penalty_decay: float = 0.0, @@ -1508,6 +1512,9 @@ class Base(nn.Module): if attentions is not None: entropy = [ calculate_entropix_metrics( logit, attn ) for logit, attn in zip(logits, attentions) ] + if attentions is not None: + entropix_enabled = True + # this might actually slow things down a bit slightly-er? logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] @@ -1523,6 +1530,7 @@ class Base(nn.Module): # adjust sample settings cfg = EntropixSamplerConfig() + entropy[0]["action"] = -1 # Low Entropy, Low Varentropy: "flowing with unspoken intent" if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh: entropy[0]["action"] = 0 @@ -1551,13 +1559,14 @@ class Base(nn.Module): attn_uncertainty = attn_ent + attn_vent temperature = temperature * float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement) - top_p = torch.clip(top_p * (1 + cfg.ada_top_p * attn_vent), min=0.1, max=1.0).item() + top_p = float(torch.clip(top_p * (1 + cfg.ada_top_p * attn_vent), min=0.1, max=1.0)) top_k = int(torch.clip( torch.round(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)), min=cfg.top_k_min, max=cfg.top_k_max )) - min_p = torch.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5) + min_p = float(torch.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5)) + temperature = clamp( temperature, cfg.temperature_min, cfg.temperature_max ) def _sample( logits ): # perform repetition penalizing @@ -1569,6 +1578,9 @@ class Base(nn.Module): if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0: logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ] + if min_p > 0.0: + logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ] + # perform top_k/top_p filtering of our logits if top_k > 0 or top_p < 1.0: logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ] @@ -1586,30 +1598,44 @@ class Base(nn.Module): return [ Categorical(logits=logit).sample() for logit in logits ] - samples = [ _sample([ logit.clone() for logit in logits ]) for _ in range(cfg.n_adaptive_samples) ] + if entropix_enabled: + samples = [ _sample([ logit.clone() for logit in logits ]) for _ in range(cfg.n_adaptive_samples) ] - def score_sample(sample): - one_hot = torch.nn.functional.one_hot(sample[0], logit.shape[-1]) - log_prob = torch.sum(log_softmax * one_hot) + def score_sample(sample): + one_hot = torch.nn.functional.one_hot(sample[0], logit.shape[-1]) + log_prob = torch.sum(log_softmax * one_hot) - confidence_score = ( - (1 - ent) * cfg.ada_score_logits_ent + - (1 - attn_ent) * cfg.ada_score_attn_ent + - (1 - vent) * cfg.ada_score_logits_vent + - (1 - attn_vent) * cfg.ada_score_attn_vent + - agreement * cfg.ada_score_agree + - interaction_strength * cfg.ada_score_int - ) - return log_prob + confidence_score + confidence_score = ( + (1 - ent) * cfg.ada_score_logits_ent + + (1 - attn_ent) * cfg.ada_score_attn_ent + + (1 - vent) * cfg.ada_score_logits_vent + + (1 - attn_vent) * cfg.ada_score_attn_vent + + agreement * cfg.ada_score_agree + + interaction_strength * cfg.ada_score_int + ) + return log_prob + confidence_score - sample_scores = [ score_sample(sample) for sample in samples ] - best_sample_idx = torch.argmax(torch.asarray(sample_scores)) - - res = samples[best_sample_idx] - scores = sample_scores - return Sampled(res, scores, entropy) + sample_scores = [ score_sample(sample) for sample in samples ] + best_sample_idx = torch.argmax(torch.asarray(sample_scores)) + + res = samples[best_sample_idx] + scores = sample_scores + return Sampled(res, scores, entropy) - temperature = min(1.5, float(temperature)) + temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max ) + min_temperature = temperature + + entropy[0]["temperature"] = temperature + entropy[0]["top_k"] = top_k + entropy[0]["top_p"] = top_p + entropy[0]["min_p"] = min_p + + if not entropix_enabled: + temperature = 1.0 + min_temperature = 1.0 + top_k = 0 + top_p = 1.0 + min_p = 0.0 # (NAR) disable stop token if quant_levels is not None and "ar" in self.capabilities: @@ -1633,6 +1659,10 @@ class Base(nn.Module): if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0: logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ] + # perform min_p filtering of our logits + if min_p > 0.0: + logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ] + # perform top_k/top_p filtering of our logits if top_k > 0 or top_p < 1.0: logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ] diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 6e07bbb..74cbf76 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -45,6 +45,7 @@ class NAR(Base): sampling_min_temperature: float = -1.0, sampling_top_k: int = -100, sampling_top_p: float = 1.0, + sampling_min_p: float = 0.0, sampling_repetition_penalty: float = 1.0, sampling_repetition_penalty_decay: float = 0.0, sampling_length_penalty: float = 0.0, @@ -191,6 +192,7 @@ class NAR(Base): min_temperature=sampling_min_temperature, top_p=sampling_top_p, top_k=sampling_top_k, + min_p=sampling_min_p, repetition_penalty=sampling_repetition_penalty, repetition_penalty_decay=sampling_repetition_penalty_decay, #length_penalty=sampling_length_penalty, diff --git a/vall_e/samplers.py b/vall_e/samplers.py index f99bee2..b5f887e 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -50,6 +50,26 @@ def ban_tokens( logits, tokens ): logits[:, token] = -float("inf") return logits +# Performs min_p filtering +# From https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/generation/logits_process.py#L537 +def min_p_filtering( logits, min_p=0.0, min_tokens_to_keep=32 ): + if min_p <= 0.0: + return logits + + # Convert logits to probabilities + probs = torch.softmax(logits, dim=-1) + # Get the probability of the top token for each sequence in the batch + top_probs, _ = probs.max(dim=-1, keepdim=True) + # Calculate the actual min_p threshold by scaling min_p with the top token's probability + scaled_min_p = min_p * top_probs + + sorted_indices = torch.argsort(logits, descending=True, dim=-1) + sorted_indices_to_remove = torch.gather(probs < scaled_min_p, dim=-1, index=sorted_indices) + sorted_indices_to_remove[..., :min_tokens_to_keep] = False + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + return logits.masked_fill(indices_to_remove, -float("inf")) + # Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ): """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering @@ -243,46 +263,48 @@ def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ): # to-do: play around with these values @dataclass() class EntropixSamplerConfig: - temp: float = 0.999 - top_p: float = 0.90 - top_k: int = 32 - min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth + temp: float = 0.85 + top_p: float = 0.90 + top_k: int = 27 + min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth - low_ent_thresh: float = 0.1 - low_vent_thresh: float = 0.1 - med_ent_thresh: float = 3.0 - high_ent_thresh: float = 5.0 - high_vent_thresh: float = 5.0 + low_ent_thresh: float = 0.1 # 3.0 + low_vent_thresh: float = 0.1 # 3.0 + med_ent_thresh: float = 3.0 # 6.0 + high_ent_thresh: float = 5.0 # 9.0 + high_vent_thresh: float = 5.0 # 9.0 - # TODO this is a bit of a nasty mess, but also makes all the hyperparameters visible - helv_attn_ent_offset: float = 1.3 - helv_attn_ent_coef: float = 0.2 + # TODO this is a bit of a nasty mess, but also makes all the hyperparameters visible + helv_attn_ent_offset: float = 1.3 + helv_attn_ent_coef: float = 0.2 - lehv_interaction_strength_offset: float = 1.2 - lehv_interaction_strength_coef: float = 0.3 + lehv_interaction_strength_offset: float = 1.2 + lehv_interaction_strength_coef: float = 0.3 - hehv_attn_ent_coef: float = 0.2 - hehv_attn_vent_offset: float = 2.0 - hehv_attn_vent_coef: float = 0.5 + hehv_attn_ent_coef: float = 0.2 + hehv_attn_vent_offset: float = 2.0 + hehv_attn_vent_coef: float = 0.5 - # TODO not convinced this should - n_adaptive_samples: int = 5 + # TODO not convinced this should + n_adaptive_samples: int = 5 - # Adaptive sampling parameters - ada_temp_logits: float = 0.3 - ada_temp_attn: float = 0.2 - ada_temp_agree: float = 0.2 - ada_top_p: float = 0.1 - ada_top_k_int: float = 0.3 - ada_top_k_agree: float = 0.2 - ada_min_p: float = 0.5 - ada_score_logits_ent: float = 0.1 - ada_score_attn_ent: float = 0.2 - ada_score_logits_vent: float = 0.3 - ada_score_attn_vent: float = 0.4 - ada_score_agree: float = 0.5 - ada_score_int: float = 0.6 + # Adaptive sampling parameters + ada_temp_logits: float = 0.3 + ada_temp_attn: float = 0.2 + ada_temp_agree: float = 0.2 + ada_top_p: float = 0.1 + ada_top_k_int: float = 0.3 + ada_top_k_agree: float = 0.2 + ada_min_p: float = 0.5 + ada_score_logits_ent: float = 0.1 + ada_score_attn_ent: float = 0.2 + ada_score_logits_vent: float = 0.3 + ada_score_attn_vent: float = 0.4 + ada_score_agree: float = 0.5 + ada_score_int: float = 0.6 - # extra stuff - top_k_min: int = 32 - top_k_max: int = 128 \ No newline at end of file + # extra stuff + top_k_min: int = 1 + top_k_max: int = 1024 + temperature_max: float = 1.25 + temperature_min: float = 0.5 \ No newline at end of file diff --git a/vall_e/webui.py b/vall_e/webui.py index f8307ca..25ece58 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -157,6 +157,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--min-nar-temp", type=float, default=kwargs["min-nar-temp"]) parser.add_argument("--top-p", type=float, default=kwargs["top-p"]) parser.add_argument("--top-k", type=int, default=kwargs["top-k"]) + parser.add_argument("--min-p", type=float, default=kwargs["min-p"]) parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"]) parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"]) parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"]) @@ -196,6 +197,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): min_nar_temp=args.min_nar_temp, top_p=args.top_p, top_k=args.top_k, + min_p=args.min_p, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, @@ -228,6 +230,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"]) parser.add_argument("--top-p", type=float, default=kwargs["top-p"]) parser.add_argument("--top-k", type=int, default=kwargs["top-k"]) + parser.add_argument("--min-p", type=int, default=kwargs["min-p"]) parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"]) parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"]) parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"]) @@ -266,6 +269,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): min_ar_temp=args.min_ar_temp, top_p=args.top_p, top_k=args.top_k, + min_p=args.min_p, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, @@ -343,6 +347,7 @@ with ui: with gr.Row(): layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.") layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.") + layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P") layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") with gr.Row(): layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") @@ -382,6 +387,7 @@ with ui: with gr.Row(): layout["inference_stt"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.") layout["inference_stt"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.") + layout["inference_stt"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P") layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.") with gr.Row(): layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.25, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")