leverage between xformers and torch.backends.cuda.sdp_kernel
for attention
This commit is contained in:
parent
d33c7bb7cf
commit
3337c69e5a
|
@ -210,7 +210,7 @@ class Model:
|
|||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
|
||||
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
|
||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||
attention: str = "eager" # or flash_attention_2
|
||||
attention: str = "auto"
|
||||
audio_embedding_sums: bool = True
|
||||
dropout: float = 0.1 # adjustable dropout value
|
||||
|
||||
|
|
|
@ -365,9 +365,11 @@ def example_usage():
|
|||
'n_tokens': 1024,
|
||||
'd_model': 1024, # 256, # 1024, # 1536
|
||||
'n_heads': 16, # 4, # 16, # 24
|
||||
'n_layers': 12, # 32
|
||||
'n_layers': 8, # 32
|
||||
'n_experts': 1,
|
||||
|
||||
'p_dropout': 0.0,
|
||||
|
||||
'l_padding': 8 if cfg.optimizations.fp8 else 0,
|
||||
|
||||
'config': cfg.model
|
||||
|
|
|
@ -135,15 +135,39 @@ except Exception as e:
|
|||
print("Error importing `mixtral` arch:", e)
|
||||
|
||||
|
||||
LLAMA_ATTENTIONS = {}
|
||||
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
|
||||
|
||||
try:
|
||||
from xformers.ops import LowerTriangularMask
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
|
||||
AVAILABLE_ATTENTIONS.append("xformers")
|
||||
except Exception as e:
|
||||
print("Error while importing `xformers`", e)
|
||||
|
||||
try:
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
AVAILABLE_ATTENTIONS.append("flash")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
from xformers.ops import LowerTriangularMask
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
|
||||
class LLamaXformersAttention(LlamaAttention):
|
||||
class Llama_Attention(LlamaAttention):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if 'mode' in kwargs:
|
||||
self.mode = kwargs['mode']
|
||||
kwargs.pop("mode")
|
||||
else:
|
||||
self.mode = "math"
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
@ -181,11 +205,14 @@ try:
|
|||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# copied from https://github.com/oobabooga/text-generation-webui/pull/950/files
|
||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
||||
if self.mode == "xformers":
|
||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
||||
else:
|
||||
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
|
||||
else:
|
||||
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)
|
||||
|
||||
attn_weights = None
|
||||
|
||||
|
@ -194,7 +221,7 @@ try:
|
|||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
except Exception as e:
|
||||
print("Error creating `LLamaXformersAttention`:", e)
|
||||
print("Error creating modified `LLamaAttention`:", e)
|
||||
|
||||
def _create_mask(l, device):
|
||||
"""1 is valid region and 0 is invalid."""
|
||||
|
@ -449,12 +476,21 @@ class Base(nn.Module):
|
|||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
# ick, there has to be a better way
|
||||
attention = self.config.attention if self.config is not None else None
|
||||
use_xformers = False
|
||||
hf_attention = self.config.attention if self.config is not None else None
|
||||
|
||||
if self.config.attention == "auto":
|
||||
if "flash" in AVAILABLE_ATTENTIONS:
|
||||
self.config.attention = "flash"
|
||||
elif "xformers" in AVAILABLE_ATTENTIONS:
|
||||
self.config.attention = "xformers"
|
||||
else:
|
||||
self.config.attention = "mem_efficient"
|
||||
|
||||
if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]:
|
||||
hf_attention = None
|
||||
if self.config.attention not in AVAILABLE_ATTENTIONS:
|
||||
raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
|
||||
|
||||
if attention == "xformers":
|
||||
use_xformers = True
|
||||
attention = None
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
self.sin_emb = SinusoidalEmbedding(d_model)
|
||||
|
@ -480,7 +516,7 @@ class Base(nn.Module):
|
|||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=attention,
|
||||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.activation_checkpointing,
|
||||
))
|
||||
else:
|
||||
|
@ -500,7 +536,7 @@ class Base(nn.Module):
|
|||
is_decoder=True,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=min(2, n_experts),
|
||||
attn_implementation=attention,
|
||||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.activation_checkpointing,
|
||||
))
|
||||
|
||||
|
@ -526,7 +562,7 @@ class Base(nn.Module):
|
|||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=attention,
|
||||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.activation_checkpointing,
|
||||
))
|
||||
else:
|
||||
|
@ -546,7 +582,7 @@ class Base(nn.Module):
|
|||
is_decoder=True,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=min(2, n_experts),
|
||||
attn_implementation=attention,
|
||||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.activation_checkpointing,
|
||||
))
|
||||
|
||||
|
@ -621,8 +657,8 @@ class Base(nn.Module):
|
|||
else:
|
||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||
|
||||
if use_xformers:
|
||||
self.model = ml.replace_attention( self.model, klass=LLamaXformersAttention, target=LlamaAttention )
|
||||
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
||||
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention )
|
||||
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
|
||||
|
|
|
@ -183,7 +183,7 @@ def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbos
|
|||
return model
|
||||
|
||||
# cannot feasibly do default arguments here sad
|
||||
def replace_attention( model, klass, target, verbose=False ):
|
||||
def replace_attention( model, klass, target, mode="math", verbose=False ):
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||
|
@ -199,6 +199,7 @@ def replace_attention( model, klass, target, verbose=False ):
|
|||
kwargs = dict(
|
||||
config = m.config,
|
||||
layer_idx = m.layer_idx,
|
||||
mode = mode,
|
||||
)
|
||||
# overwrite
|
||||
setattr(
|
||||
|
|
Loading…
Reference in New Issue
Block a user