that was actually all pointless since sdpa always had an attention mask fed to it and does not need is_causal to implicitly generate one
This commit is contained in:
parent
4aa685e749
commit
ccee5fc11c
|
@ -223,9 +223,16 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
mode = "default" if output_attentions else self.mode
|
mode = "default" if output_attentions else self.mode
|
||||||
|
non_split_attention = [
|
||||||
|
"default",
|
||||||
|
torch.nn.attention.SDPBackend.MATH,
|
||||||
|
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
|
||||||
|
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
|
||||||
|
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||||
|
]
|
||||||
|
|
||||||
# split per batch because other attention mechanisms do not have a conditional is_causal per-batch, only for the entire input
|
# split per batch because other attention mechanisms do not have a conditional is_causal per-batch, only for the entire input
|
||||||
if isinstance( is_causal, list ) and mode not in ["default"]:
|
if isinstance( is_causal, list ) and mode not in non_split_attention:
|
||||||
# initialize lists
|
# initialize lists
|
||||||
attn_hidden_states = [ None for _ in is_causal ]
|
attn_hidden_states = [ None for _ in is_causal ]
|
||||||
self_attn_weights = [ None for _ in is_causal ]
|
self_attn_weights = [ None for _ in is_causal ]
|
||||||
|
@ -282,35 +289,6 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
|
|
||||||
return attn_hidden_states, output_attentions, []
|
return attn_hidden_states, output_attentions, []
|
||||||
|
|
||||||
"""
|
|
||||||
h_s = []
|
|
||||||
s_a_w = []
|
|
||||||
p_k_v = []
|
|
||||||
|
|
||||||
for i, state in enumerate(is_causal):
|
|
||||||
hidden_state, self_attn_weight, present_key_value = self.forward(
|
|
||||||
hidden_states=hidden_states[i].unsqueeze(0),
|
|
||||||
attention_mask=attention_mask[i].unsqueeze(0),
|
|
||||||
is_causal=state,
|
|
||||||
position_ids=position_ids[i].unsqueeze(0),
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=False,
|
|
||||||
cache_position=cache_position,
|
|
||||||
position_embeddings=(position_embeddings[0][i].unsqueeze(0), position_embeddings[1][i].unsqueeze(0)) if position_embeddings is not None else None,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
h_s.append(hidden_state)
|
|
||||||
s_a_w.append(self_attn_weight)
|
|
||||||
p_k_v.append(present_key_value)
|
|
||||||
|
|
||||||
return (
|
|
||||||
torch.concat( h_s, dim=0 ),
|
|
||||||
torch.concat( s_a_w, dim=0 ) if s_a_w else None,
|
|
||||||
p_k_v,
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user