diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 3193ab8..c70f0ff 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -375,7 +375,7 @@ def example_usage(): 'n_text_tokens': 256, 'n_audio_tokens': 1024, - 'd_model': 256, # 256, # 1024, # 1536 + 'd_model': 1024, # 256, # 1024, # 1536 'n_heads': 16, # 4, # 16, # 24 'n_layers': 12, # 32 'n_experts': 1 if not cfg.model else cfg.model.experts, @@ -468,8 +468,6 @@ def example_usage(): engine = Engine(model=model, optimizer=optimizer) engines = Engines({"ar+nar": engine}) engines.setup() - - print( model.state_dict().keys() ) """ if cfg.optimizations.model_offloading: diff --git a/vall_e/models/arch/__init__.py b/vall_e/models/arch/__init__.py index d7774a7..f973264 100755 --- a/vall_e/models/arch/__init__.py +++ b/vall_e/models/arch/__init__.py @@ -44,7 +44,7 @@ except Exception as e: pass try: - from .mixtral import MixtralModel, MixtralConfig, load_balancing_loss_func + from .mixtral import MixtralModel, MixtralConfig, MixtralAttention, MixtralAttention_Adapted, load_balancing_loss_func AVAILABLE_ARCHES.append("mixtral") except Exception as e: ERROR_ARCHES["mixtral"] = e diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 59309fc..2dc72aa 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -131,6 +131,8 @@ class LlamaAttention_Adapted(LlamaAttention): is_causal=is_causal, ) + print("attention") + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) diff --git a/vall_e/models/arch/mixtral.py b/vall_e/models/arch/mixtral.py index e02dc13..e4f6c12 100644 --- a/vall_e/models/arch/mixtral.py +++ b/vall_e/models/arch/mixtral.py @@ -2,9 +2,11 @@ import torch import torch.nn.functional as F +from typing import Literal, overload, Optional, Tuple +from transformers.cache_utils import Cache from transformers import MixtralModel, MixtralConfig -from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock +from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func, MixtralSparseMoeBlock, MixtralAttention, apply_rotary_pos_emb, repeat_kv # This is required because batch sizes > 1 throws errors def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -42,4 +44,197 @@ def MixtralSparseMoeBlock_forward(self, hidden_states: torch.Tensor) -> torch.Te final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits -MixtralSparseMoeBlock.forward = MixtralSparseMoeBlock_forward \ No newline at end of file +MixtralSparseMoeBlock.forward = MixtralSparseMoeBlock_forward + +class MixtralAttention_Adapted(MixtralAttention): + def __init__(self, *args, **kwargs): + if 'mode' in kwargs: + self.mode = kwargs['mode'] + kwargs.pop("mode") + else: + self.mode = "math" + + if self.mode == "math": + self.mode = torch.nn.attention.SDPBackend.MATH + elif self.mode == "mem_efficient": + self.mode = torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION + elif self.mode == "flash": + self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION + elif self.mode == "cudnn": + self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION + + super().__init__(*args, **kwargs) + + # Adapted from MixtralAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + """ + logger.warning_once( + "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + """ + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + #with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"): + with torch.nn.attention.sdpa_kernel(self.mode): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + """ + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + #with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"): + with torch.nn.attention.sdpa_kernel(self.mode): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + """ \ No newline at end of file diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 4b43e85..2a40f28 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -582,6 +582,8 @@ class Base(nn.Module): attn_implementation=hf_attention, #gradient_checkpointing=self.gradient_checkpointing, )) + if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]: + self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend ) if self.gradient_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( @@ -605,6 +607,8 @@ class Base(nn.Module): attn_implementation=hf_attention, #gradient_checkpointing=self.gradient_checkpointing, )) + if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]: + self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend ) else: self.model = MixtralModel(MixtralConfig( vocab_size =n_resp_tokens, @@ -625,6 +629,8 @@ class Base(nn.Module): attn_implementation=hf_attention, #gradient_checkpointing=self.gradient_checkpointing, )) + if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]: + self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend ) if self.gradient_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( @@ -753,8 +759,6 @@ class Base(nn.Module): if hasattr( self.model, "embeddings" ): del self.model.embeddings - if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]: - self.model = ml.replace_attention( self.model, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend ) if not split_classifiers: self.classifier = nn.Linear(d_model, n_resp_tokens)