From c99a74e834b0ba62c60990d3a35f85ae0b9ca0a0 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 22 Nov 2024 18:30:24 -0600 Subject: [PATCH] actually generate a causal mask because it seems sometimes it does not actually generate one because it makes assumptions --- vall_e/models/arch/llama.py | 82 +++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index cee3eed..59b7b83 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -14,6 +14,7 @@ from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv +from transformers.modeling_attn_mask_utils import AttentionMaskConverter _logger = logging.getLogger(__name__) @@ -630,6 +631,75 @@ class LlamaModel_Adapted(LlamaModel): return combined_attention_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + """ + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + """ + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + """ + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + """ + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + def forward( self, input_ids: torch.LongTensor = None, @@ -692,8 +762,20 @@ class LlamaModel_Adapted(LlamaModel): # because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch if is_causal is not None: + """ + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=inputs_embeds.shape[1], + target_length=attention_mask.shape[-1] if attention_mask is not None else inputs_embeds.shape[1], + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + cache_position=cache_position, + batch_size=inputs_embeds.shape[0], + ) + """ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions) noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values) + x_mask = torch.stack( [ causal_mask[i, :, :, :] if state else noncausal_mask[i, :, :, :] for i, state in enumerate( is_causal ) ], dim=0 ) else: x_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)