what has science done
This commit is contained in:
parent
147219a5e0
commit
4aa685e749
|
@ -149,11 +149,71 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# extracts inputs from a batch based on requested causality
|
||||
def split_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: Optional[list] = None,
|
||||
target_causal_state: Optional[bool] = True,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
**kwargs,
|
||||
):
|
||||
indices = [ i for i, state in enumerate( is_causal ) if state == target_causal_state ]
|
||||
|
||||
# no matching inputs in batch
|
||||
if not indices:
|
||||
return indices, None, None, None
|
||||
|
||||
# entire batch is homogenous
|
||||
if len( indices ) == hidden_states.shape[0]:
|
||||
output_hidden_states, output_self_attn_weights, output_present_key_values = self.forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
is_causal=target_causal_state,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=False,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
return indices, output_hidden_states, output_self_attn_weights, output_present_key_values
|
||||
|
||||
input_hidden_states = torch.stack( [ hidden_states[i] for i in indices ] )
|
||||
input_attention_mask = torch.stack( [ attention_mask[i] for i in indices ] ) if attention_mask is not None else None
|
||||
input_position_ids = torch.stack( [ position_ids[i] for i in indices ] ) if position_ids is not None else None
|
||||
input_position_embeddings = (
|
||||
torch.stack( [ position_embeddings[0][i] for i in indices ] ),
|
||||
torch.stack( [ position_embeddings[1][i] for i in indices ] ),
|
||||
) if position_embeddings is not None else None
|
||||
|
||||
output_hidden_states, output_self_attn_weights, output_present_key_values = self.forward(
|
||||
hidden_states=input_hidden_states,
|
||||
attention_mask=input_attention_mask,
|
||||
is_causal=target_causal_state,
|
||||
position_ids=input_position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=False,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=input_position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
return indices, output_hidden_states, output_self_attn_weights, output_present_key_values
|
||||
|
||||
# Adapted from LlamaAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = True,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
|
@ -163,6 +223,94 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
mode = "default" if output_attentions else self.mode
|
||||
|
||||
# split per batch because other attention mechanisms do not have a conditional is_causal per-batch, only for the entire input
|
||||
if isinstance( is_causal, list ) and mode not in ["default"]:
|
||||
# initialize lists
|
||||
attn_hidden_states = [ None for _ in is_causal ]
|
||||
self_attn_weights = [ None for _ in is_causal ]
|
||||
present_key_values = [ None for _ in is_causal ]
|
||||
|
||||
# process causal inputs in a batch
|
||||
causal_indices, causal_hidden_states, causal_self_attn_weights, causal_present_key_values = self.split_forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
is_causal=is_causal,
|
||||
target_causal_state=True,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=False,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# process non-causal inputs in a batch
|
||||
non_causal_indices, non_causal_hidden_states, non_causal_self_attn_weights, non_causal_present_key_values = self.split_forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
is_causal=is_causal,
|
||||
target_causal_state=False,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=False,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# insert causal outputs to batch
|
||||
for i, idx in enumerate( causal_indices ):
|
||||
attn_hidden_states[idx] = causal_hidden_states[i]
|
||||
|
||||
if output_attentions:
|
||||
self_attn_weights[idx] = causal_self_attn_weights[i]
|
||||
|
||||
# insert non-causal outputs to batch
|
||||
for i, idx in enumerate( non_causal_indices ):
|
||||
attn_hidden_states[idx] = non_causal_hidden_states[i]
|
||||
|
||||
if output_attentions:
|
||||
self_attn_weights[idx] = non_causal_self_attn_weights[i]
|
||||
|
||||
# combine list
|
||||
attn_hidden_states = torch.stack( attn_hidden_states, dim=0 )
|
||||
if output_attentions:
|
||||
self_attn_weights = torch.stack( self_attn_weights, dim=0 )
|
||||
|
||||
return attn_hidden_states, output_attentions, []
|
||||
|
||||
"""
|
||||
h_s = []
|
||||
s_a_w = []
|
||||
p_k_v = []
|
||||
|
||||
for i, state in enumerate(is_causal):
|
||||
hidden_state, self_attn_weight, present_key_value = self.forward(
|
||||
hidden_states=hidden_states[i].unsqueeze(0),
|
||||
attention_mask=attention_mask[i].unsqueeze(0),
|
||||
is_causal=state,
|
||||
position_ids=position_ids[i].unsqueeze(0),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=False,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=(position_embeddings[0][i].unsqueeze(0), position_embeddings[1][i].unsqueeze(0)) if position_embeddings is not None else None,
|
||||
**kwargs,
|
||||
)
|
||||
h_s.append(hidden_state)
|
||||
s_a_w.append(self_attn_weight)
|
||||
p_k_v.append(present_key_value)
|
||||
|
||||
return (
|
||||
torch.concat( h_s, dim=0 ),
|
||||
torch.concat( s_a_w, dim=0 ) if s_a_w else None,
|
||||
p_k_v,
|
||||
)
|
||||
"""
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
|
@ -221,7 +369,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
causal=True,
|
||||
causal=is_causal,
|
||||
softmax_scale=1.0 / math.sqrt(self.head_dim),
|
||||
dropout_p=dropout_rate,
|
||||
)
|
||||
|
@ -232,7 +380,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias = LowerTriangularMask() if attention_mask is None or attention_mask[0, 0, 0, 1] == 0 else None,
|
||||
attn_bias = LowerTriangularMask(),
|
||||
scale = 1.0 / math.sqrt(self.head_dim),
|
||||
p=dropout_rate
|
||||
)
|
||||
|
@ -258,14 +406,14 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if x_mask is None and q_len > 1 else False
|
||||
# is_causal = True if x_mask is None and q_len > 1 else False
|
||||
|
||||
if mode in ["fused_attn"]:
|
||||
attn_output = fused_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
causal=True,
|
||||
causal=is_causal,
|
||||
softmax_scale=1.0 / math.sqrt(self.head_dim),
|
||||
dropout_p=dropout_rate,
|
||||
)
|
||||
|
@ -284,6 +432,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
f" {attn_output.size()}"
|
||||
)
|
||||
else:
|
||||
is_causal = True if x_mask is None and q_len > 1 else False
|
||||
with torch.nn.attention.sdpa_kernel(self.mode):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
|
@ -332,6 +481,7 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
|
|||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
|
@ -371,6 +521,7 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer):
|
|||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
is_causal=is_causal,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
|
@ -588,6 +739,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
x_mask,
|
||||
is_causal,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
|
@ -600,6 +752,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=x_mask,
|
||||
is_causal=is_causal,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
|
|
|
@ -487,8 +487,10 @@ class Base(nn.Module):
|
|||
self.noncausal_masks = noncausal_masks
|
||||
|
||||
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
|
||||
"""
|
||||
if noncausal_masks:
|
||||
attention_backend = "default"
|
||||
"""
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
|
Loading…
Reference in New Issue
Block a user