output attention scores for SDPA/flash, since naive attention seems broken

This commit is contained in:
mrq 2024-10-12 12:09:17 -05:00
parent 541e45263c
commit 70cf694cfd
2 changed files with 21 additions and 32 deletions

View File

@ -64,6 +64,7 @@ def main():
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--min-p", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0)
@ -121,8 +122,15 @@ def main():
comparison_kwargs["enabled"] = True
comparison_kwargs["suffix"] = "_entropix"
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
comparison_kwargs["before"]["entropix_sampling"] = True
comparison_kwargs["before"]["ar_temp"] = 0.666
comparison_kwargs["before"]["top_k"] = 27
comparison_kwargs["before"]["top_p"] = 0.9
comparison_kwargs["after"]["entropix_sampling"] = False
comparison_kwargs["after"]["ar_temp"] = args.ar_temp
comparison_kwargs["after"]["top_k"] = args.top_k
comparison_kwargs["after"]["top_p"] = args.top_p
# read html template
@ -135,7 +143,7 @@ def main():
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
top_p=args.top_p, top_k=args.top_k,
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,

View File

@ -150,7 +150,7 @@ class LlamaAttention_Adapted(LlamaAttention):
super().__init__(*args, **kwargs)
# Adapted from LlamaAttention.forward
# Adapted from LlamaAttention.forward, this doesn't seem to give great output......
def _forward(
self,
hidden_states: torch.Tensor,
@ -163,41 +163,18 @@ class LlamaAttention_Adapted(LlamaAttention):
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
dropout_rate = self.attention_dropout if self.training else 0.0
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
@ -259,7 +236,7 @@ class LlamaAttention_Adapted(LlamaAttention):
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions or not self.mode:
if not self.mode:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
return self._forward(
hidden_states=hidden_states,
@ -295,6 +272,9 @@ class LlamaAttention_Adapted(LlamaAttention):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if self.mode in ["xformers", "flash_attn"]:
if output_attentions:
attn_scores = torch.matmul(query_states, repeat_kv(key_states, self.num_key_value_groups).transpose(2, 3)) / math.sqrt(self.head_dim)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
@ -346,10 +326,11 @@ class LlamaAttention_Adapted(LlamaAttention):
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return attn_output, attn_scores, past_key_value
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if output_attentions else None
causal_mask = attention_mask
if attention_mask is not None:
@ -391,4 +372,4 @@ class LlamaAttention_Adapted(LlamaAttention):
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return attn_output, attn_scores, past_key_value