From 3337c69e5ac26c296f9dbef00635a0ea64edb76a Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 11 May 2024 17:14:05 -0500 Subject: [PATCH] leverage between xformers and `torch.backends.cuda.sdp_kernel` for attention --- vall_e/config.py | 2 +- vall_e/models/ar_nar.py | 4 ++- vall_e/models/base.py | 76 ++++++++++++++++++++++++++++++----------- vall_e/utils/wrapper.py | 3 +- 4 files changed, 62 insertions(+), 23 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 8185d47..128cbf1 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -210,7 +210,7 @@ class Model: interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results) p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training - attention: str = "eager" # or flash_attention_2 + attention: str = "auto" audio_embedding_sums: bool = True dropout: float = 0.1 # adjustable dropout value diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index d9a3127..848100d 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -365,9 +365,11 @@ def example_usage(): 'n_tokens': 1024, 'd_model': 1024, # 256, # 1024, # 1536 'n_heads': 16, # 4, # 16, # 24 - 'n_layers': 12, # 32 + 'n_layers': 8, # 32 'n_experts': 1, + 'p_dropout': 0.0, + 'l_padding': 8 if cfg.optimizations.fp8 else 0, 'config': cfg.model diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 8f0d172..6ee5862 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -135,15 +135,39 @@ except Exception as e: print("Error importing `mixtral` arch:", e) -LLAMA_ATTENTIONS = {} +AVAILABLE_ATTENTIONS = ["mem_efficient", "math"] + +try: + from xformers.ops import LowerTriangularMask + from xformers.ops.fmha import memory_efficient_attention + + AVAILABLE_ATTENTIONS.append("xformers") +except Exception as e: + print("Error while importing `xformers`", e) + +try: + from transformers.utils import is_flash_attn_2_available + + if is_flash_attn_2_available(): + AVAILABLE_ATTENTIONS.append("flash") +except Exception as e: + raise e + try: from transformers.cache_utils import Cache from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from xformers.ops import LowerTriangularMask - from xformers.ops.fmha import memory_efficient_attention - class LLamaXformersAttention(LlamaAttention): + class Llama_Attention(LlamaAttention): + def __init__(self, *args, **kwargs): + if 'mode' in kwargs: + self.mode = kwargs['mode'] + kwargs.pop("mode") + else: + self.mode = "math" + + super().__init__(*args, **kwargs) + def forward( self, hidden_states: torch.Tensor, @@ -181,11 +205,14 @@ try: dropout_rate = self.attention_dropout if self.training else 0.0 - # copied from https://github.com/oobabooga/text-generation-webui/pull/950/files - if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: - attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None) + if self.mode == "xformers": + if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: + attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None) + else: + attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask()) else: - attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask()) + with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"): + attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) attn_weights = None @@ -194,7 +221,7 @@ try: return attn_output, attn_weights, past_key_value except Exception as e: - print("Error creating `LLamaXformersAttention`:", e) + print("Error creating modified `LLamaAttention`:", e) def _create_mask(l, device): """1 is valid region and 0 is invalid.""" @@ -449,12 +476,21 @@ class Base(nn.Module): self.sep = nn.Parameter(torch.randn(d_model)) # ick, there has to be a better way - attention = self.config.attention if self.config is not None else None - use_xformers = False + hf_attention = self.config.attention if self.config is not None else None + + if self.config.attention == "auto": + if "flash" in AVAILABLE_ATTENTIONS: + self.config.attention = "flash" + elif "xformers" in AVAILABLE_ATTENTIONS: + self.config.attention = "xformers" + else: + self.config.attention = "mem_efficient" + + if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]: + hf_attention = None + if self.config.attention not in AVAILABLE_ATTENTIONS: + raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") - if attention == "xformers": - use_xformers = True - attention = None if self.arch_type == "transformer": self.sin_emb = SinusoidalEmbedding(d_model) @@ -480,7 +516,7 @@ class Base(nn.Module): hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, - attn_implementation=attention, + attn_implementation=hf_attention, #gradient_checkpointing=self.activation_checkpointing, )) else: @@ -500,7 +536,7 @@ class Base(nn.Module): is_decoder=True, num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), - attn_implementation=attention, + attn_implementation=hf_attention, #gradient_checkpointing=self.activation_checkpointing, )) @@ -526,7 +562,7 @@ class Base(nn.Module): hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, - attn_implementation=attention, + attn_implementation=hf_attention, #gradient_checkpointing=self.activation_checkpointing, )) else: @@ -546,7 +582,7 @@ class Base(nn.Module): is_decoder=True, num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), - attn_implementation=attention, + attn_implementation=hf_attention, #gradient_checkpointing=self.activation_checkpointing, )) @@ -621,8 +657,8 @@ class Base(nn.Module): else: raise RuntimeError(f'Unknown arch specified: {self.arch_type}') - if use_xformers: - self.model = ml.replace_attention( self.model, klass=LLamaXformersAttention, target=LlamaAttention ) + if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: + self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention ) self.classifier = nn.Linear(d_model, n_resp_tokens) diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index fcd1275..275d5a4 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -183,7 +183,7 @@ def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbos return model # cannot feasibly do default arguments here sad -def replace_attention( model, klass, target, verbose=False ): +def replace_attention( model, klass, target, mode="math", verbose=False ): device = next(model.parameters()).device dtype = next(model.parameters()).dtype modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] @@ -199,6 +199,7 @@ def replace_attention( model, klass, target, verbose=False ): kwargs = dict( config = m.config, layer_idx = m.layer_idx, + mode = mode, ) # overwrite setattr(