leverage between xformers and torch.backends.cuda.sdp_kernel
for attention
This commit is contained in:
parent
d33c7bb7cf
commit
3337c69e5a
@ -210,7 +210,7 @@ class Model:
|
|||||||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
|
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (experimental, worse performance and results)
|
||||||
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
|
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
|
||||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||||
attention: str = "eager" # or flash_attention_2
|
attention: str = "auto"
|
||||||
audio_embedding_sums: bool = True
|
audio_embedding_sums: bool = True
|
||||||
dropout: float = 0.1 # adjustable dropout value
|
dropout: float = 0.1 # adjustable dropout value
|
||||||
|
|
||||||
|
@ -365,9 +365,11 @@ def example_usage():
|
|||||||
'n_tokens': 1024,
|
'n_tokens': 1024,
|
||||||
'd_model': 1024, # 256, # 1024, # 1536
|
'd_model': 1024, # 256, # 1024, # 1536
|
||||||
'n_heads': 16, # 4, # 16, # 24
|
'n_heads': 16, # 4, # 16, # 24
|
||||||
'n_layers': 12, # 32
|
'n_layers': 8, # 32
|
||||||
'n_experts': 1,
|
'n_experts': 1,
|
||||||
|
|
||||||
|
'p_dropout': 0.0,
|
||||||
|
|
||||||
'l_padding': 8 if cfg.optimizations.fp8 else 0,
|
'l_padding': 8 if cfg.optimizations.fp8 else 0,
|
||||||
|
|
||||||
'config': cfg.model
|
'config': cfg.model
|
||||||
|
@ -135,15 +135,39 @@ except Exception as e:
|
|||||||
print("Error importing `mixtral` arch:", e)
|
print("Error importing `mixtral` arch:", e)
|
||||||
|
|
||||||
|
|
||||||
LLAMA_ATTENTIONS = {}
|
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xformers.ops import LowerTriangularMask
|
||||||
|
from xformers.ops.fmha import memory_efficient_attention
|
||||||
|
|
||||||
|
AVAILABLE_ATTENTIONS.append("xformers")
|
||||||
|
except Exception as e:
|
||||||
|
print("Error while importing `xformers`", e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
AVAILABLE_ATTENTIONS.append("flash")
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||||
|
|
||||||
from xformers.ops import LowerTriangularMask
|
|
||||||
from xformers.ops.fmha import memory_efficient_attention
|
|
||||||
|
|
||||||
class LLamaXformersAttention(LlamaAttention):
|
class Llama_Attention(LlamaAttention):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
if 'mode' in kwargs:
|
||||||
|
self.mode = kwargs['mode']
|
||||||
|
kwargs.pop("mode")
|
||||||
|
else:
|
||||||
|
self.mode = "math"
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -181,11 +205,14 @@ try:
|
|||||||
|
|
||||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
# copied from https://github.com/oobabooga/text-generation-webui/pull/950/files
|
if self.mode == "xformers":
|
||||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||||
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
||||||
|
else:
|
||||||
|
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
|
||||||
else:
|
else:
|
||||||
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
|
with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)
|
||||||
|
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
@ -194,7 +221,7 @@ try:
|
|||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error creating `LLamaXformersAttention`:", e)
|
print("Error creating modified `LLamaAttention`:", e)
|
||||||
|
|
||||||
def _create_mask(l, device):
|
def _create_mask(l, device):
|
||||||
"""1 is valid region and 0 is invalid."""
|
"""1 is valid region and 0 is invalid."""
|
||||||
@ -449,12 +476,21 @@ class Base(nn.Module):
|
|||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
# ick, there has to be a better way
|
# ick, there has to be a better way
|
||||||
attention = self.config.attention if self.config is not None else None
|
hf_attention = self.config.attention if self.config is not None else None
|
||||||
use_xformers = False
|
|
||||||
|
if self.config.attention == "auto":
|
||||||
|
if "flash" in AVAILABLE_ATTENTIONS:
|
||||||
|
self.config.attention = "flash"
|
||||||
|
elif "xformers" in AVAILABLE_ATTENTIONS:
|
||||||
|
self.config.attention = "xformers"
|
||||||
|
else:
|
||||||
|
self.config.attention = "mem_efficient"
|
||||||
|
|
||||||
|
if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]:
|
||||||
|
hf_attention = None
|
||||||
|
if self.config.attention not in AVAILABLE_ATTENTIONS:
|
||||||
|
raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
|
||||||
|
|
||||||
if attention == "xformers":
|
|
||||||
use_xformers = True
|
|
||||||
attention = None
|
|
||||||
|
|
||||||
if self.arch_type == "transformer":
|
if self.arch_type == "transformer":
|
||||||
self.sin_emb = SinusoidalEmbedding(d_model)
|
self.sin_emb = SinusoidalEmbedding(d_model)
|
||||||
@ -480,7 +516,7 @@ class Base(nn.Module):
|
|||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
attn_implementation=attention,
|
attn_implementation=hf_attention,
|
||||||
#gradient_checkpointing=self.activation_checkpointing,
|
#gradient_checkpointing=self.activation_checkpointing,
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
@ -500,7 +536,7 @@ class Base(nn.Module):
|
|||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
num_local_experts=n_experts,
|
num_local_experts=n_experts,
|
||||||
num_experts_per_tok=min(2, n_experts),
|
num_experts_per_tok=min(2, n_experts),
|
||||||
attn_implementation=attention,
|
attn_implementation=hf_attention,
|
||||||
#gradient_checkpointing=self.activation_checkpointing,
|
#gradient_checkpointing=self.activation_checkpointing,
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -526,7 +562,7 @@ class Base(nn.Module):
|
|||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
attn_implementation=attention,
|
attn_implementation=hf_attention,
|
||||||
#gradient_checkpointing=self.activation_checkpointing,
|
#gradient_checkpointing=self.activation_checkpointing,
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
@ -546,7 +582,7 @@ class Base(nn.Module):
|
|||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
num_local_experts=n_experts,
|
num_local_experts=n_experts,
|
||||||
num_experts_per_tok=min(2, n_experts),
|
num_experts_per_tok=min(2, n_experts),
|
||||||
attn_implementation=attention,
|
attn_implementation=hf_attention,
|
||||||
#gradient_checkpointing=self.activation_checkpointing,
|
#gradient_checkpointing=self.activation_checkpointing,
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -621,8 +657,8 @@ class Base(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||||
|
|
||||||
if use_xformers:
|
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
||||||
self.model = ml.replace_attention( self.model, klass=LLamaXformersAttention, target=LlamaAttention )
|
self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention )
|
||||||
|
|
||||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||||
|
|
||||||
|
@ -183,7 +183,7 @@ def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbos
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
# cannot feasibly do default arguments here sad
|
# cannot feasibly do default arguments here sad
|
||||||
def replace_attention( model, klass, target, verbose=False ):
|
def replace_attention( model, klass, target, mode="math", verbose=False ):
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
dtype = next(model.parameters()).dtype
|
dtype = next(model.parameters()).dtype
|
||||||
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||||
@ -199,6 +199,7 @@ def replace_attention( model, klass, target, verbose=False ):
|
|||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
config = m.config,
|
config = m.config,
|
||||||
layer_idx = m.layer_idx,
|
layer_idx = m.layer_idx,
|
||||||
|
mode = mode,
|
||||||
)
|
)
|
||||||
# overwrite
|
# overwrite
|
||||||
setattr(
|
setattr(
|
||||||
|
Loading…
Reference in New Issue
Block a user