added min-p (really does not seem useful since it's very sensitive), more tweaks to entropix

This commit is contained in:
mrq 2024-10-11 22:36:06 -05:00
parent bef43a0c18
commit d0ab7d755a
8 changed files with 130 additions and 63 deletions

View File

@ -30,6 +30,7 @@ def main():
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--min-p", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0)
@ -62,7 +63,7 @@ def main():
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
top_p=args.top_p, top_k=args.top_k,
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,

View File

@ -192,6 +192,7 @@ class TTS():
#
top_p=1.0,
top_k=0,
min_p=0.0,
#
repetition_penalty=1.0,
repetition_penalty_decay=0.0,
@ -245,7 +246,7 @@ class TTS():
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], max_steps=max_ar_steps,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_length_penalty=length_penalty,
sampling_beam_width=beam_width,
@ -289,7 +290,7 @@ class TTS():
input_prompt_prefix=input_prompt_prefix,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_length_penalty=length_penalty,
sampling_beam_width=beam_width,
@ -308,7 +309,7 @@ class TTS():
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
disable_tqdm=not tqdm,
@ -320,7 +321,7 @@ class TTS():
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
disable_tqdm=not tqdm,

View File

@ -47,6 +47,7 @@ class AR(Base):
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_min_p: float = 0.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
@ -202,6 +203,7 @@ class AR(Base):
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
min_p=sampling_min_p,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,

View File

@ -54,6 +54,7 @@ class AR_NAR(Base):
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_min_p: float = 0.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
@ -235,6 +236,7 @@ class AR_NAR(Base):
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
min_p=sampling_min_p,
#repetition_penalty=sampling_repetition_penalty,
#repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,
@ -314,6 +316,7 @@ class AR_NAR(Base):
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
min_p=sampling_min_p,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,

View File

