entropix tweaks (it doesn't output garbage but it loves to go for silence)

This commit is contained in:
mrq 2024-10-12 09:46:18 -05:00
parent d0ab7d755a
commit d6f7c86a5c
2 changed files with 143 additions and 127 deletions

View File

@ -1507,135 +1507,22 @@ class Base(nn.Module):
elif self.causal:
logits = [ logit[-self.causal_size:] for logit in logits ]
# calculate entropies
# I would love to shove it in samplers.py but we modify our sampler settings
# entropix sampling
if attentions is not None:
entropy = [ calculate_entropix_metrics( logit, attn ) for logit, attn in zip(logits, attentions) ]
if attentions is not None:
entropix_enabled = True
# this might actually slow things down a bit slightly-er?
# move to CPU for speedups
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
# to-do: not make it hardcoded to bsz=1
metrics = entropy[0]
logit = logits[0]
res = [ sample_entropix(
logit,
attentions[-1], #torch.stack(attentions, dim=1),
temperature,
top_k,
top_p,
min_p,
) for logit in logits ]
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
agreement = metrics["agreement"]
interaction_strength = metrics["interaction_strength"]
# adjust sample settings
cfg = EntropixSamplerConfig()
entropy[0]["action"] = -1
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
entropy[0]["action"] = 0
temperature *= 0
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh:
entropy[0]["action"] = 1
# sample with slightly higher temperature
temperature *= cfg.helv_attn_ent_offset + cfg.helv_attn_ent_coef * attn_ent # Increase temperature based on attention entropy
# Low Entropy, High Varentropy: "exploring forks in the path"
elif ent < cfg.high_ent_thresh and vent > cfg.high_vent_thresh:
entropy[0]["action"] = 2
temperature *= cfg.lehv_interaction_strength_offset + cfg.lehv_interaction_strength_coef * interaction_strength # Increase temperature based on interaction strength
top_k = max(5, int(top_k * (1 + 0.5 * (1 - agreement)))) # Increase top_k when agreement is low
# High Entropy, High Varentropy: "resampling in the mist"
elif ent > cfg.med_ent_thresh and vent > cfg.high_vent_thresh:
entropy[0]["action"] = 3
# Use high temperature and adjusted top_p based on attention metrics
temperature *= cfg.hehv_attn_vent_offset + cfg.hehv_attn_vent_coef * attn_vent # Increase temperature based on attention varentropy
top_p = max(0.5, top_p - cfg.hehv_attn_ent_coef * attn_ent) # Decrease top_p when attention entropy is high
# Middle ground: use adaptive sampling
else:
entropy[0]["action"] = 4
log_softmax = torch.nn.functional.log_softmax(logit)
logits_uncertainty = ent + vent
attn_uncertainty = attn_ent + attn_vent
temperature = temperature * float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement)
top_p = float(torch.clip(top_p * (1 + cfg.ada_top_p * attn_vent), min=0.1, max=1.0))
top_k = int(torch.clip(
torch.round(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)),
min=cfg.top_k_min,
max=cfg.top_k_max
))
min_p = float(torch.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5))
temperature = clamp( temperature, cfg.temperature_min, cfg.temperature_max )
def _sample( logits ):
# perform repetition penalizing
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
# to-do: figure out a faster way to handle tolist()
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# (AR) perform length penalizing
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
if min_p > 0.0:
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
# perform top_k/top_p filtering of our logits
if top_k > 0 or top_p < 1.0:
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
# epsilon float comparison because I don't trust Python
if abs(temperature - min_temperature) >= 0.001:
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
else:
logits = [ logit / temperature for logit in logits ]
# do DRY sampling
if dry_multiplier > 0.0:
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ]
return [ Categorical(logits=logit).sample() for logit in logits ]
if entropix_enabled:
samples = [ _sample([ logit.clone() for logit in logits ]) for _ in range(cfg.n_adaptive_samples) ]
def score_sample(sample):
one_hot = torch.nn.functional.one_hot(sample[0], logit.shape[-1])
log_prob = torch.sum(log_softmax * one_hot)
confidence_score = (
(1 - ent) * cfg.ada_score_logits_ent +
(1 - attn_ent) * cfg.ada_score_attn_ent +
(1 - vent) * cfg.ada_score_logits_vent +
(1 - attn_vent) * cfg.ada_score_attn_vent +
agreement * cfg.ada_score_agree +
interaction_strength * cfg.ada_score_int
)
return log_prob + confidence_score
sample_scores = [ score_sample(sample) for sample in samples ]
best_sample_idx = torch.argmax(torch.asarray(sample_scores))
res = samples[best_sample_idx]
scores = sample_scores
return Sampled(res, scores, entropy)
temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max )
min_temperature = temperature
entropy[0]["temperature"] = temperature
entropy[0]["top_k"] = top_k
entropy[0]["top_p"] = top_p
entropy[0]["min_p"] = min_p
if not entropix_enabled:
temperature = 1.0
min_temperature = 1.0
top_k = 0
top_p = 1.0
min_p = 0.0
if res:
return Sampled([ r[0] for r in res], scores, [ r[1] for r in res])
# (NAR) disable stop token
if quant_levels is not None and "ar" in self.capabilities:

View File

