actually generate a causal mask because it seems sometimes it does not actually generate one because it makes assumptions
This commit is contained in:
parent
ccee5fc11c
commit
c99a74e834
|
@ -14,6 +14,7 @@ from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
|||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -630,6 +631,75 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
|
||||
return combined_attention_mask
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
"""
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
"""
|
||||
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
|
||||
"""
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
"""
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
@ -692,8 +762,20 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
|
||||
# because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch
|
||||
if is_causal is not None:
|
||||
"""
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=inputs_embeds.shape[1],
|
||||
target_length=attention_mask.shape[-1] if attention_mask is not None else inputs_embeds.shape[1],
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
cache_position=cache_position,
|
||||
batch_size=inputs_embeds.shape[0],
|
||||
)
|
||||
"""
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
||||
noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values)
|
||||
|
||||
x_mask = torch.stack( [ causal_mask[i, :, :, :] if state else noncausal_mask[i, :, :, :] for i, state in enumerate( is_causal ) ], dim=0 )
|
||||
else:
|
||||
x_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
||||
|
|
Loading…
Reference in New Issue
Block a user