From ed6b7a690f17573e7e58c7ce15b1566f70b6be05 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 13 Oct 2024 00:26:46 -0500 Subject: [PATCH] ugh......... --- vall_e/models/arch/llama.py | 112 ++++++------------------------------ 1 file changed, 19 insertions(+), 93 deletions(-) diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index ea7807f..fc27fe9 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -147,79 +147,6 @@ class LlamaAttention_Adapted(LlamaAttention): self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION super().__init__(*args, **kwargs) - - # Adapted from LlamaAttention.forward, this doesn't seem to give great output...... - def _forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - dropout_rate = self.attention_dropout if self.training else 0.0 - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - attn_scores = attn_weights - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - attn_scores = None - - return attn_output, attn_scores, past_key_value - # Adapted from LlamaAttention.forward def forward( @@ -234,19 +161,7 @@ class LlamaAttention_Adapted(LlamaAttention): position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if not self.mode: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - return self._forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - + mode = "default" if output_attentions else self.mode dropout_rate = self.attention_dropout if self.training else 0.0 bsz, q_len, _ = hidden_states.size() @@ -269,10 +184,7 @@ class LlamaAttention_Adapted(LlamaAttention): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if self.mode in ["xformers", "flash_attn"]: - if output_attentions: - attn_scores = torch.matmul(query_states, repeat_kv(key_states, self.num_key_value_groups).transpose(2, 3)) / math.sqrt(self.head_dim) - + if mode in ["xformers", "flash_attn"]: # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) @@ -301,7 +213,7 @@ class LlamaAttention_Adapted(LlamaAttention): value_states = value_states.to(target_dtype) """ - if self.mode == "flash_attn": + if mode == "flash_attn": attn_output = flash_attn_func( query_states, key_states, @@ -312,7 +224,7 @@ class LlamaAttention_Adapted(LlamaAttention): ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - elif self.mode == "xformers": + elif mode == "xformers": attn_output = memory_efficient_attention( query_states, key_states, @@ -347,7 +259,7 @@ class LlamaAttention_Adapted(LlamaAttention): # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False - if self.mode in ["fused_attn"]: + if mode in ["fused_attn"]: attn_output = fused_attn_func( query_states, key_states, @@ -356,6 +268,20 @@ class LlamaAttention_Adapted(LlamaAttention): softmax_scale=1.0 / math.sqrt(self.head_dim), dropout_p=dropout_rate, ) + elif mode in ["default"]: + attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + # cringe logic + attn_weights = (attn_scores + causal_mask) if attention_mask is not None else (attn_scores) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) else: with torch.nn.attention.sdpa_kernel(self.mode): attn_output = torch.nn.functional.scaled_dot_product_attention(