entropix tweaks (it doesn't output garbage but it loves to go for silence)
This commit is contained in:
parent
d0ab7d755a
commit
d6f7c86a5c
|
@ -1507,135 +1507,22 @@ class Base(nn.Module):
|
||||||
elif self.causal:
|
elif self.causal:
|
||||||
logits = [ logit[-self.causal_size:] for logit in logits ]
|
logits = [ logit[-self.causal_size:] for logit in logits ]
|
||||||
|
|
||||||
# calculate entropies
|
# entropix sampling
|
||||||
# I would love to shove it in samplers.py but we modify our sampler settings
|
|
||||||
if attentions is not None:
|
if attentions is not None:
|
||||||
entropy = [ calculate_entropix_metrics( logit, attn ) for logit, attn in zip(logits, attentions) ]
|
# move to CPU for speedups
|
||||||
|
|
||||||
if attentions is not None:
|
|
||||||
entropix_enabled = True
|
|
||||||
|
|
||||||
# this might actually slow things down a bit slightly-er?
|
|
||||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||||
|
|
||||||
# to-do: not make it hardcoded to bsz=1
|
res = [ sample_entropix(
|
||||||
metrics = entropy[0]
|
logit,
|
||||||
logit = logits[0]
|
attentions[-1], #torch.stack(attentions, dim=1),
|
||||||
|
temperature,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
min_p,
|
||||||
|
) for logit in logits ]
|
||||||
|
|
||||||
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
|
if res:
|
||||||
attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
|
return Sampled([ r[0] for r in res], scores, [ r[1] for r in res])
|
||||||
agreement = metrics["agreement"]
|
|
||||||
interaction_strength = metrics["interaction_strength"]
|
|
||||||
|
|
||||||
# adjust sample settings
|
|
||||||
cfg = EntropixSamplerConfig()
|
|
||||||
|
|
||||||
entropy[0]["action"] = -1
|
|
||||||
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
|
|
||||||
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
|
|
||||||
entropy[0]["action"] = 0
|
|
||||||
temperature *= 0
|
|
||||||
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
|
|
||||||
elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh:
|
|
||||||
entropy[0]["action"] = 1
|
|
||||||
# sample with slightly higher temperature
|
|
||||||
temperature *= cfg.helv_attn_ent_offset + cfg.helv_attn_ent_coef * attn_ent # Increase temperature based on attention entropy
|
|
||||||
# Low Entropy, High Varentropy: "exploring forks in the path"
|
|
||||||
elif ent < cfg.high_ent_thresh and vent > cfg.high_vent_thresh:
|
|
||||||
entropy[0]["action"] = 2
|
|
||||||
temperature *= cfg.lehv_interaction_strength_offset + cfg.lehv_interaction_strength_coef * interaction_strength # Increase temperature based on interaction strength
|
|
||||||
top_k = max(5, int(top_k * (1 + 0.5 * (1 - agreement)))) # Increase top_k when agreement is low
|
|
||||||
# High Entropy, High Varentropy: "resampling in the mist"
|
|
||||||
elif ent > cfg.med_ent_thresh and vent > cfg.high_vent_thresh:
|
|
||||||
entropy[0]["action"] = 3
|
|
||||||
# Use high temperature and adjusted top_p based on attention metrics
|
|
||||||
temperature *= cfg.hehv_attn_vent_offset + cfg.hehv_attn_vent_coef * attn_vent # Increase temperature based on attention varentropy
|
|
||||||
top_p = max(0.5, top_p - cfg.hehv_attn_ent_coef * attn_ent) # Decrease top_p when attention entropy is high
|
|
||||||
# Middle ground: use adaptive sampling
|
|
||||||
else:
|
|
||||||
entropy[0]["action"] = 4
|
|
||||||
log_softmax = torch.nn.functional.log_softmax(logit)
|
|
||||||
logits_uncertainty = ent + vent
|
|
||||||
attn_uncertainty = attn_ent + attn_vent
|
|
||||||
|
|
||||||
temperature = temperature * float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement)
|
|
||||||
top_p = float(torch.clip(top_p * (1 + cfg.ada_top_p * attn_vent), min=0.1, max=1.0))
|
|
||||||
top_k = int(torch.clip(
|
|
||||||
torch.round(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)),
|
|
||||||
min=cfg.top_k_min,
|
|
||||||
max=cfg.top_k_max
|
|
||||||
))
|
|
||||||
min_p = float(torch.clip(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty), 0.01, 0.5))
|
|
||||||
temperature = clamp( temperature, cfg.temperature_min, cfg.temperature_max )
|
|
||||||
|
|
||||||
def _sample( logits ):
|
|
||||||
# perform repetition penalizing
|
|
||||||
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
|
|
||||||
# to-do: figure out a faster way to handle tolist()
|
|
||||||
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
|
||||||
|
|
||||||
# (AR) perform length penalizing
|
|
||||||
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
|
|
||||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
|
||||||
|
|
||||||
if min_p > 0.0:
|
|
||||||
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
|
|
||||||
|
|
||||||
# perform top_k/top_p filtering of our logits
|
|
||||||
if top_k > 0 or top_p < 1.0:
|
|
||||||
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
|
||||||
|
|
||||||
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
|
|
||||||
# epsilon float comparison because I don't trust Python
|
|
||||||
if abs(temperature - min_temperature) >= 0.001:
|
|
||||||
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
|
|
||||||
else:
|
|
||||||
logits = [ logit / temperature for logit in logits ]
|
|
||||||
|
|
||||||
# do DRY sampling
|
|
||||||
if dry_multiplier > 0.0:
|
|
||||||
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ]
|
|
||||||
|
|
||||||
return [ Categorical(logits=logit).sample() for logit in logits ]
|
|
||||||
|
|
||||||
if entropix_enabled:
|
|
||||||
samples = [ _sample([ logit.clone() for logit in logits ]) for _ in range(cfg.n_adaptive_samples) ]
|
|
||||||
|
|
||||||
def score_sample(sample):
|
|
||||||
one_hot = torch.nn.functional.one_hot(sample[0], logit.shape[-1])
|
|
||||||
log_prob = torch.sum(log_softmax * one_hot)
|
|
||||||
|
|
||||||
confidence_score = (
|
|
||||||
(1 - ent) * cfg.ada_score_logits_ent +
|
|
||||||
(1 - attn_ent) * cfg.ada_score_attn_ent +
|
|
||||||
(1 - vent) * cfg.ada_score_logits_vent +
|
|
||||||
(1 - attn_vent) * cfg.ada_score_attn_vent +
|
|
||||||
agreement * cfg.ada_score_agree +
|
|
||||||
interaction_strength * cfg.ada_score_int
|
|
||||||
)
|
|
||||||
return log_prob + confidence_score
|
|
||||||
|
|
||||||
sample_scores = [ score_sample(sample) for sample in samples ]
|
|
||||||
best_sample_idx = torch.argmax(torch.asarray(sample_scores))
|
|
||||||
|
|
||||||
res = samples[best_sample_idx]
|
|
||||||
scores = sample_scores
|
|
||||||
return Sampled(res, scores, entropy)
|
|
||||||
|
|
||||||
temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max )
|
|
||||||
min_temperature = temperature
|
|
||||||
|
|
||||||
entropy[0]["temperature"] = temperature
|
|
||||||
entropy[0]["top_k"] = top_k
|
|
||||||
entropy[0]["top_p"] = top_p
|
|
||||||
entropy[0]["min_p"] = min_p
|
|
||||||
|
|
||||||
if not entropix_enabled:
|
|
||||||
temperature = 1.0
|
|
||||||
min_temperature = 1.0
|
|
||||||
top_k = 0
|
|
||||||
top_p = 1.0
|
|
||||||
min_p = 0.0
|
|
||||||
|
|
||||||
# (NAR) disable stop token
|
# (NAR) disable stop token
|
||||||
if quant_levels is not None and "ar" in self.capabilities:
|
if quant_levels is not None and "ar" in self.capabilities:
|
||||||
|
|
|
@ -258,6 +258,7 @@ def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ):
|
||||||
"attn_varentropy": torch.mean(attn_varentropy),
|
"attn_varentropy": torch.mean(attn_varentropy),
|
||||||
"agreement": torch.mean(agreement),
|
"agreement": torch.mean(agreement),
|
||||||
"interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)),
|
"interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)),
|
||||||
|
"action": -1
|
||||||
}
|
}
|
||||||
|
|
||||||
# to-do: play around with these values
|
# to-do: play around with these values
|
||||||
|
@ -304,7 +305,135 @@ class EntropixSamplerConfig:
|
||||||
ada_score_int: float = 0.6
|
ada_score_int: float = 0.6
|
||||||
|
|
||||||
# extra stuff
|
# extra stuff
|
||||||
top_k_min: int = 1
|
|
||||||
top_k_max: int = 1024
|
|
||||||
temperature_max: float = 1.25
|
temperature_max: float = 1.25
|
||||||
temperature_min: float = 0.5
|
temperature_min: float = 0.5
|
||||||
|
top_k_min: int = 1
|
||||||
|
top_k_max: int = 1024
|
||||||
|
top_p_min: int = 0.1
|
||||||
|
top_p_max: int = 1.0
|
||||||
|
min_p_min: int = 0.01
|
||||||
|
min_p_max: int = 0.5
|
||||||
|
|
||||||
|
Exponential = torch.distributions.exponential.Exponential(1.0)
|
||||||
|
def _sample_entropix(
|
||||||
|
logits,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=0,
|
||||||
|
top_p=1.0,
|
||||||
|
min_p=0.0,
|
||||||
|
cfg=EntropixSamplerConfig(),
|
||||||
|
):
|
||||||
|
def clamp(n, lo, hi):
|
||||||
|
return max(lo, min(n, hi))
|
||||||
|
|
||||||
|
if top_k == 0:
|
||||||
|
top_k = logits.shape[-1]
|
||||||
|
|
||||||
|
temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max )
|
||||||
|
top_p = clamp( float(top_p), cfg.top_p_min, cfg.top_p_max )
|
||||||
|
top_k = clamp( int(top_k), cfg.top_k_min, cfg.top_k_max )
|
||||||
|
min_p = clamp( float(min_p), cfg.min_p_min, cfg.min_p_max )
|
||||||
|
|
||||||
|
probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
|
||||||
|
|
||||||
|
# Apply min_p sampling
|
||||||
|
if min_p > 0.0:
|
||||||
|
p_max = float(torch.max(probs, dim=-1, keepdims=True).values)
|
||||||
|
indices_to_remove = probs < (min_p * p_max)
|
||||||
|
logits = torch.where(indices_to_remove, torch.full_like(logits, float('-inf')), logits)
|
||||||
|
|
||||||
|
# Apply top-k sampling
|
||||||
|
top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
|
||||||
|
probs_sort = torch.flip(top_k_probs, dims=[-1])
|
||||||
|
probs_idx = torch.flip(top_k_indices, dims=[-1])
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
|
||||||
|
# Apply top-p sampling
|
||||||
|
mask = torch.where(probs_sum - probs_sort > top_p, 1.0, 0.0)
|
||||||
|
probs_sort = probs_sort * (1 - mask)
|
||||||
|
probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdims=True)
|
||||||
|
|
||||||
|
next_token = torch.argmax(probs_sort / Exponential.sample(probs_sort.shape), dim=-1, keepdim=True)
|
||||||
|
return torch.take_along_dim(probs_idx, next_token, dim=-1)
|
||||||
|
|
||||||
|
def sample_entropix(
|
||||||
|
logits,
|
||||||
|
attentions,
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=32,
|
||||||
|
top_p=1.0,
|
||||||
|
min_p=0.0,
|
||||||
|
cfg=EntropixSamplerConfig(),
|
||||||
|
):
|
||||||
|
metrics = calculate_entropix_metrics( logits, attentions )
|
||||||
|
|
||||||
|
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
|
||||||
|
attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
|
||||||
|
agreement = metrics["agreement"]
|
||||||
|
interaction_strength = metrics["interaction_strength"]
|
||||||
|
|
||||||
|
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
|
||||||
|
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
|
||||||
|
metrics["action"] = 0
|
||||||
|
res = logits.argmax(dim=1)
|
||||||
|
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
|
||||||
|
elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh:
|
||||||
|
metrics["action"] = 1
|
||||||
|
# sample with slightly higher temperature
|
||||||
|
temperature *= cfg.helv_attn_ent_offset + cfg.helv_attn_ent_coef * attn_ent # Increase temperature based on attention entropy
|
||||||
|
res = _sample_entropix( logits, temperature, top_k, top_p, min_p, cfg=cfg )
|
||||||
|
# Low Entropy, High Varentropy: "exploring forks in the path"
|
||||||
|
elif ent < cfg.high_ent_thresh and vent > cfg.high_vent_thresh:
|
||||||
|
metrics["action"] = 2
|
||||||
|
temperature *= cfg.lehv_interaction_strength_offset + cfg.lehv_interaction_strength_coef * interaction_strength # Increase temperature based on interaction strength
|
||||||
|
top_k = max(5, int(top_k * (1 + 0.5 * (1 - agreement)))) # Increase top_k when agreement is low
|
||||||
|
res = _sample_entropix( logits, temperature, top_k, top_p, min_p, cfg=cfg )
|
||||||
|
# High Entropy, High Varentropy: "resampling in the mist"
|
||||||
|
elif ent > cfg.med_ent_thresh and vent > cfg.high_vent_thresh:
|
||||||
|
metrics["action"] = 3
|
||||||
|
# Use high temperature and adjusted top_p based on attention metrics
|
||||||
|
temperature *= cfg.hehv_attn_vent_offset + cfg.hehv_attn_vent_coef * attn_vent # Increase temperature based on attention varentropy
|
||||||
|
top_p = max(0.5, top_p - cfg.hehv_attn_ent_coef * attn_ent) # Decrease top_p when attention entropy is high
|
||||||
|
res = _sample_entropix( logits, temperature, top_k, top_p, min_p, cfg=cfg )
|
||||||
|
# Middle ground: use adaptive sampling
|
||||||
|
else:
|
||||||
|
metrics["action"] = 4
|
||||||
|
|
||||||
|
log_softmax = torch.nn.functional.log_softmax(logits)
|
||||||
|
logits_uncertainty = ent + vent
|
||||||
|
attn_uncertainty = attn_ent + attn_vent
|
||||||
|
|
||||||
|
temperature *= float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement)
|
||||||
|
top_p = float(top_p * (1 + cfg.ada_top_p * attn_vent))
|
||||||
|
top_k = int(round(float(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement))))
|
||||||
|
min_p = float(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty))
|
||||||
|
|
||||||
|
samples = [ _sample_entropix( logits.clone(), temperature, top_k, top_p, min_p, cfg=cfg ) for _ in range(cfg.n_adaptive_samples) ]
|
||||||
|
|
||||||
|
def score_sample(sample):
|
||||||
|
one_hot = torch.nn.functional.one_hot( sample, logits.shape[-1] )
|
||||||
|
log_prob = torch.sum(log_softmax * one_hot)
|
||||||
|
|
||||||
|
confidence_score = (
|
||||||
|
(1 - ent) * cfg.ada_score_logits_ent +
|
||||||
|
(1 - attn_ent) * cfg.ada_score_attn_ent +
|
||||||
|
(1 - vent) * cfg.ada_score_logits_vent +
|
||||||
|
(1 - attn_vent) * cfg.ada_score_attn_vent +
|
||||||
|
agreement * cfg.ada_score_agree +
|
||||||
|
interaction_strength * cfg.ada_score_int
|
||||||
|
)
|
||||||
|
return log_prob + confidence_score
|
||||||
|
|
||||||
|
sample_scores = [ score_sample(sample) for sample in samples ]
|
||||||
|
best_sample_idx = torch.argmax(torch.asarray(sample_scores))
|
||||||
|
|
||||||
|
res = samples[best_sample_idx]
|
||||||
|
|
||||||
|
"""
|
||||||
|
metrics["temperature"] = temperature
|
||||||
|
metrics["top_k"] = top_k
|
||||||
|
metrics["top_p"] = top_p
|
||||||
|
metrics["min_p"] = min_p
|
||||||
|
"""
|
||||||
|
|
||||||
|
return res, metrics
|
Loading…
Reference in New Issue
Block a user