leverage between xformers and torch.backends.cuda.sdp_kernel for attention

This commit is contained in:
mrq 2024-05-11 17:14:05 -05:00
parent d33c7bb7cf
commit 3337c69e5a
4 changed files with 62 additions and 23 deletions

View File

@ -210,7 +210,7 @@ class Model:
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
attention: str = "eager" # or flash_attention_2
attention: str = "auto"
audio_embedding_sums: bool = True
dropout: float = 0.1 # adjustable dropout value

View File

@ -365,9 +365,11 @@ def example_usage():
'n_tokens': 1024,
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32
'n_layers': 8, # 32
'n_experts': 1,
'p_dropout': 0.0,
'l_padding': 8 if cfg.optimizations.fp8 else 0,
'config': cfg.model

View File

@ -135,15 +135,39 @@ except Exception as e:
print("Error importing `mixtral` arch:", e)
LLAMA_ATTENTIONS = {}
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
try:
from xformers.ops import LowerTriangularMask
from xformers.ops.fmha import memory_efficient_attention
AVAILABLE_ATTENTIONS.append("xformers")
except Exception as e:
print("Error while importing `xformers`", e)
try:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
AVAILABLE_ATTENTIONS.append("flash")
except Exception as e:
raise e
try:
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from xformers.ops import LowerTriangularMask
from xformers.ops.fmha import memory_efficient_attention
class LLamaXformersAttention(LlamaAttention):
class Llama_Attention(LlamaAttention):
def __init__(self, *args, **kwargs):
if 'mode' in kwargs:
self.mode = kwargs['mode']
kwargs.pop("mode")
else:
self.mode = "math"
super().__init__(*args, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
@ -181,11 +205,14 @@ try:
dropout_rate = self.attention_dropout if self.training else 0.0
# copied from https://github.com/oobabooga/text-generation-webui/pull/950/files
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
if self.mode == "xformers":
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
else:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
else:
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)
attn_weights = None
@ -194,7 +221,7 @@ try:
return attn_output, attn_weights, past_key_value
except Exception as e:
print("Error creating `LLamaXformersAttention`:", e)
print("Error creating modified `LLamaAttention`:", e)
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
@ -449,12 +476,21 @@ class Base(nn.Module):
self.sep = nn.Parameter(torch.randn(d_model))
# ick, there has to be a better way
attention = self.config.attention if self.config is not None else None
use_xformers = False
hf_attention = self.config.attention if self.config is not None else None
if self.config.attention == "auto":
if "flash" in AVAILABLE_ATTENTIONS:
self.config.attention = "flash"
elif "xformers" in AVAILABLE_ATTENTIONS:
self.config.attention = "xformers"
else:
self.config.attention = "mem_efficient"
if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]:
hf_attention = None
if self.config.attention not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
if attention == "xformers":
use_xformers = True
attention = None
if self.arch_type == "transformer":
self.sin_emb = SinusoidalEmbedding(d_model)
@ -480,7 +516,7 @@ class Base(nn.Module):
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=attention,
attn_implementation=hf_attention,
#gradient_checkpointing=self.activation_checkpointing,
))
else:
@ -500,7 +536,7 @@ class Base(nn.Module):
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=attention,
attn_implementation=hf_attention,
#gradient_checkpointing=self.activation_checkpointing,
))
@ -526,7 +562,7 @@ class Base(nn.Module):
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=attention,
attn_implementation=hf_attention,
#gradient_checkpointing=self.activation_checkpointing,
))
else:
@ -546,7 +582,7 @@ class Base(nn.Module):
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=attention,
attn_implementation=hf_attention,
#gradient_checkpointing=self.activation_checkpointing,
))
@ -621,8 +657,8 @@ class Base(nn.Module):
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
if use_xformers:
self.model = ml.replace_attention( self.model, klass=LLamaXformersAttention, target=LlamaAttention )
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention )
self.classifier = nn.Linear(d_model, n_resp_tokens)

View File

@ -183,7 +183,7 @@ def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbos
return model
# cannot feasibly do default arguments here sad
def replace_attention( model, klass, target, verbose=False ):
def replace_attention( model, klass, target, mode="math", verbose=False ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
@ -199,6 +199,7 @@ def replace_attention( model, klass, target, verbose=False ):
kwargs = dict(
config = m.config,
layer_idx = m.layer_idx,
mode = mode,
)
# overwrite
setattr(