do not include SDPA attention if there's no available SDPA backends

This commit is contained in:
mrq 2024-08-06 20:42:39 -05:00
parent 613024ec0d
commit 949339a3fa
2 changed files with 10 additions and 2 deletions

View File

@ -9,7 +9,7 @@ from transformers.cache_utils import Cache
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
AVAILABLE_ATTENTIONS = ["sdpa"] AVAILABLE_ATTENTIONS = []
if torch.backends.cuda.flash_sdp_enabled(): if torch.backends.cuda.flash_sdp_enabled():
AVAILABLE_ATTENTIONS.append("flash") AVAILABLE_ATTENTIONS.append("flash")
@ -20,6 +20,12 @@ if torch.backends.cuda.mem_efficient_sdp_enabled():
if torch.backends.cuda.math_sdp_enabled(): if torch.backends.cuda.math_sdp_enabled():
AVAILABLE_ATTENTIONS.append("math") AVAILABLE_ATTENTIONS.append("math")
if torch.backends.cuda.cudnn_sdp_enabled():
AVAILABLE_ATTENTIONS.append("cudnn")
if AVAILABLE_ATTENTIONS:
AVAILABLE_ATTENTIONS.append("sdpa")
try: try:
from xformers.ops import LowerTriangularMask from xformers.ops import LowerTriangularMask
from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha import memory_efficient_attention

View File

@ -522,8 +522,10 @@ class Base(nn.Module):
attention_backend = "mem_efficient" attention_backend = "mem_efficient"
elif "math" in AVAILABLE_ATTENTIONS: elif "math" in AVAILABLE_ATTENTIONS:
attention_backend = "math" attention_backend = "math"
else: elif "sdpa" in AVAILABLE_ATTENTIONS:
attention_backend = "sdpa" attention_backend = "sdpa"
else:
attention_backend = "eager"
if attention_backend == "xformers": if attention_backend == "xformers":
attention_backend = "mem_efficient" attention_backend = "mem_efficient"