entropix apparently processes the entire sequence of logits but it falls apart when doing that

This commit is contained in:
mrq 2024-10-13 12:01:12 -05:00
parent c800d28bb8
commit 84005c5b00
3 changed files with 43 additions and 38 deletions

View File

@ -292,6 +292,10 @@ class LlamaAttention_Adapted(LlamaAttention):
is_causal=is_causal,
)
# cringe
if attn_scores is None and output_attentions:
attn_scores = attn_output
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)

View File

@ -1498,6 +1498,26 @@ class Base(nn.Module):
scores = None
entropy = None
# (AR) entropix sampling
# we do it before everything to retain logits for the entire sequence (even though it's still better to pass only the last token)
if attentions is not None and quant_levels is None:
# move to CPU for speedups
seq_lens = [ logit.shape[0] for logit in logits ]
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
attentions = torch.stack(attentions, dim=1).to(device="cpu") # ( batch, layer, heads, seq_len, seq_len )
res = [ sample_entropix(
logit[:seq_lens[batch], :], # ( seq_len, vocab )
attentions[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, seq_len )
temperature,
top_k,
top_p,
min_p,
) for batch, logit in enumerate(logits) ]
if res:
return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ])
# (NAR) return the entire generated response
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
@ -1530,24 +1550,6 @@ class Base(nn.Module):
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
# (AR) entropix sampling
# we do it after the penalizers because entropix's internal sampling doesn't account for them (but does do top_k/top_p/min_p)
if attentions is not None and quant_levels is None:
# move to CPU for speedups
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
res = [ sample_entropix(
logit,
torch.stack(attentions, dim=1)[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, ? ), our attention scores might be padded
temperature,
top_k,
top_p,
min_p,
) for batch, logit in enumerate(logits) ]
if res:
return Sampled([ r[0] for r in res], scores, [ r[1] for r in res])
# perform min_p filtering of our logits
if min_p > 0.0:
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]

View File

@ -243,12 +243,12 @@ def calculate_entropix_metrics( logits, attentions=None, dim=-1, use_stats=False
"logits_varentropy": torch.mean(varentropy).item(),
}
last_attention_scores = attentions[-1].unsqueeze(0) # ( bsz, heads, seq_len, ? )
last_attention_scores = attentions[-1].unsqueeze(0) # ( bsz, heads, seq_len, seq_len )
attention_probs = F.softmax(last_attention_scores, dim=-1)
if use_stats:
attn_stats = AttnStats.new( 1, attentions.shape[0], attentions.shape[1], logits.device )
for idx, attn in enumerate( attentions ):
attn_stats.update( attn.unsqueeze(0)[:, :, -1, :], idx ) # (bsz, heads, last_token, ?)
attn_stats.update( attn.unsqueeze(0)[:, :, -1, :], idx ) # (bsz, heads, last_token, seq_len)
attn_entropy = attn_stats.entropy
attn_varentropy = attn_stats.varentropy
else:
@ -382,18 +382,20 @@ def _sample_entropix(
if top_k == 0:
top_k = logits.shape[-1]
logit = logits[-1, :]
temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max )
top_p = clamp( float(top_p), cfg.top_p_min, cfg.top_p_max )
top_k = clamp( int(top_k), cfg.top_k_min, cfg.top_k_max )
min_p = clamp( float(min_p), cfg.min_p_min, cfg.min_p_max )
probs = F.softmax(logits / temperature, dim=-1)
probs = F.softmax(logit / temperature, dim=-1)
# Apply min_p sampling
if min_p > 0.0:
p_max = float(torch.max(probs, dim=-1, keepdim=True).values)
indices_to_remove = probs < (min_p * p_max)
logits = torch.where(indices_to_remove, torch.full_like(logits, float('-inf')), logits)
logit = torch.where(indices_to_remove, torch.full_like(logit, float('-inf')), logit)
# Apply top-k sampling
top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]))
@ -401,21 +403,16 @@ def _sample_entropix(
probs_idx = torch.flip(top_k_indices, dims=[-1])
probs_sum = torch.cumsum(probs_sort, dim=-1)
# Apply top-p sampling
mask = torch.where(probs_sum - probs_sort > top_p, torch.tensor(1.0, device=logits.device), torch.tensor(0.0, device=logits.device))
mask = torch.where(probs_sum - probs_sort > top_p, torch.tensor(1.0, device=logit.device), torch.tensor(0.0, device=logit.device))
probs_sort = probs_sort * (1 - mask)
probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdim=True)
q = Exponential.sample(probs_sort.shape)
"""
next_token = torch.argmax(probs_sort / q, dim=-1, keepdim=True)
return torch.take_along_dim(probs_idx, next_token, dim=-1)[0]
"""
"""
# q = torch.rand(probs_sort.shape, generator=generator, device=probs_sort.device)
"""
next_token = torch.argmax(probs_sort / q, dim=-1, keepdim=True)
next_token_g = torch.take_along_dim(probs_idx, next_token, dim=-1)[0]
next_token_g = torch.take_along_dim(probs_idx, next_token, dim=-1)
return next_token_g
def sample_entropix(
@ -433,14 +430,9 @@ def sample_entropix(
top_p = cfg.top_p
"""
"""
if generator is None:
generator = torch.Generator(device=logits.device).manual_seed(int(time.time()))
"""
# logits: ( bsz, vocab )
# attentions: ( bsz, layer, heads, seq_len, ? )
metrics = calculate_entropix_metrics( logits, attentions )
# logits: ( seq_len, vocab )
# attentions: ( layer, heads, seq_len, seq_len )
metrics = calculate_entropix_metrics( logits[-1:, :], attentions[:, :, -1:, :] )
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
@ -450,7 +442,7 @@ def sample_entropix(
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
metrics["action"] = 0
res = logits.argmax(dim=1)
res = logits[-1, :].argmax(dim=1)
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh:
metrics["action"] = 1
@ -510,6 +502,13 @@ def sample_entropix(
res = samples[best_sample_idx]
"""
metrics = {
"attn_entropy": metrics["attn_entropy"],
"attn_varentropy": metrics["attn_varentropy"],
}
"""
"""
metrics["temperature"] = temperature
metrics["top_k"] = top_k