@ -45,6 +45,9 @@ Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
"""
def clamp(n, lo, hi):
return max(lo, min(n, hi))
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -1473,6 +1476,7 @@ class Base(nn.Module):
min_temperature: float = -1.0, # activates dynamic temperature sampling
top_k: int = -100,
top_p: float = 1.0,
min_p: float = 0.0,
# repetition penalty parameters
repetition_penalty: float = 1.0,
repetition_penalty_decay: float = 0.0,
@ -1508,6 +1512,9 @@ class Base(nn.Module):
if attentions is not None:
entropy = [ calculate_entropix_metrics( logit, attn ) for logit, attn in zip(logits, attentions) ]
if attentions is not None:
entropix_enabled = True
# this might actually slow things down a bit slightly-er?
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
@ -1523,6 +1530,7 @@ class Base(nn.Module):
# adjust sample settings
cfg = EntropixSamplerConfig()
entropy[0]["action"] = -1
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
entropy[0]["action"] = 0
@ -1551,13 +1559,14 @@ class Base(nn.Module):
attn_uncertainty = attn_ent + attn_vent
temperature = temperature * float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement)
top_p = torch.clip(top_p * (1 + cfg.ada_top_p * attn_vent), min=0.1, max=1.0).item()
top_p = float(torch.clip(top_p * (1 + cfg.ada_top_p * attn_vent), min=0.1, max=1.0))
top_k = int(torch.clip(
torch.round(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)),
min=cfg.top_k_min,
max=cfg.top_k_max
))
min_p = torch.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5)
min_p = float(torch.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5))
temperature = clamp( temperature, cfg.temperature_min, cfg.temperature_max )
def _sample( logits ):
# perform repetition penalizing
@ -1569,6 +1578,9 @@ class Base(nn.Module):
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
if min_p > 0.0:
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
# perform top_k/top_p filtering of our logits
if top_k > 0 or top_p < 1.0:
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
@ -1586,30 +1598,44 @@ class Base(nn.Module):
return [ Categorical(logits=logit).sample() for logit in logits ]
samples = [ _sample([ logit.clone() for logit in logits ]) for _ in range(cfg.n_adaptive_samples) ]
if entropix_enabled:
samples = [ _sample([ logit.clone() for logit in logits ]) for _ in range(cfg.n_adaptive_samples) ]
def score_sample(sample):
one_hot = torch.nn.functional.one_hot(sample[0], logit.shape[-1])
log_prob = torch.sum(log_softmax * one_hot)
def score_sample(sample):
one_hot = torch.nn.functional.one_hot(sample[0], logit.shape[-1])
log_prob = torch.sum(log_softmax * one_hot)
confidence_score = (
(1 - ent) * cfg.ada_score_logits_ent +
(1 - attn_ent) * cfg.ada_score_attn_ent +
(1 - vent) * cfg.ada_score_logits_vent +
(1 - attn_vent) * cfg.ada_score_attn_vent +
agreement * cfg.ada_score_agree +
interaction_strength * cfg.ada_score_int
)
return log_prob + confidence_score
confidence_score = (
(1 - ent) * cfg.ada_score_logits_ent +
(1 - attn_ent) * cfg.ada_score_attn_ent +
(1 - vent) * cfg.ada_score_logits_vent +
(1 - attn_vent) * cfg.ada_score_attn_vent +
agreement * cfg.ada_score_agree +
interaction_strength * cfg.ada_score_int
)
return log_prob + confidence_score
sample_scores = [ score_sample(sample) for sample in samples ]
best_sample_idx = torch.argmax(torch.asarray(sample_scores))
res = samples[best_sample_idx]
scores = sample_scores
return Sampled(res, scores, entropy)
sample_scores = [ score_sample(sample) for sample in samples ]
best_sample_idx = torch.argmax(torch.asarray(sample_scores))
res = samples[best_sample_idx]
scores = sample_scores
return Sampled(res, scores, entropy)
temperature = min(1.5, float(temperature))
temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max )
min_temperature = temperature
entropy[0]["temperature"] = temperature
entropy[0]["top_k"] = top_k
entropy[0]["top_p"] = top_p
entropy[0]["min_p"] = min_p
if not entropix_enabled:
temperature = 1.0
min_temperature = 1.0
top_k = 0
top_p = 1.0
min_p = 0.0
# (NAR) disable stop token
if quant_levels is not None and "ar" in self.capabilities:
@ -1633,6 +1659,10 @@ class Base(nn.Module):
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
# perform min_p filtering of our logits
if min_p > 0.0:
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
# perform top_k/top_p filtering of our logits
if top_k > 0 or top_p < 1.0:
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]

View File

@ -45,6 +45,7 @@ class NAR(Base):
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_min_p: float = 0.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
@ -191,6 +192,7 @@ class NAR(Base):
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
min_p=sampling_min_p,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,

View File

@ -50,6 +50,26 @@ def ban_tokens( logits, tokens ):
logits[:, token] = -float("inf")
return logits
# Performs min_p filtering
# From https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/generation/logits_process.py#L537
def min_p_filtering( logits, min_p=0.0, min_tokens_to_keep=32 ):
if min_p <= 0.0:
return logits
# Convert logits to probabilities
probs = torch.softmax(logits, dim=-1)
# Get the probability of the top token for each sequence in the batch
top_probs, _ = probs.max(dim=-1, keepdim=True)
# Calculate the actual min_p threshold by scaling min_p with the top token's probability
scaled_min_p = min_p * top_probs
sorted_indices = torch.argsort(logits, descending=True, dim=-1)
sorted_indices_to_remove = torch.gather(probs < scaled_min_p, dim=-1, index=sorted_indices)
sorted_indices_to_remove[..., :min_tokens_to_keep] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
return logits.masked_fill(indices_to_remove, -float("inf"))
# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
@ -243,46 +263,48 @@ def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ):
# to-do: play around with these values
@dataclass()
class EntropixSamplerConfig:
temp: float = 0.999
top_p: float = 0.90
top_k: int = 32
min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth
temp: float = 0.85
top_p: float = 0.90
top_k: int = 27
min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth
low_ent_thresh: float = 0.1
low_vent_thresh: float = 0.1
med_ent_thresh: float = 3.0
high_ent_thresh: float = 5.0
high_vent_thresh: float = 5.0
low_ent_thresh: float = 0.1 # 3.0
low_vent_thresh: float = 0.1 # 3.0
med_ent_thresh: float = 3.0 # 6.0
high_ent_thresh: float = 5.0 # 9.0
high_vent_thresh: float = 5.0 # 9.0
# TODO this is a bit of a nasty mess, but also makes all the hyperparameters visible
helv_attn_ent_offset: float = 1.3
helv_attn_ent_coef: float = 0.2
# TODO this is a bit of a nasty mess, but also makes all the hyperparameters visible
helv_attn_ent_offset: float = 1.3
helv_attn_ent_coef: float = 0.2
lehv_interaction_strength_offset: float = 1.2
lehv_interaction_strength_coef: float = 0.3
lehv_interaction_strength_offset: float = 1.2
lehv_interaction_strength_coef: float = 0.3
hehv_attn_ent_coef: float = 0.2
hehv_attn_vent_offset: float = 2.0
hehv_attn_vent_coef: float = 0.5
hehv_attn_ent_coef: float = 0.2
hehv_attn_vent_offset: float = 2.0
hehv_attn_vent_coef: float = 0.5
# TODO not convinced this should
n_adaptive_samples: int = 5
# TODO not convinced this should
n_adaptive_samples: int = 5
# Adaptive sampling parameters
ada_temp_logits: float = 0.3
ada_temp_attn: float = 0.2
ada_temp_agree: float = 0.2
ada_top_p: float = 0.1
ada_top_k_int: float = 0.3
ada_top_k_agree: float = 0.2
ada_min_p: float = 0.5
ada_score_logits_ent: float = 0.1
ada_score_attn_ent: float = 0.2
ada_score_logits_vent: float = 0.3
ada_score_attn_vent: float = 0.4
ada_score_agree: float = 0.5
ada_score_int: float = 0.6
# Adaptive sampling parameters
ada_temp_logits: float = 0.3
ada_temp_attn: float = 0.2
ada_temp_agree: float = 0.2
ada_top_p: float = 0.1
ada_top_k_int: float = 0.3
ada_top_k_agree: float = 0.2
ada_min_p: float = 0.5
ada_score_logits_ent: float = 0.1
ada_score_attn_ent: float = 0.2
ada_score_logits_vent: float = 0.3
ada_score_attn_vent: float = 0.4
ada_score_agree: float = 0.5
ada_score_int: float = 0.6
# extra stuff
top_k_min: int = 32
top_k_max: int = 128
# extra stuff
top_k_min: int = 1
top_k_max: int = 1024
temperature_max: float = 1.25
temperature_min: float = 0.5

View File

@ -157,6 +157,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--min-nar-temp", type=float, default=kwargs["min-nar-temp"])
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
parser.add_argument("--min-p", type=float, default=kwargs["min-p"])
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
@ -196,6 +197,7 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
min_nar_temp=args.min_nar_temp,
top_p=args.top_p,
top_k=args.top_k,
min_p=args.min_p,
repetition_penalty=args.repetition_penalty,
repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
@ -228,6 +230,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
parser.add_argument("--min-p", type=int, default=kwargs["min-p"])
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
@ -266,6 +269,7 @@ def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
min_ar_temp=args.min_ar_temp,
top_p=args.top_p,
top_k=args.top_k,
min_p=args.min_p,
repetition_penalty=args.repetition_penalty,
repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
@ -343,6 +347,7 @@ with ui:
with gr.Row():
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
layout["inference_tts"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
with gr.Row():
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
@ -382,6 +387,7 @@ with ui:
with gr.Row():
layout["inference_stt"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
layout["inference_stt"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
layout["inference_stt"]["inputs"]["min-p"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.0, step=0.05, label="Min P")
layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
with gr.Row():
layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.25, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")