actually generate a causal mask because it seems sometimes it does not actually generate one because it makes assumptions
This commit is contained in:
parent
ccee5fc11c
commit
c99a74e834
|
@ -14,6 +14,7 @@ from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
|
||||||
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -630,6 +631,75 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
|
|
||||||
return combined_attention_mask
|
return combined_attention_mask
|
||||||
|
|
||||||
|
def _update_causal_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Cache,
|
||||||
|
output_attentions: bool,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
"""
|
||||||
|
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
|
||||||
|
"""
|
||||||
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
|
# to infer the attention mask.
|
||||||
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds=input_tensor,
|
||||||
|
past_key_values_length=past_seen_tokens,
|
||||||
|
is_training=self.training,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
"""
|
||||||
|
|
||||||
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
sequence_length = input_tensor.shape[1]
|
||||||
|
if using_static_cache:
|
||||||
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
|
else:
|
||||||
|
target_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=target_length,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=input_tensor.shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and attention_mask is not None
|
||||||
|
and attention_mask.device.type == "cuda"
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
@ -692,8 +762,20 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
|
|
||||||
# because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch
|
# because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch
|
||||||
if is_causal is not None:
|
if is_causal is not None:
|
||||||
|
"""
|
||||||
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=inputs_embeds.shape[1],
|
||||||
|
target_length=attention_mask.shape[-1] if attention_mask is not None else inputs_embeds.shape[1],
|
||||||
|
dtype=inputs_embeds.dtype,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=inputs_embeds.shape[0],
|
||||||
|
)
|
||||||
|
"""
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
||||||
noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values)
|
noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values)
|
||||||
|
|
||||||
x_mask = torch.stack( [ causal_mask[i, :, :, :] if state else noncausal_mask[i, :, :, :] for i, state in enumerate( is_causal ) ], dim=0 )
|
x_mask = torch.stack( [ causal_mask[i, :, :, :] if state else noncausal_mask[i, :, :, :] for i, state in enumerate( is_causal ) ], dim=0 )
|
||||||
else:
|
else:
|
||||||
x_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
x_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user