huge oversight in the attention masking......... (i realized I have not been providing a non-causal mask to non-causal tasks)

This commit is contained in:
mrq 2024-11-22 13:44:43 -06:00
parent 24d888c47c
commit 147219a5e0
3 changed files with 77 additions and 13 deletions

View File

@ -265,6 +265,8 @@ class ModelExperimentalSettings:
masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick, "rand" will pick between [0.2, 0.8]
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
noncausal_masks: bool = False # to correct an oversight with Llama always using causal masks......
# classifier-free guidance training settings
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training

View File

@ -109,7 +109,6 @@ try:
except Exception as e:
_logger.warning(f"Error while querying for `flash_attn` support: {str(e)}")
"""
try:
from xformers.ops.fmha import memory_efficient_attention
from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask, LowerTriangularMask
@ -117,7 +116,6 @@ try:
AVAILABLE_ATTENTIONS.append("xformers")
except Exception as e:
_logger.warning(f"Error while importing `xformers`: {str(e)}")
"""
# to-do: find a better way to query for if there's available kernels since these return true regardless
if torch.backends.cuda.flash_sdp_enabled():
@ -246,20 +244,21 @@ class LlamaAttention_Adapted(LlamaAttention):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
x_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
x_mask = x_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
if query_states.device.type == "cuda" and x_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
is_causal = True if x_mask is None and q_len > 1 else False
if mode in ["fused_attn"]:
attn_output = fused_attn_func(
@ -273,7 +272,7 @@ class LlamaAttention_Adapted(LlamaAttention):
elif mode in ["default"]:
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# cringe logic
attn_weights = (attn_scores + causal_mask) if attention_mask is not None else (attn_scores)
attn_weights = (attn_scores + x_mask) if attention_mask is not None else (attn_scores)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
@ -290,7 +289,7 @@ class LlamaAttention_Adapted(LlamaAttention):
query_states,
key_states,
value_states,
attn_mask=causal_mask,
attn_mask=x_mask,
dropout_p=dropout_rate,
is_causal=is_causal,
)
@ -458,10 +457,55 @@ class LlamaModel_Adapted(LlamaModel):
return self.early_exit_scale * sum([ i for i in range(0, l) ])
return self.layers_n - 1 + self.early_exit_scale * sum([ i for i in range(0, self.layers_n - 1) ])
# shamelessly borrowed from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 until I replace it with my own noncausal-mask maker
def _update_noncausal_mask(
self, attention_mask, inputs_embeds, past_key_values_length
):
# create noncausal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
input_shape = (inputs_embeds.shape[0], inputs_embeds.shape[1])
def _expand_mask(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
if attention_mask is None:
attention_mask = torch.ones( input_shape, dtype=torch.bool, device=inputs_embeds.device )
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
is_causal: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
@ -517,9 +561,14 @@ class LlamaModel_Adapted(LlamaModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch
if is_causal is not None:
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values)
x_mask = torch.stack( [ causal_mask[i, :, :, :] if state else noncausal_mask[i, :, :, :] for i, state in enumerate( is_causal ) ], dim=0 )
else:
x_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
@ -538,7 +587,7 @@ class LlamaModel_Adapted(LlamaModel):
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
x_mask,
position_ids,
past_key_values,
output_attentions,
@ -550,7 +599,7 @@ class LlamaModel_Adapted(LlamaModel):
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
attention_mask=x_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,

View File

@ -435,6 +435,7 @@ class Base(nn.Module):
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
interleave = self.config.experimental.interleave if self.config is not None else False
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
@ -483,6 +484,11 @@ class Base(nn.Module):
self.inject_timestep_embedding = False # results in bad output
self.masking_ratio = masking_ratio
self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.noncausal_masks = noncausal_masks
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
if noncausal_masks:
attention_backend = "default"
self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None
@ -773,6 +779,7 @@ class Base(nn.Module):
self,
inputs,
mask = None,
is_causal = None,
position_ids = None,
state = None,
@ -800,6 +807,7 @@ class Base(nn.Module):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
is_causal=is_causal,
)
if self.n_experts > 1 and self.training:
@ -1514,11 +1522,16 @@ class Base(nn.Module):
# needs to be done here as we still have our raw inputs
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
classifier_levels = self.get_input( inputs, name="classifier_level" )
casual_levels = [ "AR:0:0", "stt", "len" ]
# right now limit to new versions because I need to retrain the model for noncausal masks...
is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else None
output = self._forward(
inputs=x,
mask=mask,
state=state,
is_causal=is_causal,
position_ids=position_ids,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,