huge oversight in the attention masking......... (i realized I have not been providing a non-causal mask to non-causal tasks)
This commit is contained in:
parent
24d888c47c
commit
147219a5e0
|
@ -265,6 +265,8 @@ class ModelExperimentalSettings:
|
|||
masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick, "rand" will pick between [0.2, 0.8]
|
||||
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
|
||||
|
||||
noncausal_masks: bool = False # to correct an oversight with Llama always using causal masks......
|
||||
|
||||
# classifier-free guidance training settings
|
||||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
|
||||
|
|
|
@ -109,7 +109,6 @@ try:
|
|||
except Exception as e:
|
||||
_logger.warning(f"Error while querying for `flash_attn` support: {str(e)}")
|
||||
|
||||
"""
|
||||
try:
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask, LowerTriangularMask
|
||||
|
@ -117,7 +116,6 @@ try:
|
|||
AVAILABLE_ATTENTIONS.append("xformers")
|
||||
except Exception as e:
|
||||
_logger.warning(f"Error while importing `xformers`: {str(e)}")
|
||||
"""
|
||||
|
||||
# to-do: find a better way to query for if there's available kernels since these return true regardless
|
||||
if torch.backends.cuda.flash_sdp_enabled():
|
||||
|
@ -246,20 +244,21 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
x_mask = attention_mask
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
x_mask = x_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
if query_states.device.type == "cuda" and x_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
is_causal = True if x_mask is None and q_len > 1 else False
|
||||
|
||||
if mode in ["fused_attn"]:
|
||||
attn_output = fused_attn_func(
|
||||
|
@ -273,7 +272,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
elif mode in ["default"]:
|
||||
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
# cringe logic
|
||||
attn_weights = (attn_scores + causal_mask) if attention_mask is not None else (attn_scores)
|
||||
attn_weights = (attn_scores + x_mask) if attention_mask is not None else (attn_scores)
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
|
@ -290,7 +289,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
attn_mask=x_mask,
|
||||
dropout_p=dropout_rate,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
@ -458,10 +457,55 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
return self.early_exit_scale * sum([ i for i in range(0, l) ])
|
||||
return self.layers_n - 1 + self.early_exit_scale * sum([ i for i in range(0, self.layers_n - 1) ])
|
||||
|
||||
# shamelessly borrowed from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 until I replace it with my own noncausal-mask maker
|
||||
def _update_noncausal_mask(
|
||||
self, attention_mask, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# create noncausal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
|
||||
input_shape = (inputs_embeds.shape[0], inputs_embeds.shape[1])
|
||||
|
||||
def _expand_mask(
|
||||
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = (
|
||||
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones( input_shape, dtype=torch.bool, device=inputs_embeds.device )
|
||||
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
).to(inputs_embeds.device)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
@ -517,9 +561,14 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
# because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch
|
||||
if is_causal is not None:
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
||||
noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values)
|
||||
x_mask = torch.stack( [ causal_mask[i, :, :, :] if state else noncausal_mask[i, :, :, :] for i, state in enumerate( is_causal ) ], dim=0 )
|
||||
else:
|
||||
x_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
|
@ -538,7 +587,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
x_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
|
@ -550,7 +599,7 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=x_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
|
|
|
@ -435,6 +435,7 @@ class Base(nn.Module):
|
|||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
|
||||
|
||||
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
||||
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
||||
|
@ -483,6 +484,11 @@ class Base(nn.Module):
|
|||
self.inject_timestep_embedding = False # results in bad output
|
||||
self.masking_ratio = masking_ratio
|
||||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||
self.noncausal_masks = noncausal_masks
|
||||
|
||||
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
|
||||
if noncausal_masks:
|
||||
attention_backend = "default"
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -773,6 +779,7 @@ class Base(nn.Module):
|
|||
self,
|
||||
inputs,
|
||||
mask = None,
|
||||
is_causal = None,
|
||||
position_ids = None,
|
||||
|
||||
state = None,
|
||||
|
@ -800,6 +807,7 @@ class Base(nn.Module):
|
|||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
if self.n_experts > 1 and self.training:
|
||||
|
@ -1514,11 +1522,16 @@ class Base(nn.Module):
|
|||
# needs to be done here as we still have our raw inputs
|
||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
||||
casual_levels = [ "AR:0:0", "stt", "len" ]
|
||||
|
||||
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
||||
is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else None
|
||||
|
||||
output = self._forward(
|
||||
inputs=x,
|
||||
mask=mask,
|
||||
state=state,
|
||||
is_causal=is_causal,
|
||||
position_ids=position_ids,
|
||||
output_attentions = output_attentions,
|
||||
output_hidden_states = output_hidden_states,
|
||||
|
|
Loading…
Reference in New Issue
Block a user