at wits end in trying to output the right attention scores
This commit is contained in:
parent
70cf694cfd
commit
d405f243d4
|
@ -135,7 +135,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
self.mode = kwargs['mode']
|
||||
kwargs.pop("mode")
|
||||
else:
|
||||
self.mode = "math"
|
||||
self.mode = "sdpa"
|
||||
|
||||
if self.mode == "math":
|
||||
self.mode = torch.nn.attention.SDPBackend.MATH
|
||||
|
@ -145,8 +145,6 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
|
||||
elif self.mode == "cudnn":
|
||||
self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
else:
|
||||
self.mode = None
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
@ -330,6 +328,8 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# to-do: actually find what is our attention scores, since these seem to not vary at all
|
||||
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if output_attentions else None
|
||||
|
||||
causal_mask = attention_mask
|
||||
|
|
|
@ -1502,9 +1502,11 @@ class Base(nn.Module):
|
|||
# (NAR) return the entire generated response
|
||||
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
|
||||
logits = [ logit[-l:] for logit, l in zip(logits, map(len, prev_list)) ]
|
||||
seq_lens = map(len, prev_list)
|
||||
logits = [ logit[-l:] for logit, l in zip(logits, seq_lens) ]
|
||||
# (AR chunkwise) return the last chunkwise piece
|
||||
elif self.causal:
|
||||
seq_lens = [ logit.shape[0] - self.causal_size for logit in logits ]
|
||||
logits = [ logit[-self.causal_size:] for logit in logits ]
|
||||
|
||||
# (NAR) disable stop token
|
||||
|
@ -1537,12 +1539,12 @@ class Base(nn.Module):
|
|||
|
||||
res = [ sample_entropix(
|
||||
logit,
|
||||
attentions[-1], # original code just uses the last attention scores
|
||||
torch.stack(attentions, dim=1)[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, ? ), our attention scores might be padded
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
min_p,
|
||||
) for logit in logits ]
|
||||
) for batch, logit in enumerate(logits) ]
|
||||
|
||||
if res:
|
||||
return Sampled([ r[0] for r in res], scores, [ r[1] for r in res])
|
||||
|
|
|
@ -2,6 +2,7 @@ import math
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
from torch import Tensor, einsum, nn
|
||||
|
||||
|
@ -229,38 +230,88 @@ def dry_sampling( logits, previous=None, factor=0.0, base=1.75, allowed_length=2
|
|||
LN_2 = 0.69314718056 # ln(2) = 1.0 / LOG2_E
|
||||
|
||||
# Grabbed from https://github.com/xjdr-alt/entropix/blob/main/entropix/sampler.py
|
||||
# Right now I only care about quantifying these two, I'll figure out how to best apply this to the model
|
||||
def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ):
|
||||
def calculate_entropix_metrics( logits, attentions=None, dim=-1, use_stats=False ):
|
||||
"""Calculate the entropy and varentropy of the probability distribution using logsoftmax."""
|
||||
log_probs = torch.nn.functional.log_softmax(logits, dim=dim)
|
||||
log_probs = F.log_softmax(logits, dim=dim)
|
||||
probs = torch.exp(log_probs)
|
||||
entropy = -torch.sum(probs * log_probs, dim=dim) / LN_2 # Convert to base-2
|
||||
varentropy = torch.sum(probs * (log_probs / LN_2 + entropy[..., None])**2, dim=dim)
|
||||
varentropy = torch.sum(probs * (log_probs / LN_2 + entropy.unsqueeze(-1))**2, dim=dim)
|
||||
|
||||
if attention_scores is None:
|
||||
if attentions is None:
|
||||
return {
|
||||
"logits_entropy": torch.mean(entropy).item(),
|
||||
"logits_varentropy": torch.mean(varentropy).item(),
|
||||
}
|
||||
|
||||
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
|
||||
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clip(attention_probs, 1e-10, 1.0)), dim=-1)
|
||||
attn_varentropy = torch.var(attn_entropy, dim=1)
|
||||
last_attention_scores = attentions[-1].unsqueeze(0) # ( bsz, heads, seq_len, ? )
|
||||
attention_probs = F.softmax(last_attention_scores, dim=-1)
|
||||
if use_stats:
|
||||
attn_stats = AttnStats.new( 1, attentions.shape[0], attentions.shape[1], logits.device )
|
||||
for idx, attn in enumerate( attentions ):
|
||||
attn_stats.update( attn.unsqueeze(0)[:, :, -1, :], idx ) # (bsz, heads, last_token, ?)
|
||||
attn_entropy = attn_stats.entropy
|
||||
attn_varentropy = attn_stats.varentropy
|
||||
else:
|
||||
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1)
|
||||
attn_varentropy = torch.var(attn_entropy, dim=1)
|
||||
|
||||
# Add a small epsilon to avoid NaN when all values are the same
|
||||
attn_varentropy = torch.where(torch.isnan(attn_varentropy), torch.zeros_like(attn_varentropy), attn_varentropy)
|
||||
mean_attention = torch.mean(attention_probs, dim=1)
|
||||
agreement = torch.mean(torch.abs(attention_probs - mean_attention[:, None, :]), dim=(1, 2))
|
||||
agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2))
|
||||
|
||||
interaction_strength = torch.mean(torch.abs(last_attention_scores), dim=(1, 2, 3))
|
||||
|
||||
interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3))
|
||||
return {
|
||||
"logits_entropy": torch.mean(entropy).item(),
|
||||
"logits_varentropy": torch.mean(varentropy).item(),
|
||||
"attn_entropy": torch.mean(attn_entropy).item(),
|
||||
"attn_varentropy": torch.mean(attn_varentropy).item(),
|
||||
"agreement": torch.mean(agreement).item(),
|
||||
"interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)).item(),
|
||||
"interaction_strength": interaction_strength.item(), # torch.mean(interaction_strength).item(),
|
||||
"action": -1
|
||||
}
|
||||
|
||||
from typing import NamedTuple
|
||||
class AttnStats(NamedTuple):
|
||||
entropy: torch.Tensor # (bsz, n_layers, num_heads)
|
||||
varentropy: torch.Tensor # (bsz, n_layers, num_heads)
|
||||
n_layers: int
|
||||
n_heads: int
|
||||
|
||||
@classmethod
|
||||
def new(cls, bsz: int, n_layers: int, n_heads: int, device = "cuda") -> 'AttnStats':
|
||||
return cls(
|
||||
entropy=torch.zeros((bsz, n_layers, n_heads), dtype=torch.float32, device=device),
|
||||
varentropy=torch.zeros((bsz, n_layers, n_heads), dtype=torch.float32, device=device),
|
||||
n_layers=n_layers,
|
||||
n_heads=n_heads
|
||||
)
|
||||
|
||||
@property
|
||||
def avg_entropy(self):
|
||||
return self.entropy.sum(dim=-1, keepdim=False) # Average across heads
|
||||
|
||||
@property
|
||||
def avg_varentropy(self):
|
||||
return self.varentropy.sum(dim=-1, keepdim=False) # Average across heads
|
||||
|
||||
@property
|
||||
def std_error(self):
|
||||
return torch.sqrt(torch.mean(self.varentropy)) / (self.n_heads * self.n_layers)
|
||||
|
||||
def update(self, scores: torch.Tensor, layer_idx: int):
|
||||
# scores shape: (bsz, n_heads, seqlen, n_words)
|
||||
probs = torch.nn.functional.softmax(scores, dim=-1)
|
||||
new_entropy = -torch.sum(torch.where(probs > 0, probs * torch.log(probs), torch.tensor(0.0)), dim=-1)
|
||||
new_varentropy = torch.sum(probs * (torch.log(probs) + new_entropy.unsqueeze(-1))**2, dim=-1)
|
||||
|
||||
# Update entropy and varentropy tensors
|
||||
self.entropy[:, layer_idx, :] = new_entropy
|
||||
self.varentropy[:, layer_idx, :] = new_varentropy
|
||||
|
||||
return self
|
||||
|
||||
# to-do: play around with these values
|
||||
@dataclass()
|
||||
class EntropixSamplerConfig:
|
||||
|
@ -336,27 +387,36 @@ def _sample_entropix(
|
|||
top_k = clamp( int(top_k), cfg.top_k_min, cfg.top_k_max )
|
||||
min_p = clamp( float(min_p), cfg.min_p_min, cfg.min_p_max )
|
||||
|
||||
probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
|
||||
probs = F.softmax(logits / temperature, dim=-1)
|
||||
|
||||
# Apply min_p sampling
|
||||
if min_p > 0.0:
|
||||
p_max = float(torch.max(probs, dim=-1, keepdims=True).values)
|
||||
p_max = float(torch.max(probs, dim=-1, keepdim=True).values)
|
||||
indices_to_remove = probs < (min_p * p_max)
|
||||
logits = torch.where(indices_to_remove, torch.full_like(logits, float('-inf')), logits)
|
||||
|
||||
# Apply top-k sampling
|
||||
top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
|
||||
top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]))
|
||||
probs_sort = torch.flip(top_k_probs, dims=[-1])
|
||||
probs_idx = torch.flip(top_k_indices, dims=[-1])
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
|
||||
# Apply top-p sampling
|
||||
mask = torch.where(probs_sum - probs_sort > top_p, 1.0, 0.0)
|
||||
mask = torch.where(probs_sum - probs_sort > top_p, torch.tensor(1.0, device=logits.device), torch.tensor(0.0, device=logits.device))
|
||||
probs_sort = probs_sort * (1 - mask)
|
||||
probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdims=True)
|
||||
probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdim=True)
|
||||
|
||||
next_token = torch.argmax(probs_sort / Exponential.sample(probs_sort.shape), dim=-1, keepdim=True)
|
||||
q = Exponential.sample(probs_sort.shape)
|
||||
"""
|
||||
next_token = torch.argmax(probs_sort / q, dim=-1, keepdim=True)
|
||||
return torch.take_along_dim(probs_idx, next_token, dim=-1)[0]
|
||||
"""
|
||||
|
||||
"""
|
||||
# q = torch.rand(probs_sort.shape, generator=generator, device=probs_sort.device)
|
||||
"""
|
||||
next_token = torch.argmax(probs_sort / q, dim=-1, keepdim=True)
|
||||
next_token_g = torch.take_along_dim(probs_idx, next_token, dim=-1)[0]
|
||||
return next_token_g
|
||||
|
||||
def sample_entropix(
|
||||
logits,
|
||||
|
@ -373,6 +433,13 @@ def sample_entropix(
|
|||
top_p = cfg.top_p
|
||||
"""
|
||||
|
||||
"""
|
||||
if generator is None:
|
||||
generator = torch.Generator(device=logits.device).manual_seed(int(time.time()))
|
||||
"""
|
||||
|
||||
# logits: ( bsz, vocab )
|
||||
# attentions: ( bsz, layer, heads, seq_len, ? )
|
||||
metrics = calculate_entropix_metrics( logits, attentions )
|
||||
|
||||
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
|
||||
|
@ -407,7 +474,7 @@ def sample_entropix(
|
|||
else:
|
||||
metrics["action"] = 4
|
||||
|
||||
log_softmax = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
log_softmax = F.log_softmax(logits, dim=-1)
|
||||
logits_uncertainty = ent + vent
|
||||
attn_uncertainty = attn_ent + attn_vent
|
||||
|
||||
|
@ -419,7 +486,7 @@ def sample_entropix(
|
|||
samples = [ _sample_entropix( logits.clone(), temperature, top_k, top_p, min_p, cfg=cfg ) for _ in range(cfg.n_adaptive_samples) ]
|
||||
|
||||
def score_sample(sample):
|
||||
one_hot = torch.nn.functional.one_hot( sample, logits.shape[-1] )
|
||||
one_hot = F.one_hot( sample, logits.shape[-1] )
|
||||
log_prob = torch.sum(log_softmax * one_hot)
|
||||
|
||||
confidence_score = (
|
||||
|
|
Loading…
Reference in New Issue
Block a user