output attention scores for SDPA/flash, since naive attention seems broken
This commit is contained in:
parent
541e45263c
commit
70cf694cfd
|
@ -64,6 +64,7 @@ def main():
|
|||
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=0)
|
||||
parser.add_argument("--min-p", type=float, default=0.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
|
||||
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||
|
@ -121,8 +122,15 @@ def main():
|
|||
comparison_kwargs["enabled"] = True
|
||||
comparison_kwargs["suffix"] = "_entropix"
|
||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||
|
||||
comparison_kwargs["before"]["entropix_sampling"] = True
|
||||
comparison_kwargs["before"]["ar_temp"] = 0.666
|
||||
comparison_kwargs["before"]["top_k"] = 27
|
||||
comparison_kwargs["before"]["top_p"] = 0.9
|
||||
comparison_kwargs["after"]["entropix_sampling"] = False
|
||||
comparison_kwargs["after"]["ar_temp"] = args.ar_temp
|
||||
comparison_kwargs["after"]["top_k"] = args.top_k
|
||||
comparison_kwargs["after"]["top_p"] = args.top_p
|
||||
|
||||
|
||||
# read html template
|
||||
|
@ -135,7 +143,7 @@ def main():
|
|||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
|
||||
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
|
||||
top_p=args.top_p, top_k=args.top_k,
|
||||
top_p=args.top_p, top_k=args.top_k, min_p=args.min_p,
|
||||
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
|
||||
length_penalty=args.length_penalty,
|
||||
beam_width=args.beam_width,
|
||||
|
|
|
@ -150,7 +150,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Adapted from LlamaAttention.forward
|
||||
# Adapted from LlamaAttention.forward, this doesn't seem to give great output......
|
||||
def _forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
@ -163,41 +163,18 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
@ -259,7 +236,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions or not self.mode:
|
||||
if not self.mode:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
return self._forward(
|
||||
hidden_states=hidden_states,
|
||||
|
@ -295,6 +272,9 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
if self.mode in ["xformers", "flash_attn"]:
|
||||
if output_attentions:
|
||||
attn_scores = torch.matmul(query_states, repeat_kv(key_states, self.num_key_value_groups).transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
@ -346,10 +326,11 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, None, past_key_value
|
||||
return attn_output, attn_scores, past_key_value
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if output_attentions else None
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
|
@ -391,4 +372,4 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
return attn_output, attn_scores, past_key_value
|
Loading…
Reference in New Issue
Block a user