do not include SDPA attention if there's no available SDPA backends
This commit is contained in:
parent
613024ec0d
commit
949339a3fa
|
@ -9,7 +9,7 @@ from transformers.cache_utils import Cache
|
||||||
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
AVAILABLE_ATTENTIONS = ["sdpa"]
|
AVAILABLE_ATTENTIONS = []
|
||||||
|
|
||||||
if torch.backends.cuda.flash_sdp_enabled():
|
if torch.backends.cuda.flash_sdp_enabled():
|
||||||
AVAILABLE_ATTENTIONS.append("flash")
|
AVAILABLE_ATTENTIONS.append("flash")
|
||||||
|
@ -20,6 +20,12 @@ if torch.backends.cuda.mem_efficient_sdp_enabled():
|
||||||
if torch.backends.cuda.math_sdp_enabled():
|
if torch.backends.cuda.math_sdp_enabled():
|
||||||
AVAILABLE_ATTENTIONS.append("math")
|
AVAILABLE_ATTENTIONS.append("math")
|
||||||
|
|
||||||
|
if torch.backends.cuda.cudnn_sdp_enabled():
|
||||||
|
AVAILABLE_ATTENTIONS.append("cudnn")
|
||||||
|
|
||||||
|
if AVAILABLE_ATTENTIONS:
|
||||||
|
AVAILABLE_ATTENTIONS.append("sdpa")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from xformers.ops import LowerTriangularMask
|
from xformers.ops import LowerTriangularMask
|
||||||
from xformers.ops.fmha import memory_efficient_attention
|
from xformers.ops.fmha import memory_efficient_attention
|
||||||
|
|
|
@ -522,8 +522,10 @@ class Base(nn.Module):
|
||||||
attention_backend = "mem_efficient"
|
attention_backend = "mem_efficient"
|
||||||
elif "math" in AVAILABLE_ATTENTIONS:
|
elif "math" in AVAILABLE_ATTENTIONS:
|
||||||
attention_backend = "math"
|
attention_backend = "math"
|
||||||
else:
|
elif "sdpa" in AVAILABLE_ATTENTIONS:
|
||||||
attention_backend = "sdpa"
|
attention_backend = "sdpa"
|
||||||
|
else:
|
||||||
|
attention_backend = "eager"
|
||||||
|
|
||||||
if attention_backend == "xformers":
|
if attention_backend == "xformers":
|
||||||
attention_backend = "mem_efficient"
|
attention_backend = "mem_efficient"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user