entropix apparently processes the entire sequence of logits but it falls apart when doing that
This commit is contained in:
parent
c800d28bb8
commit
84005c5b00
|
@ -292,6 +292,10 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
# cringe
|
||||
if attn_scores is None and output_attentions:
|
||||
attn_scores = attn_output
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
|
|
|
@ -1498,6 +1498,26 @@ class Base(nn.Module):
|
|||
scores = None
|
||||
entropy = None
|
||||
|
||||
# (AR) entropix sampling
|
||||
# we do it before everything to retain logits for the entire sequence (even though it's still better to pass only the last token)
|
||||
if attentions is not None and quant_levels is None:
|
||||
# move to CPU for speedups
|
||||
seq_lens = [ logit.shape[0] for logit in logits ]
|
||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
attentions = torch.stack(attentions, dim=1).to(device="cpu") # ( batch, layer, heads, seq_len, seq_len )
|
||||
|
||||
res = [ sample_entropix(
|
||||
logit[:seq_lens[batch], :], # ( seq_len, vocab )
|
||||
attentions[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, seq_len )
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
min_p,
|
||||
) for batch, logit in enumerate(logits) ]
|
||||
|
||||
if res:
|
||||
return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ])
|
||||
|
||||
# (NAR) return the entire generated response
|
||||
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
|
||||
|
@ -1530,24 +1550,6 @@ class Base(nn.Module):
|
|||
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
|
||||
# (AR) entropix sampling
|
||||
# we do it after the penalizers because entropix's internal sampling doesn't account for them (but does do top_k/top_p/min_p)
|
||||
if attentions is not None and quant_levels is None:
|
||||
# move to CPU for speedups
|
||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
|
||||
res = [ sample_entropix(
|
||||
logit,
|
||||
torch.stack(attentions, dim=1)[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, ? ), our attention scores might be padded
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
min_p,
|
||||
) for batch, logit in enumerate(logits) ]
|
||||
|
||||
if res:
|
||||
return Sampled([ r[0] for r in res], scores, [ r[1] for r in res])
|
||||
|
||||
# perform min_p filtering of our logits
|
||||
if min_p > 0.0:
|
||||
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
|
||||
|
|
|
@ -243,12 +243,12 @@ def calculate_entropix_metrics( logits, attentions=None, dim=-1, use_stats=False
|
|||
"logits_varentropy": torch.mean(varentropy).item(),
|
||||
}
|
||||
|
||||
last_attention_scores = attentions[-1].unsqueeze(0) # ( bsz, heads, seq_len, ? )
|
||||
last_attention_scores = attentions[-1].unsqueeze(0) # ( bsz, heads, seq_len, seq_len )
|
||||
attention_probs = F.softmax(last_attention_scores, dim=-1)
|
||||
if use_stats:
|
||||
attn_stats = AttnStats.new( 1, attentions.shape[0], attentions.shape[1], logits.device )
|
||||
for idx, attn in enumerate( attentions ):
|
||||
attn_stats.update( attn.unsqueeze(0)[:, :, -1, :], idx ) # (bsz, heads, last_token, ?)
|
||||
attn_stats.update( attn.unsqueeze(0)[:, :, -1, :], idx ) # (bsz, heads, last_token, seq_len)
|
||||
attn_entropy = attn_stats.entropy
|
||||
attn_varentropy = attn_stats.varentropy
|
||||
else:
|
||||
|
@ -382,18 +382,20 @@ def _sample_entropix(
|
|||
if top_k == 0:
|
||||
top_k = logits.shape[-1]
|
||||
|
||||
logit = logits[-1, :]
|
||||
|
||||
temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max )
|
||||
top_p = clamp( float(top_p), cfg.top_p_min, cfg.top_p_max )
|
||||
top_k = clamp( int(top_k), cfg.top_k_min, cfg.top_k_max )
|
||||
min_p = clamp( float(min_p), cfg.min_p_min, cfg.min_p_max )
|
||||
|
||||
probs = F.softmax(logits / temperature, dim=-1)
|
||||
probs = F.softmax(logit / temperature, dim=-1)
|
||||
|
||||
# Apply min_p sampling
|
||||
if min_p > 0.0:
|
||||
p_max = float(torch.max(probs, dim=-1, keepdim=True).values)
|
||||
indices_to_remove = probs < (min_p * p_max)
|
||||
logits = torch.where(indices_to_remove, torch.full_like(logits, float('-inf')), logits)
|
||||
logit = torch.where(indices_to_remove, torch.full_like(logit, float('-inf')), logit)
|
||||
|
||||
# Apply top-k sampling
|
||||
top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]))
|
||||
|
@ -401,21 +403,16 @@ def _sample_entropix(
|
|||
probs_idx = torch.flip(top_k_indices, dims=[-1])
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
# Apply top-p sampling
|
||||
mask = torch.where(probs_sum - probs_sort > top_p, torch.tensor(1.0, device=logits.device), torch.tensor(0.0, device=logits.device))
|
||||
mask = torch.where(probs_sum - probs_sort > top_p, torch.tensor(1.0, device=logit.device), torch.tensor(0.0, device=logit.device))
|
||||
probs_sort = probs_sort * (1 - mask)
|
||||
probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdim=True)
|
||||
|
||||
q = Exponential.sample(probs_sort.shape)
|
||||
"""
|
||||
next_token = torch.argmax(probs_sort / q, dim=-1, keepdim=True)
|
||||
return torch.take_along_dim(probs_idx, next_token, dim=-1)[0]
|
||||
"""
|
||||
|
||||
"""
|
||||
# q = torch.rand(probs_sort.shape, generator=generator, device=probs_sort.device)
|
||||
"""
|
||||
next_token = torch.argmax(probs_sort / q, dim=-1, keepdim=True)
|
||||
next_token_g = torch.take_along_dim(probs_idx, next_token, dim=-1)[0]
|
||||
next_token_g = torch.take_along_dim(probs_idx, next_token, dim=-1)
|
||||
return next_token_g
|
||||
|
||||
def sample_entropix(
|
||||
|
@ -433,14 +430,9 @@ def sample_entropix(
|
|||
top_p = cfg.top_p
|
||||
"""
|
||||
|
||||
"""
|
||||
if generator is None:
|
||||
generator = torch.Generator(device=logits.device).manual_seed(int(time.time()))
|
||||
"""
|
||||
|
||||
# logits: ( bsz, vocab )
|
||||
# attentions: ( bsz, layer, heads, seq_len, ? )
|
||||
metrics = calculate_entropix_metrics( logits, attentions )
|
||||
# logits: ( seq_len, vocab )
|
||||
# attentions: ( layer, heads, seq_len, seq_len )
|
||||
metrics = calculate_entropix_metrics( logits[-1:, :], attentions[:, :, -1:, :] )
|
||||
|
||||
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
|
||||
attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
|
||||
|
@ -450,7 +442,7 @@ def sample_entropix(
|
|||
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
|
||||
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
|
||||
metrics["action"] = 0
|
||||
res = logits.argmax(dim=1)
|
||||
res = logits[-1, :].argmax(dim=1)
|
||||
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
|
||||
elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh:
|
||||
metrics["action"] = 1
|
||||
|
@ -510,6 +502,13 @@ def sample_entropix(
|
|||
|
||||
res = samples[best_sample_idx]
|
||||
|
||||
"""
|
||||
metrics = {
|
||||
"attn_entropy": metrics["attn_entropy"],
|
||||
"attn_varentropy": metrics["attn_varentropy"],
|
||||
}
|
||||
"""
|
||||
|
||||
"""
|
||||
metrics["temperature"] = temperature
|
||||
metrics["top_k"] = top_k
|
||||
|
|
Loading…
Reference in New Issue
Block a user