From 949339a3facfbe6d45ca1a1e4b0d9da20a1e5c31 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 6 Aug 2024 20:42:39 -0500 Subject: [PATCH] do not include SDPA attention if there's no available SDPA backends --- vall_e/models/arch/llama.py | 8 +++++++- vall_e/models/base.py | 4 +++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 2dc72aa..b2ee4e0 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -9,7 +9,7 @@ from transformers.cache_utils import Cache from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv -AVAILABLE_ATTENTIONS = ["sdpa"] +AVAILABLE_ATTENTIONS = [] if torch.backends.cuda.flash_sdp_enabled(): AVAILABLE_ATTENTIONS.append("flash") @@ -20,6 +20,12 @@ if torch.backends.cuda.mem_efficient_sdp_enabled(): if torch.backends.cuda.math_sdp_enabled(): AVAILABLE_ATTENTIONS.append("math") +if torch.backends.cuda.cudnn_sdp_enabled(): + AVAILABLE_ATTENTIONS.append("cudnn") + +if AVAILABLE_ATTENTIONS: + AVAILABLE_ATTENTIONS.append("sdpa") + try: from xformers.ops import LowerTriangularMask from xformers.ops.fmha import memory_efficient_attention diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 74b07b9..89404c3 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -522,8 +522,10 @@ class Base(nn.Module): attention_backend = "mem_efficient" elif "math" in AVAILABLE_ATTENTIONS: attention_backend = "math" - else: + elif "sdpa" in AVAILABLE_ATTENTIONS: attention_backend = "sdpa" + else: + attention_backend = "eager" if attention_backend == "xformers": attention_backend = "mem_efficient"