that was actually all pointless since sdpa always had an attention mask fed to it and does not need is_causal to implicitly generate one

This commit is contained in:
mrq 2024-11-22 16:51:50 -06:00
parent 4aa685e749
commit ccee5fc11c

View File

@ -223,9 +223,16 @@ class LlamaAttention_Adapted(LlamaAttention):
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
mode = "default" if output_attentions else self.mode
non_split_attention = [
"default",
torch.nn.attention.SDPBackend.MATH,
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
]
# split per batch because other attention mechanisms do not have a conditional is_causal per-batch, only for the entire input
if isinstance( is_causal, list ) and mode not in ["default"]:
if isinstance( is_causal, list ) and mode not in non_split_attention:
# initialize lists
attn_hidden_states = [ None for _ in is_causal ]
self_attn_weights = [ None for _ in is_causal ]
@ -282,35 +289,6 @@ class LlamaAttention_Adapted(LlamaAttention):
return attn_hidden_states, output_attentions, []
"""
h_s = []
s_a_w = []
p_k_v = []
for i, state in enumerate(is_causal):
hidden_state, self_attn_weight, present_key_value = self.forward(
hidden_states=hidden_states[i].unsqueeze(0),
attention_mask=attention_mask[i].unsqueeze(0),
is_causal=state,
position_ids=position_ids[i].unsqueeze(0),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=False,
cache_position=cache_position,
position_embeddings=(position_embeddings[0][i].unsqueeze(0), position_embeddings[1][i].unsqueeze(0)) if position_embeddings is not None else None,
**kwargs,
)
h_s.append(hidden_state)
s_a_w.append(self_attn_weight)
p_k_v.append(present_key_value)
return (
torch.concat( h_s, dim=0 ),
torch.concat( s_a_w, dim=0 ) if s_a_w else None,
p_k_v,
)
"""
dropout_rate = self.attention_dropout if self.training else 0.0
bsz, q_len, _ = hidden_states.size()