huge oversight in the attention masking......... (i realized I have not been providing a non-causal mask to non-causal tasks)
This commit is contained in:
parent
24d888c47c
commit
147219a5e0
|
@ -265,6 +265,8 @@ class ModelExperimentalSettings:
|
||||||
masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick, "rand" will pick between [0.2, 0.8]
|
masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick, "rand" will pick between [0.2, 0.8]
|
||||||
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
|
ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence
|
||||||
|
|
||||||
|
noncausal_masks: bool = False # to correct an oversight with Llama always using causal masks......
|
||||||
|
|
||||||
# classifier-free guidance training settings
|
# classifier-free guidance training settings
|
||||||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||||
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
|
cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training
|
||||||
|
|
|
@ -109,7 +109,6 @@ try:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.warning(f"Error while querying for `flash_attn` support: {str(e)}")
|
_logger.warning(f"Error while querying for `flash_attn` support: {str(e)}")
|
||||||
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
from xformers.ops.fmha import memory_efficient_attention
|
from xformers.ops.fmha import memory_efficient_attention
|
||||||
from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask, LowerTriangularMask
|
from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask, LowerTriangularMask
|
||||||
|
@ -117,7 +116,6 @@ try:
|
||||||
AVAILABLE_ATTENTIONS.append("xformers")
|
AVAILABLE_ATTENTIONS.append("xformers")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_logger.warning(f"Error while importing `xformers`: {str(e)}")
|
_logger.warning(f"Error while importing `xformers`: {str(e)}")
|
||||||
"""
|
|
||||||
|
|
||||||
# to-do: find a better way to query for if there's available kernels since these return true regardless
|
# to-do: find a better way to query for if there's available kernels since these return true regardless
|
||||||
if torch.backends.cuda.flash_sdp_enabled():
|
if torch.backends.cuda.flash_sdp_enabled():
|
||||||
|
@ -246,20 +244,21 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
causal_mask = attention_mask
|
x_mask = attention_mask
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
x_mask = x_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
|
||||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
if query_states.device.type == "cuda" and x_mask is not None:
|
||||||
query_states = query_states.contiguous()
|
query_states = query_states.contiguous()
|
||||||
key_states = key_states.contiguous()
|
key_states = key_states.contiguous()
|
||||||
value_states = value_states.contiguous()
|
value_states = value_states.contiguous()
|
||||||
|
|
||||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
is_causal = True if x_mask is None and q_len > 1 else False
|
||||||
|
|
||||||
if mode in ["fused_attn"]:
|
if mode in ["fused_attn"]:
|
||||||
attn_output = fused_attn_func(
|
attn_output = fused_attn_func(
|
||||||
|
@ -273,7 +272,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
elif mode in ["default"]:
|
elif mode in ["default"]:
|
||||||
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
# cringe logic
|
# cringe logic
|
||||||
attn_weights = (attn_scores + causal_mask) if attention_mask is not None else (attn_scores)
|
attn_weights = (attn_scores + x_mask) if attention_mask is not None else (attn_scores)
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||||
|
@ -290,7 +289,7 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
attn_mask=causal_mask,
|
attn_mask=x_mask,
|
||||||
dropout_p=dropout_rate,
|
dropout_p=dropout_rate,
|
||||||
is_causal=is_causal,
|
is_causal=is_causal,
|
||||||
)
|
)
|
||||||
|
@ -458,10 +457,55 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
return self.early_exit_scale * sum([ i for i in range(0, l) ])
|
return self.early_exit_scale * sum([ i for i in range(0, l) ])
|
||||||
return self.layers_n - 1 + self.early_exit_scale * sum([ i for i in range(0, self.layers_n - 1) ])
|
return self.layers_n - 1 + self.early_exit_scale * sum([ i for i in range(0, self.layers_n - 1) ])
|
||||||
|
|
||||||
|
# shamelessly borrowed from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 until I replace it with my own noncausal-mask maker
|
||||||
|
def _update_noncausal_mask(
|
||||||
|
self, attention_mask, inputs_embeds, past_key_values_length
|
||||||
|
):
|
||||||
|
# create noncausal mask
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
combined_attention_mask = None
|
||||||
|
|
||||||
|
input_shape = (inputs_embeds.shape[0], inputs_embeds.shape[1])
|
||||||
|
|
||||||
|
def _expand_mask(
|
||||||
|
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
|
"""
|
||||||
|
bsz, src_len = mask.size()
|
||||||
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
|
expanded_mask = (
|
||||||
|
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
|
return inverted_mask.masked_fill(
|
||||||
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones( input_shape, dtype=torch.bool, device=inputs_embeds.device )
|
||||||
|
|
||||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
|
expanded_attn_mask = _expand_mask(
|
||||||
|
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||||
|
).to(inputs_embeds.device)
|
||||||
|
combined_attention_mask = (
|
||||||
|
expanded_attn_mask
|
||||||
|
if combined_attention_mask is None
|
||||||
|
else expanded_attn_mask + combined_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
is_causal: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
@ -517,9 +561,14 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(
|
# because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch
|
||||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
if is_causal is not None:
|
||||||
)
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
||||||
|
noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values)
|
||||||
|
x_mask = torch.stack( [ causal_mask[i, :, :, :] if state else noncausal_mask[i, :, :, :] for i, state in enumerate( is_causal ) ], dim=0 )
|
||||||
|
else:
|
||||||
|
x_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
# create position embeddings to be shared across the decoder layers
|
# create position embeddings to be shared across the decoder layers
|
||||||
|
@ -538,7 +587,7 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
decoder_layer.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
x_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
@ -550,7 +599,7 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=causal_mask,
|
attention_mask=x_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
|
|
@ -435,6 +435,7 @@ class Base(nn.Module):
|
||||||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||||
interleave = self.config.experimental.interleave if self.config is not None else False
|
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||||
|
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
|
||||||
|
|
||||||
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
||||||
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
||||||
|
@ -483,6 +484,11 @@ class Base(nn.Module):
|
||||||
self.inject_timestep_embedding = False # results in bad output
|
self.inject_timestep_embedding = False # results in bad output
|
||||||
self.masking_ratio = masking_ratio
|
self.masking_ratio = masking_ratio
|
||||||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||||
|
self.noncausal_masks = noncausal_masks
|
||||||
|
|
||||||
|
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
|
||||||
|
if noncausal_masks:
|
||||||
|
attention_backend = "default"
|
||||||
|
|
||||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||||
self.langs_emb = None
|
self.langs_emb = None
|
||||||
|
@ -773,6 +779,7 @@ class Base(nn.Module):
|
||||||
self,
|
self,
|
||||||
inputs,
|
inputs,
|
||||||
mask = None,
|
mask = None,
|
||||||
|
is_causal = None,
|
||||||
position_ids = None,
|
position_ids = None,
|
||||||
|
|
||||||
state = None,
|
state = None,
|
||||||
|
@ -800,6 +807,7 @@ class Base(nn.Module):
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
|
is_causal=is_causal,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.n_experts > 1 and self.training:
|
if self.n_experts > 1 and self.training:
|
||||||
|
@ -1514,11 +1522,16 @@ class Base(nn.Module):
|
||||||
# needs to be done here as we still have our raw inputs
|
# needs to be done here as we still have our raw inputs
|
||||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||||
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
||||||
|
casual_levels = [ "AR:0:0", "stt", "len" ]
|
||||||
|
|
||||||
|
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
||||||
|
is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else None
|
||||||
|
|
||||||
output = self._forward(
|
output = self._forward(
|
||||||
inputs=x,
|
inputs=x,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
state=state,
|
state=state,
|
||||||
|
is_causal=is_causal,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
output_attentions = output_attentions,
|
output_attentions = output_attentions,
|
||||||
output_hidden_states = output_hidden_states,
|
output_hidden_states = output_hidden_states,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user