that was actually all pointless since sdpa always had an attention mask fed to it and does not need is_causal to implicitly generate one
This commit is contained in:
parent
4aa685e749
commit
ccee5fc11c
|
@ -223,9 +223,16 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
mode = "default" if output_attentions else self.mode
|
||||
non_split_attention = [
|
||||
"default",
|
||||
torch.nn.attention.SDPBackend.MATH,
|
||||
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
|
||||
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
|
||||
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
]
|
||||
|
||||
# split per batch because other attention mechanisms do not have a conditional is_causal per-batch, only for the entire input
|
||||
if isinstance( is_causal, list ) and mode not in ["default"]:
|
||||
if isinstance( is_causal, list ) and mode not in non_split_attention:
|
||||
# initialize lists
|
||||
attn_hidden_states = [ None for _ in is_causal ]
|
||||
self_attn_weights = [ None for _ in is_causal ]
|
||||
|
@ -282,35 +289,6 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
|
||||
return attn_hidden_states, output_attentions, []
|
||||
|
||||
"""
|
||||
h_s = []
|
||||
s_a_w = []
|
||||
p_k_v = []
|
||||
|
||||
for i, state in enumerate(is_causal):
|
||||
hidden_state, self_attn_weight, present_key_value = self.forward(
|
||||
hidden_states=hidden_states[i].unsqueeze(0),
|
||||
attention_mask=attention_mask[i].unsqueeze(0),
|
||||
is_causal=state,
|
||||
position_ids=position_ids[i].unsqueeze(0),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=False,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=(position_embeddings[0][i].unsqueeze(0), position_embeddings[1][i].unsqueeze(0)) if position_embeddings is not None else None,
|
||||
**kwargs,
|
||||
)
|
||||
h_s.append(hidden_state)
|
||||
s_a_w.append(self_attn_weight)
|
||||
p_k_v.append(present_key_value)
|
||||
|
||||
return (
|
||||
torch.concat( h_s, dim=0 ),
|
||||
torch.concat( s_a_w, dim=0 ) if s_a_w else None,
|
||||
p_k_v,
|
||||
)
|
||||
"""
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user