do not include SDPA attention if there's no available SDPA backends
This commit is contained in:
parent
613024ec0d
commit
949339a3fa
|
@ -9,7 +9,7 @@ from transformers.cache_utils import Cache
|
|||
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
AVAILABLE_ATTENTIONS = ["sdpa"]
|
||||
AVAILABLE_ATTENTIONS = []
|
||||
|
||||
if torch.backends.cuda.flash_sdp_enabled():
|
||||
AVAILABLE_ATTENTIONS.append("flash")
|
||||
|
@ -20,6 +20,12 @@ if torch.backends.cuda.mem_efficient_sdp_enabled():
|
|||
if torch.backends.cuda.math_sdp_enabled():
|
||||
AVAILABLE_ATTENTIONS.append("math")
|
||||
|
||||
if torch.backends.cuda.cudnn_sdp_enabled():
|
||||
AVAILABLE_ATTENTIONS.append("cudnn")
|
||||
|
||||
if AVAILABLE_ATTENTIONS:
|
||||
AVAILABLE_ATTENTIONS.append("sdpa")
|
||||
|
||||
try:
|
||||
from xformers.ops import LowerTriangularMask
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
|
|
|
@ -522,8 +522,10 @@ class Base(nn.Module):
|
|||
attention_backend = "mem_efficient"
|
||||
elif "math" in AVAILABLE_ATTENTIONS:
|
||||
attention_backend = "math"
|
||||
else:
|
||||
elif "sdpa" in AVAILABLE_ATTENTIONS:
|
||||
attention_backend = "sdpa"
|
||||
else:
|
||||
attention_backend = "eager"
|
||||
|
||||
if attention_backend == "xformers":
|
||||
attention_backend = "mem_efficient"
|
||||
|
|
Loading…
Reference in New Issue
Block a user