entropix apparently processes the entire sequence of logits but it falls apart when doing that
This commit is contained in:
parent
c800d28bb8
commit
84005c5b00
|
@ -292,6 +292,10 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
is_causal=is_causal,
|
is_causal=is_causal,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# cringe
|
||||||
|
if attn_scores is None and output_attentions:
|
||||||
|
attn_scores = attn_output
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.view(bsz, q_len, -1)
|
attn_output = attn_output.view(bsz, q_len, -1)
|
||||||
|
|
||||||
|
|
|
@ -1498,6 +1498,26 @@ class Base(nn.Module):
|
||||||
scores = None
|
scores = None
|
||||||
entropy = None
|
entropy = None
|
||||||
|
|
||||||
|
# (AR) entropix sampling
|
||||||
|
# we do it before everything to retain logits for the entire sequence (even though it's still better to pass only the last token)
|
||||||
|
if attentions is not None and quant_levels is None:
|
||||||
|
# move to CPU for speedups
|
||||||
|
seq_lens = [ logit.shape[0] for logit in logits ]
|
||||||
|
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||||
|
attentions = torch.stack(attentions, dim=1).to(device="cpu") # ( batch, layer, heads, seq_len, seq_len )
|
||||||
|
|
||||||
|
res = [ sample_entropix(
|
||||||
|
logit[:seq_lens[batch], :], # ( seq_len, vocab )
|
||||||
|
attentions[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, seq_len )
|
||||||
|
temperature,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
min_p,
|
||||||
|
) for batch, logit in enumerate(logits) ]
|
||||||
|
|
||||||
|
if res:
|
||||||
|
return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ])
|
||||||
|
|
||||||
# (NAR) return the entire generated response
|
# (NAR) return the entire generated response
|
||||||
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||||
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
|
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
|
||||||
|
@ -1530,24 +1550,6 @@ class Base(nn.Module):
|
||||||
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
|
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
|
||||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||||
|
|
||||||
# (AR) entropix sampling
|
|
||||||
# we do it after the penalizers because entropix's internal sampling doesn't account for them (but does do top_k/top_p/min_p)
|
|
||||||
if attentions is not None and quant_levels is None:
|
|
||||||
# move to CPU for speedups
|
|
||||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
|
||||||
|
|
||||||
res = [ sample_entropix(
|
|
||||||
logit,
|
|
||||||
torch.stack(attentions, dim=1)[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, ? ), our attention scores might be padded
|
|
||||||
temperature,
|
|
||||||
top_k,
|
|
||||||
top_p,
|
|
||||||
min_p,
|
|
||||||
) for batch, logit in enumerate(logits) ]
|
|
||||||
|
|
||||||
if res:
|
|
||||||
return Sampled([ r[0] for r in res], scores, [ r[1] for r in res])
|
|
||||||
|
|
||||||
# perform min_p filtering of our logits
|
# perform min_p filtering of our logits
|
||||||
if min_p > 0.0:
|
if min_p > 0.0:
|
||||||
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
|
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
|
||||||
|
|
|
@ -243,12 +243,12 @@ def calculate_entropix_metrics( logits, attentions=None, dim=-1, use_stats=False
|
||||||
"logits_varentropy": torch.mean(varentropy).item(),
|
"logits_varentropy": torch.mean(varentropy).item(),
|
||||||
}
|
}
|
||||||
|
|
||||||
last_attention_scores = attentions[-1].unsqueeze(0) # ( bsz, heads, seq_len, ? )
|
last_attention_scores = attentions[-1].unsqueeze(0) # ( bsz, heads, seq_len, seq_len )
|
||||||
attention_probs = F.softmax(last_attention_scores, dim=-1)
|
attention_probs = F.softmax(last_attention_scores, dim=-1)
|
||||||
if use_stats:
|
if use_stats:
|
||||||
attn_stats = AttnStats.new( 1, attentions.shape[0], attentions.shape[1], logits.device )
|
attn_stats = AttnStats.new( 1, attentions.shape[0], attentions.shape[1], logits.device )
|
||||||
for idx, attn in enumerate( attentions ):
|
for idx, attn in enumerate( attentions ):
|
||||||
attn_stats.update( attn.unsqueeze(0)[:, :, -1, :], idx ) # (bsz, heads, last_token, ?)
|
attn_stats.update( attn.unsqueeze(0)[:, :, -1, :], idx ) # (bsz, heads, last_token, seq_len)
|
||||||
attn_entropy = attn_stats.entropy
|
attn_entropy = attn_stats.entropy
|
||||||
attn_varentropy = attn_stats.varentropy
|
attn_varentropy = attn_stats.varentropy
|
||||||
else:
|
else:
|
||||||
|
@ -382,18 +382,20 @@ def _sample_entropix(
|
||||||
if top_k == 0:
|
if top_k == 0:
|
||||||
top_k = logits.shape[-1]
|
top_k = logits.shape[-1]
|
||||||
|
|
||||||
|
logit = logits[-1, :]
|
||||||
|
|
||||||
temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max )
|
temperature = clamp( float(temperature), cfg.temperature_min, cfg.temperature_max )
|
||||||
top_p = clamp( float(top_p), cfg.top_p_min, cfg.top_p_max )
|
top_p = clamp( float(top_p), cfg.top_p_min, cfg.top_p_max )
|
||||||
top_k = clamp( int(top_k), cfg.top_k_min, cfg.top_k_max )
|
top_k = clamp( int(top_k), cfg.top_k_min, cfg.top_k_max )
|
||||||
min_p = clamp( float(min_p), cfg.min_p_min, cfg.min_p_max )
|
min_p = clamp( float(min_p), cfg.min_p_min, cfg.min_p_max )
|
||||||
|
|
||||||
probs = F.softmax(logits / temperature, dim=-1)
|
probs = F.softmax(logit / temperature, dim=-1)
|
||||||
|
|
||||||
# Apply min_p sampling
|
# Apply min_p sampling
|
||||||
if min_p > 0.0:
|
if min_p > 0.0:
|
||||||
p_max = float(torch.max(probs, dim=-1, keepdim=True).values)
|
p_max = float(torch.max(probs, dim=-1, keepdim=True).values)
|
||||||
indices_to_remove = probs < (min_p * p_max)
|
indices_to_remove = probs < (min_p * p_max)
|
||||||
logits = torch.where(indices_to_remove, torch.full_like(logits, float('-inf')), logits)
|
logit = torch.where(indices_to_remove, torch.full_like(logit, float('-inf')), logit)
|
||||||
|
|
||||||
# Apply top-k sampling
|
# Apply top-k sampling
|
||||||
top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]))
|
top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]))
|
||||||
|
@ -401,21 +403,16 @@ def _sample_entropix(
|
||||||
probs_idx = torch.flip(top_k_indices, dims=[-1])
|
probs_idx = torch.flip(top_k_indices, dims=[-1])
|
||||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
# Apply top-p sampling
|
# Apply top-p sampling
|
||||||
mask = torch.where(probs_sum - probs_sort > top_p, torch.tensor(1.0, device=logits.device), torch.tensor(0.0, device=logits.device))
|
mask = torch.where(probs_sum - probs_sort > top_p, torch.tensor(1.0, device=logit.device), torch.tensor(0.0, device=logit.device))
|
||||||
probs_sort = probs_sort * (1 - mask)
|
probs_sort = probs_sort * (1 - mask)
|
||||||
probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdim=True)
|
probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdim=True)
|
||||||
|
|
||||||
q = Exponential.sample(probs_sort.shape)
|
q = Exponential.sample(probs_sort.shape)
|
||||||
"""
|
|
||||||
next_token = torch.argmax(probs_sort / q, dim=-1, keepdim=True)
|
|
||||||
return torch.take_along_dim(probs_idx, next_token, dim=-1)[0]
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# q = torch.rand(probs_sort.shape, generator=generator, device=probs_sort.device)
|
# q = torch.rand(probs_sort.shape, generator=generator, device=probs_sort.device)
|
||||||
"""
|
"""
|
||||||
next_token = torch.argmax(probs_sort / q, dim=-1, keepdim=True)
|
next_token = torch.argmax(probs_sort / q, dim=-1, keepdim=True)
|
||||||
next_token_g = torch.take_along_dim(probs_idx, next_token, dim=-1)[0]
|
next_token_g = torch.take_along_dim(probs_idx, next_token, dim=-1)
|
||||||
return next_token_g
|
return next_token_g
|
||||||
|
|
||||||
def sample_entropix(
|
def sample_entropix(
|
||||||
|
@ -433,14 +430,9 @@ def sample_entropix(
|
||||||
top_p = cfg.top_p
|
top_p = cfg.top_p
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""
|
# logits: ( seq_len, vocab )
|
||||||
if generator is None:
|
# attentions: ( layer, heads, seq_len, seq_len )
|
||||||
generator = torch.Generator(device=logits.device).manual_seed(int(time.time()))
|
metrics = calculate_entropix_metrics( logits[-1:, :], attentions[:, :, -1:, :] )
|
||||||
"""
|
|
||||||
|
|
||||||
# logits: ( bsz, vocab )
|
|
||||||
# attentions: ( bsz, layer, heads, seq_len, ? )
|
|
||||||
metrics = calculate_entropix_metrics( logits, attentions )
|
|
||||||
|
|
||||||
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
|
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
|
||||||
attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
|
attn_ent, attn_vent = metrics["attn_entropy"], metrics["attn_varentropy"]
|
||||||
|
@ -450,7 +442,7 @@ def sample_entropix(
|
||||||
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
|
# Low Entropy, Low Varentropy: "flowing with unspoken intent"
|
||||||
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
|
if ent < cfg.low_ent_thresh and vent < cfg.low_vent_thresh:
|
||||||
metrics["action"] = 0
|
metrics["action"] = 0
|
||||||
res = logits.argmax(dim=1)
|
res = logits[-1, :].argmax(dim=1)
|
||||||
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
|
# High Entropy, Low Varentropy: "treading carefully, asking clarifying questions"
|
||||||
elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh:
|
elif ent > cfg.high_ent_thresh and vent < cfg.low_vent_thresh:
|
||||||
metrics["action"] = 1
|
metrics["action"] = 1
|
||||||
|
@ -510,6 +502,13 @@ def sample_entropix(
|
||||||
|
|
||||||
res = samples[best_sample_idx]
|
res = samples[best_sample_idx]
|
||||||
|
|
||||||
|
"""
|
||||||
|
metrics = {
|
||||||
|
"attn_entropy": metrics["attn_entropy"],
|
||||||
|
"attn_varentropy": metrics["attn_varentropy"],
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
metrics["temperature"] = temperature
|
metrics["temperature"] = temperature
|
||||||
metrics["top_k"] = top_k
|
metrics["top_k"] = top_k
|
||||||
|
|
Loading…
Reference in New Issue
Block a user