diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 69aa8d6..c701dd8 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -149,11 +149,71 @@ class LlamaAttention_Adapted(LlamaAttention): super().__init__(*args, **kwargs) + # extracts inputs from a batch based on requested causality + def split_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + is_causal: Optional[list] = None, + target_causal_state: Optional[bool] = True, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ): + indices = [ i for i, state in enumerate( is_causal ) if state == target_causal_state ] + + # no matching inputs in batch + if not indices: + return indices, None, None, None + + # entire batch is homogenous + if len( indices ) == hidden_states.shape[0]: + output_hidden_states, output_self_attn_weights, output_present_key_values = self.forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + is_causal=target_causal_state, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=False, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + return indices, output_hidden_states, output_self_attn_weights, output_present_key_values + + input_hidden_states = torch.stack( [ hidden_states[i] for i in indices ] ) + input_attention_mask = torch.stack( [ attention_mask[i] for i in indices ] ) if attention_mask is not None else None + input_position_ids = torch.stack( [ position_ids[i] for i in indices ] ) if position_ids is not None else None + input_position_embeddings = ( + torch.stack( [ position_embeddings[0][i] for i in indices ] ), + torch.stack( [ position_embeddings[1][i] for i in indices ] ), + ) if position_embeddings is not None else None + + output_hidden_states, output_self_attn_weights, output_present_key_values = self.forward( + hidden_states=input_hidden_states, + attention_mask=input_attention_mask, + is_causal=target_causal_state, + position_ids=input_position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=False, + cache_position=cache_position, + position_embeddings=input_position_embeddings, + **kwargs, + ) + return indices, output_hidden_states, output_self_attn_weights, output_present_key_values + # Adapted from LlamaAttention.forward def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = True, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, @@ -163,6 +223,94 @@ class LlamaAttention_Adapted(LlamaAttention): **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: mode = "default" if output_attentions else self.mode + + # split per batch because other attention mechanisms do not have a conditional is_causal per-batch, only for the entire input + if isinstance( is_causal, list ) and mode not in ["default"]: + # initialize lists + attn_hidden_states = [ None for _ in is_causal ] + self_attn_weights = [ None for _ in is_causal ] + present_key_values = [ None for _ in is_causal ] + + # process causal inputs in a batch + causal_indices, causal_hidden_states, causal_self_attn_weights, causal_present_key_values = self.split_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + is_causal=is_causal, + target_causal_state=True, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=False, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # process non-causal inputs in a batch + non_causal_indices, non_causal_hidden_states, non_causal_self_attn_weights, non_causal_present_key_values = self.split_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + is_causal=is_causal, + target_causal_state=False, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=False, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + # insert causal outputs to batch + for i, idx in enumerate( causal_indices ): + attn_hidden_states[idx] = causal_hidden_states[i] + + if output_attentions: + self_attn_weights[idx] = causal_self_attn_weights[i] + + # insert non-causal outputs to batch + for i, idx in enumerate( non_causal_indices ): + attn_hidden_states[idx] = non_causal_hidden_states[i] + + if output_attentions: + self_attn_weights[idx] = non_causal_self_attn_weights[i] + + # combine list + attn_hidden_states = torch.stack( attn_hidden_states, dim=0 ) + if output_attentions: + self_attn_weights = torch.stack( self_attn_weights, dim=0 ) + + return attn_hidden_states, output_attentions, [] + + """ + h_s = [] + s_a_w = [] + p_k_v = [] + + for i, state in enumerate(is_causal): + hidden_state, self_attn_weight, present_key_value = self.forward( + hidden_states=hidden_states[i].unsqueeze(0), + attention_mask=attention_mask[i].unsqueeze(0), + is_causal=state, + position_ids=position_ids[i].unsqueeze(0), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=False, + cache_position=cache_position, + position_embeddings=(position_embeddings[0][i].unsqueeze(0), position_embeddings[1][i].unsqueeze(0)) if position_embeddings is not None else None, + **kwargs, + ) + h_s.append(hidden_state) + s_a_w.append(self_attn_weight) + p_k_v.append(present_key_value) + + return ( + torch.concat( h_s, dim=0 ), + torch.concat( s_a_w, dim=0 ) if s_a_w else None, + p_k_v, + ) + """ + dropout_rate = self.attention_dropout if self.training else 0.0 bsz, q_len, _ = hidden_states.size() @@ -221,7 +369,7 @@ class LlamaAttention_Adapted(LlamaAttention): query_states, key_states, value_states, - causal=True, + causal=is_causal, softmax_scale=1.0 / math.sqrt(self.head_dim), dropout_p=dropout_rate, ) @@ -232,7 +380,7 @@ class LlamaAttention_Adapted(LlamaAttention): query_states, key_states, value_states, - attn_bias = LowerTriangularMask() if attention_mask is None or attention_mask[0, 0, 0, 1] == 0 else None, + attn_bias = LowerTriangularMask(), scale = 1.0 / math.sqrt(self.head_dim), p=dropout_rate ) @@ -258,14 +406,14 @@ class LlamaAttention_Adapted(LlamaAttention): # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if x_mask is None and q_len > 1 else False + # is_causal = True if x_mask is None and q_len > 1 else False if mode in ["fused_attn"]: attn_output = fused_attn_func( query_states, key_states, value_states, - causal=True, + causal=is_causal, softmax_scale=1.0 / math.sqrt(self.head_dim), dropout_p=dropout_rate, ) @@ -284,6 +432,7 @@ class LlamaAttention_Adapted(LlamaAttention): f" {attn_output.size()}" ) else: + is_causal = True if x_mask is None and q_len > 1 else False with torch.nn.attention.sdpa_kernel(self.mode): attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, @@ -332,6 +481,7 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer): self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + is_causal: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, @@ -371,6 +521,7 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer): hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, + is_causal=is_causal, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, @@ -588,6 +739,7 @@ class LlamaModel_Adapted(LlamaModel): decoder_layer.__call__, hidden_states, x_mask, + is_causal, position_ids, past_key_values, output_attentions, @@ -600,6 +752,7 @@ class LlamaModel_Adapted(LlamaModel): layer_outputs = decoder_layer( hidden_states, attention_mask=x_mask, + is_causal=is_causal, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index b43fa07..5ece310 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -487,8 +487,10 @@ class Base(nn.Module): self.noncausal_masks = noncausal_masks # use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends + """ if noncausal_masks: attention_backend = "default" + """ self.text_emb = Embedding(n_text_tokens, d_model) self.langs_emb = None