@ -258,6 +258,7 @@ def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ):
"attn_varentropy": torch.mean(attn_varentropy),
"agreement": torch.mean(agreement),
"interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)),
"action": -1
}
# to-do: play around with these values
@ -304,7 +305,135 @@ class EntropixSamplerConfig:
ada_score_int: float = 0.6
# extra stuff
temperature_max: float = 1.25
temperature_min: float = 0.5
top_k_min: int = 1
top_k_max: int = 1024
temperature_max: float = 1.25
temperature_min: float = 0.5
top_p_min: int = 0.1
top_p_max: int = 1.0
min_p_min: int = 0.01
min_p_max: int = 0.5
Exponential = torch.distributions.exponential.Exponential(1.0)
def _sample_entropix(
logits,
temperature=1.0,
top_k=0,
top_p=1.0,
min_p=0.0,
cfg=EntropixSamplerConfig(),
):
def clamp(n, lo, hi):
return max(lo, min(n, hi))
if top_k == 0:
top_k = logits.shape[-1]
temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max )
top_p = clamp( float(top_p), cfg.top_p_min, cfg.top_p_max )
top_k = clamp( int(top_k), cfg.top_k_min, cfg.top_k_max )
min_p = clamp( float(min_p), cfg.min_p_min, cfg.min_p_max )
probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
# Apply min_p sampling
if min_p > 0.0:
p_max = float(torch.max(probs, dim=-1, keepdims=True).values)
indices_to_remove = probs < (min_p * p_max)
logits = torch.where(indices_to_remove, torch.full_like(logits, float('-inf')), logits)
# Apply top-k sampling
top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
probs_sort = torch.flip(top_k_probs, dims=[-1])
probs_idx = torch.flip(top_k_indices, dims=[-1])
probs_sum = torch.cumsum(probs_sort, dim=-1)
# Apply top-p sampling
mask = torch.where(probs_sum - probs_sort > top_p, 1.0, 0.0)
probs_sort = probs_sort * (1 - mask)
probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdims=True)
next_token = torch.argmax(probs_sort / Exponential.sample(probs_sort.shape), dim=-1, keepdim=True)
return torch.take_along_dim(probs_idx, next_token, dim=-1)
def sample_entropix(
logits,
attentions,
temperature=1.0,
top_k=32,
top_p=1.0,
min_p=0.0,
cfg=EntropixSamplerConfig(),
):
metrics = calculate_entropix_metrics( logits, attentions )
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
agreement = metrics["agreement"]
interaction_strength = metrics["interaction_strength"]
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
metrics["action"] = 0
res = logits.argmax(dim=1)
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh:
metrics["action"] = 1
# sample with slightly higher temperature
temperature *= cfg.helv_attn_ent_offset + cfg.helv_attn_ent_coef * attn_ent # Increase temperature based on attention entropy
res = _sample_entropix( logits, temperature, top_k, top_p, min_p, cfg=cfg )
# Low Entropy, High Varentropy: "exploring forks in the path"
elif ent < cfg.high_ent_thresh and vent > cfg.high_vent_thresh:
metrics["action"] = 2
temperature *= cfg.lehv_interaction_strength_offset + cfg.lehv_interaction_strength_coef * interaction_strength # Increase temperature based on interaction strength
top_k = max(5, int(top_k * (1 + 0.5 * (1 - agreement)))) # Increase top_k when agreement is low
res = _sample_entropix( logits, temperature, top_k, top_p, min_p, cfg=cfg )
# High Entropy, High Varentropy: "resampling in the mist"
elif ent > cfg.med_ent_thresh and vent > cfg.high_vent_thresh:
metrics["action"] = 3
# Use high temperature and adjusted top_p based on attention metrics
temperature *= cfg.hehv_attn_vent_offset + cfg.hehv_attn_vent_coef * attn_vent # Increase temperature based on attention varentropy
top_p = max(0.5, top_p - cfg.hehv_attn_ent_coef * attn_ent) # Decrease top_p when attention entropy is high
res = _sample_entropix( logits, temperature, top_k, top_p, min_p, cfg=cfg )
# Middle ground: use adaptive sampling
else:
metrics["action"] = 4
log_softmax = torch.nn.functional.log_softmax(logits)
logits_uncertainty = ent + vent
attn_uncertainty = attn_ent + attn_vent
temperature *= float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement)
top_p = float(top_p * (1 + cfg.ada_top_p * attn_vent))
top_k = int(round(float(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement))))
min_p = float(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty))
samples = [ _sample_entropix( logits.clone(), temperature, top_k, top_p, min_p, cfg=cfg ) for _ in range(cfg.n_adaptive_samples) ]
def score_sample(sample):
one_hot = torch.nn.functional.one_hot( sample, logits.shape[-1] )
log_prob = torch.sum(log_softmax * one_hot)
confidence_score = (
(1 - ent) * cfg.ada_score_logits_ent +
(1 - attn_ent) * cfg.ada_score_attn_ent +
(1 - vent) * cfg.ada_score_logits_vent +
(1 - attn_vent) * cfg.ada_score_attn_vent +
agreement * cfg.ada_score_agree +
interaction_strength * cfg.ada_score_int
)
return log_prob + confidence_score
sample_scores = [ score_sample(sample) for sample in samples ]
best_sample_idx = torch.argmax(torch.asarray(sample_scores))
res = samples[best_sample_idx]
"""
metrics["temperature"] = temperature
metrics["top_k"] = top_k
metrics["top_p"] = top_p
metrics["min_p"] = min_p
"""
return res, metrics