entropix tweaks (it doesn't output garbage but it loves to go for silence)
This commit is contained in:
parent
d0ab7d755a
commit
d6f7c86a5c
|
@ -1507,135 +1507,22 @@ class Base(nn.Module):
|
|||
elif self.causal:
|
||||
logits = [ logit[-self.causal_size:] for logit in logits ]
|
||||
|
||||
# calculate entropies
|
||||
# I would love to shove it in samplers.py but we modify our sampler settings
|
||||
# entropix sampling
|
||||
if attentions is not None:
|
||||
entropy = [ calculate_entropix_metrics( logit, attn ) for logit, attn in zip(logits, attentions) ]
|
||||
|
||||
if attentions is not None:
|
||||
entropix_enabled = True
|
||||
|
||||
# this might actually slow things down a bit slightly-er?
|
||||
# move to CPU for speedups
|
||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
|
||||
# to-do: not make it hardcoded to bsz=1
|
||||
metrics = entropy[0]
|
||||
logit = logits[0]
|
||||
res = [ sample_entropix(
|
||||
logit,
|
||||
attentions[-1], #torch.stack(attentions, dim=1),
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
min_p,
|
||||
) for logit in logits ]
|
||||
|
||||
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
|
||||
attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
|
||||
agreement = metrics["agreement"]
|
||||
interaction_strength = metrics["interaction_strength"]
|
||||
|
||||
# adjust sample settings
|
||||
cfg = EntropixSamplerConfig()
|
||||
|
||||
entropy[0]["action"] = -1
|
||||
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
|
||||
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
|
||||
entropy[0]["action"] = 0
|
||||
temperature *= 0
|
||||
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
|
||||
elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh:
|
||||
entropy[0]["action"] = 1
|
||||
# sample with slightly higher temperature
|
||||
temperature *= cfg.helv_attn_ent_offset + cfg.helv_attn_ent_coef * attn_ent # Increase temperature based on attention entropy
|
||||
# Low Entropy, High Varentropy: "exploring forks in the path"
|
||||
elif ent < cfg.high_ent_thresh and vent > cfg.high_vent_thresh:
|
||||
entropy[0]["action"] = 2
|
||||
temperature *= cfg.lehv_interaction_strength_offset + cfg.lehv_interaction_strength_coef * interaction_strength # Increase temperature based on interaction strength
|
||||
top_k = max(5, int(top_k * (1 + 0.5 * (1 - agreement)))) # Increase top_k when agreement is low
|
||||
# High Entropy, High Varentropy: "resampling in the mist"
|
||||
elif ent > cfg.med_ent_thresh and vent > cfg.high_vent_thresh:
|
||||
entropy[0]["action"] = 3
|
||||
# Use high temperature and adjusted top_p based on attention metrics
|
||||
temperature *= cfg.hehv_attn_vent_offset + cfg.hehv_attn_vent_coef * attn_vent # Increase temperature based on attention varentropy
|
||||
top_p = max(0.5, top_p - cfg.hehv_attn_ent_coef * attn_ent) # Decrease top_p when attention entropy is high
|
||||
# Middle ground: use adaptive sampling
|
||||
else:
|
||||
entropy[0]["action"] = 4
|
||||
log_softmax = torch.nn.functional.log_softmax(logit)
|
||||
logits_uncertainty = ent + vent
|
||||
attn_uncertainty = attn_ent + attn_vent
|
||||
|
||||
temperature = temperature * float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement)
|
||||
top_p = float(torch.clip(top_p * (1 + cfg.ada_top_p * attn_vent), min=0.1, max=1.0))
|
||||
top_k = int(torch.clip(
|
||||
torch.round(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)),
|
||||
min=cfg.top_k_min,
|
||||
max=cfg.top_k_max
|
||||
))
|
||||
min_p = float(torch.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5))
|
||||
temperature = clamp( temperature, cfg.temperature_min, cfg.temperature_max )
|
||||
|
||||
def _sample( logits ):
|
||||
# perform repetition penalizing
|
||||
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
|
||||
# to-do: figure out a faster way to handle tolist()
|
||||
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
|
||||
# (AR) perform length penalizing
|
||||
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
|
||||
if min_p > 0.0:
|
||||
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
|
||||
|
||||
# perform top_k/top_p filtering of our logits
|
||||
if top_k > 0 or top_p < 1.0:
|
||||
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
||||
|
||||
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
|
||||
# epsilon float comparison because I don't trust Python
|
||||
if abs(temperature - min_temperature) >= 0.001:
|
||||
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
|
||||
else:
|
||||
logits = [ logit / temperature for logit in logits ]
|
||||
|
||||
# do DRY sampling
|
||||
if dry_multiplier > 0.0:
|
||||
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ]
|
||||
|
||||
return [ Categorical(logits=logit).sample() for logit in logits ]
|
||||
|
||||
if entropix_enabled:
|
||||
samples = [ _sample([ logit.clone() for logit in logits ]) for _ in range(cfg.n_adaptive_samples) ]
|
||||
|
||||
def score_sample(sample):
|
||||
one_hot = torch.nn.functional.one_hot(sample[0], logit.shape[-1])
|
||||
log_prob = torch.sum(log_softmax * one_hot)
|
||||
|
||||
confidence_score = (
|
||||
(1 - ent) * cfg.ada_score_logits_ent +
|
||||
(1 - attn_ent) * cfg.ada_score_attn_ent +
|
||||
(1 - vent) * cfg.ada_score_logits_vent +
|
||||
(1 - attn_vent) * cfg.ada_score_attn_vent +
|
||||
agreement * cfg.ada_score_agree +
|
||||
interaction_strength * cfg.ada_score_int
|
||||
)
|
||||
return log_prob + confidence_score
|
||||
|
||||
sample_scores = [ score_sample(sample) for sample in samples ]
|
||||
best_sample_idx = torch.argmax(torch.asarray(sample_scores))
|
||||
|
||||
res = samples[best_sample_idx]
|
||||
scores = sample_scores
|
||||
return Sampled(res, scores, entropy)
|
||||
|
||||
temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max )
|
||||
min_temperature = temperature
|
||||
|
||||
entropy[0]["temperature"] = temperature
|
||||
entropy[0]["top_k"] = top_k
|
||||
entropy[0]["top_p"] = top_p
|
||||
entropy[0]["min_p"] = min_p
|
||||
|
||||
if not entropix_enabled:
|
||||
temperature = 1.0
|
||||
min_temperature = 1.0
|
||||
top_k = 0
|
||||
top_p = 1.0
|
||||
min_p = 0.0
|
||||
if res:
|
||||
return Sampled([ r[0] for r in res], scores, [ r[1] for r in res])
|
||||
|
||||
# (NAR) disable stop token
|
||||
if quant_levels is not None and "ar" in self.capabilities:
|
||||
|
|
|
@ -258,6 +258,7 @@ def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ):
|
|||
"attn_varentropy": torch.mean(attn_varentropy),
|
||||
"agreement": torch.mean(agreement),
|
||||
"interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)),
|
||||
"action": -1
|
||||
}
|
||||
|
||||
# to-do: play around with these values
|
||||
|
@ -304,7 +305,135 @@ class EntropixSamplerConfig:
|
|||
ada_score_int: float = 0.6
|
||||
|
||||
# extra stuff
|
||||
temperature_max: float = 1.25
|
||||
temperature_min: float = 0.5
|
||||
top_k_min: int = 1
|
||||
top_k_max: int = 1024
|
||||
temperature_max: float = 1.25
|
||||
temperature_min: float = 0.5
|
||||
top_p_min: int = 0.1
|
||||
top_p_max: int = 1.0
|
||||
min_p_min: int = 0.01
|
||||
min_p_max: int = 0.5
|
||||
|
||||
Exponential = torch.distributions.exponential.Exponential(1.0)
|
||||
def _sample_entropix(
|
||||
logits,
|
||||
temperature=1.0,
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
min_p=0.0,
|
||||
cfg=EntropixSamplerConfig(),
|
||||
):
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
||||
if top_k == 0:
|
||||
top_k = logits.shape[-1]
|
||||
|
||||
temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max )
|
||||
top_p = clamp( float(top_p), cfg.top_p_min, cfg.top_p_max )
|
||||
top_k = clamp( int(top_k), cfg.top_k_min, cfg.top_k_max )
|
||||
min_p = clamp( float(min_p), cfg.min_p_min, cfg.min_p_max )
|
||||
|
||||
probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
|
||||
|
||||
# Apply min_p sampling
|
||||
if min_p > 0.0:
|
||||
p_max = float(torch.max(probs, dim=-1, keepdims=True).values)
|
||||
indices_to_remove = probs < (min_p * p_max)
|
||||
logits = torch.where(indices_to_remove, torch.full_like(logits, float('-inf')), logits)
|
||||
|
||||
# Apply top-k sampling
|
||||
top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
|
||||
probs_sort = torch.flip(top_k_probs, dims=[-1])
|
||||
probs_idx = torch.flip(top_k_indices, dims=[-1])
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
|
||||
# Apply top-p sampling
|
||||
mask = torch.where(probs_sum - probs_sort > top_p, 1.0, 0.0)
|
||||
probs_sort = probs_sort * (1 - mask)
|
||||
probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdims=True)
|
||||
|
||||
next_token = torch.argmax(probs_sort / Exponential.sample(probs_sort.shape), dim=-1, keepdim=True)
|
||||
return torch.take_along_dim(probs_idx, next_token, dim=-1)
|
||||
|
||||
def sample_entropix(
|
||||
logits,
|
||||
attentions,
|
||||
temperature=1.0,
|
||||
top_k=32,
|
||||
top_p=1.0,
|
||||
min_p=0.0,
|
||||
cfg=EntropixSamplerConfig(),
|
||||
):
|
||||
metrics = calculate_entropix_metrics( logits, attentions )
|
||||
|
||||
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
|
||||
attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
|
||||
agreement = metrics["agreement"]
|
||||
interaction_strength = metrics["interaction_strength"]
|
||||
|
||||
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
|
||||
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
|
||||
metrics["action"] = 0
|
||||
res = logits.argmax(dim=1)
|
||||
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
|
||||
elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh:
|
||||
metrics["action"] = 1
|
||||
# sample with slightly higher temperature
|
||||
temperature *= cfg.helv_attn_ent_offset + cfg.helv_attn_ent_coef * attn_ent # Increase temperature based on attention entropy
|
||||
res = _sample_entropix( logits, temperature, top_k, top_p, min_p, cfg=cfg )
|
||||
# Low Entropy, High Varentropy: "exploring forks in the path"
|
||||
elif ent < cfg.high_ent_thresh and vent > cfg.high_vent_thresh:
|
||||
metrics["action"] = 2
|
||||
temperature *= cfg.lehv_interaction_strength_offset + cfg.lehv_interaction_strength_coef * interaction_strength # Increase temperature based on interaction strength
|
||||
top_k = max(5, int(top_k * (1 + 0.5 * (1 - agreement)))) # Increase top_k when agreement is low
|
||||
res = _sample_entropix( logits, temperature, top_k, top_p, min_p, cfg=cfg )
|
||||
# High Entropy, High Varentropy: "resampling in the mist"
|
||||
elif ent > cfg.med_ent_thresh and vent > cfg.high_vent_thresh:
|
||||
metrics["action"] = 3
|
||||
# Use high temperature and adjusted top_p based on attention metrics
|
||||
temperature *= cfg.hehv_attn_vent_offset + cfg.hehv_attn_vent_coef * attn_vent # Increase temperature based on attention varentropy
|
||||
top_p = max(0.5, top_p - cfg.hehv_attn_ent_coef * attn_ent) # Decrease top_p when attention entropy is high
|
||||
res = _sample_entropix( logits, temperature, top_k, top_p, min_p, cfg=cfg )
|
||||
# Middle ground: use adaptive sampling
|
||||
else:
|
||||
metrics["action"] = 4
|
||||
|
||||
log_softmax = torch.nn.functional.log_softmax(logits)
|
||||
logits_uncertainty = ent + vent
|
||||
attn_uncertainty = attn_ent + attn_vent
|
||||
|
||||
temperature *= float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement)
|
||||
top_p = float(top_p * (1 + cfg.ada_top_p * attn_vent))
|
||||
top_k = int(round(float(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement))))
|
||||
min_p = float(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty))
|
||||
|
||||
samples = [ _sample_entropix( logits.clone(), temperature, top_k, top_p, min_p, cfg=cfg ) for _ in range(cfg.n_adaptive_samples) ]
|
||||
|
||||
def score_sample(sample):
|
||||
one_hot = torch.nn.functional.one_hot( sample, logits.shape[-1] )
|
||||
log_prob = torch.sum(log_softmax * one_hot)
|
||||
|
||||
confidence_score = (
|
||||
(1 - ent) * cfg.ada_score_logits_ent +
|
||||
(1 - attn_ent) * cfg.ada_score_attn_ent +
|
||||
(1 - vent) * cfg.ada_score_logits_vent +
|
||||
(1 - attn_vent) * cfg.ada_score_attn_vent +
|
||||
agreement * cfg.ada_score_agree +
|
||||
interaction_strength * cfg.ada_score_int
|
||||
)
|
||||
return log_prob + confidence_score
|
||||
|
||||
sample_scores = [ score_sample(sample) for sample in samples ]
|
||||
best_sample_idx = torch.argmax(torch.asarray(sample_scores))
|
||||
|
||||
res = samples[best_sample_idx]
|
||||
|
||||
"""
|
||||
metrics["temperature"] = temperature
|
||||
metrics["top_k"] = top_k
|
||||
metrics["top_p"] = top_p
|
||||
metrics["min_p"] = min_p
|
||||
"""
|
||||
|
||||
return res, metrics
|
Loading…
Reference in New Issue
Block a user