From 147219a5e0be2ac58d667ba5ba8aa5fe03a847ab Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 22 Nov 2024 13:44:43 -0600 Subject: [PATCH] huge oversight in the attention masking......... (i realized I have not been providing a non-causal mask to non-causal tasks) --- vall_e/config.py | 2 + vall_e/models/arch/llama.py | 75 ++++++++++++++++++++++++++++++------- vall_e/models/base.py | 13 +++++++ 3 files changed, 77 insertions(+), 13 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 8b0da14..e491b42 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -265,6 +265,8 @@ class ModelExperimentalSettings: masking_ratio: str | float = 0.8 # sets a masking ratio, "random" will randomly pick, "rand" will pick between [0.2, 0.8] ignore_inputs_for_loss: bool = True # only calculate the loss on the outputs since thats what matters, as the inputs that do have loss calculated upon affects the loss for the entire sequence + noncausal_masks: bool = False # to correct an oversight with Llama always using causal masks...... + # classifier-free guidance training settings cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training cfg_text_dropout_p: float = 0.0 # 0.0 # probability to drop out input audio prompt during training diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 607bbfb..69aa8d6 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -109,7 +109,6 @@ try: except Exception as e: _logger.warning(f"Error while querying for `flash_attn` support: {str(e)}") -""" try: from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask, LowerTriangularMask @@ -117,7 +116,6 @@ try: AVAILABLE_ATTENTIONS.append("xformers") except Exception as e: _logger.warning(f"Error while importing `xformers`: {str(e)}") -""" # to-do: find a better way to query for if there's available kernels since these return true regardless if torch.backends.cuda.flash_sdp_enabled(): @@ -246,20 +244,21 @@ class LlamaAttention_Adapted(LlamaAttention): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - causal_mask = attention_mask + x_mask = attention_mask + if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + x_mask = x_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: + if query_states.device.type == "cuda" and x_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False + is_causal = True if x_mask is None and q_len > 1 else False if mode in ["fused_attn"]: attn_output = fused_attn_func( @@ -273,7 +272,7 @@ class LlamaAttention_Adapted(LlamaAttention): elif mode in ["default"]: attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) # cringe logic - attn_weights = (attn_scores + causal_mask) if attention_mask is not None else (attn_scores) + attn_weights = (attn_scores + x_mask) if attention_mask is not None else (attn_scores) # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) @@ -290,7 +289,7 @@ class LlamaAttention_Adapted(LlamaAttention): query_states, key_states, value_states, - attn_mask=causal_mask, + attn_mask=x_mask, dropout_p=dropout_rate, is_causal=is_causal, ) @@ -458,10 +457,55 @@ class LlamaModel_Adapted(LlamaModel): return self.early_exit_scale * sum([ i for i in range(0, l) ]) return self.layers_n - 1 + self.early_exit_scale * sum([ i for i in range(0, self.layers_n - 1) ]) + # shamelessly borrowed from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 until I replace it with my own noncausal-mask maker + def _update_noncausal_mask( + self, attention_mask, inputs_embeds, past_key_values_length + ): + # create noncausal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + input_shape = (inputs_embeds.shape[0], inputs_embeds.shape[1]) + + def _expand_mask( + mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None + ): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = ( + mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + ) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + if attention_mask is None: + attention_mask = torch.ones( input_shape, dtype=torch.bool, device=inputs_embeds.device ) + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, + is_causal: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -517,9 +561,14 @@ class LlamaModel_Adapted(LlamaModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + # because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch + if is_causal is not None: + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions) + noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values) + x_mask = torch.stack( [ causal_mask[i, :, :, :] if state else noncausal_mask[i, :, :, :] for i, state in enumerate( is_causal ) ], dim=0 ) + else: + x_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -538,7 +587,7 @@ class LlamaModel_Adapted(LlamaModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + x_mask, position_ids, past_key_values, output_attentions, @@ -550,7 +599,7 @@ class LlamaModel_Adapted(LlamaModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=x_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index dd1ca93..b43fa07 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -435,6 +435,7 @@ class Base(nn.Module): audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True interleave = self.config.experimental.interleave if self.config is not None else False + noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False @@ -483,6 +484,11 @@ class Base(nn.Module): self.inject_timestep_embedding = False # results in bad output self.masking_ratio = masking_ratio self.ignore_inputs_for_loss = ignore_inputs_for_loss + self.noncausal_masks = noncausal_masks + + # use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends + if noncausal_masks: + attention_backend = "default" self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None @@ -773,6 +779,7 @@ class Base(nn.Module): self, inputs, mask = None, + is_causal = None, position_ids = None, state = None, @@ -800,6 +807,7 @@ class Base(nn.Module): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, + is_causal=is_causal, ) if self.n_experts > 1 and self.training: @@ -1514,11 +1522,16 @@ class Base(nn.Module): # needs to be done here as we still have our raw inputs position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None classifier_levels = self.get_input( inputs, name="classifier_level" ) + casual_levels = [ "AR:0:0", "stt", "len" ] + + # right now limit to new versions because I need to retrain the model for noncausal masks... + is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else None output = self._forward( inputs=x, mask=mask, state=state, + is_causal=is_causal, position_ids=position_ids, output_attentions = output_attentions, output_hidden_states = output_hidden_states,