diff --git a/vall_e/models/arch/__init__.py b/vall_e/models/arch/__init__.py index f973264..3d37f5b 100755 --- a/vall_e/models/arch/__init__.py +++ b/vall_e/models/arch/__init__.py @@ -34,6 +34,7 @@ try: AVAILABLE_ARCHES.append("llama") except Exception as e: ERROR_ARCHES["llama"] = e + AVAILABLE_ARCHES = [] pass try: diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 5dff62d..f8712d1 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -20,9 +20,41 @@ try: except Exception as e: print("Error while querying for `flash_attention_2` support", e) +is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count())) is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count())) + try: - if not is_ampere_or_newer_gpu: + if is_rocm and False: + # try to use triton flash attention / fused attention + # currently only forward works, backwards throws an assert + # even then it's extremely slow on my 7900XTX so the provided code is probably botched since it's a benchmark sample + from einops import rearrange + from .triton_flash_attention import triton_attention, MetaData + + def flash_attn_func(q, k, v, softmax_scale=None, causal=False, *args, **kwargs): + metadata = MetaData(sm_scale=softmax_scale) + batch_size, seqlen_q, seqlen_k = q.shape[0], q.shape[1], k.shape[1] + + metadata.max_seqlens_q = seqlen_q + metadata.max_seqlens_k = seqlen_k + + # varlen but doesn't seem necessary + if False: + q, k, v = [rearrange(x, 'b s ... -> (b s) ...').contiguous() for x in [q, k, v]] + + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device) + cu_seqlens_k = cu_seqlens_q + + metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + + if causal: + metadata.need_causal() + + return triton_attention( q, k, v, None, metadata )[0] + + AVAILABLE_ATTENTIONS.append("flash_attn") + AVAILABLE_ATTENTIONS.append("flash_attn_rocm") + elif not is_ampere_or_newer_gpu: # Uses https://github.com/ZRayZzz/flash-attention-v100/ # Currently doesn't work because it's hard-coded to use a head dim of 128, will throw NaNs otherwise... from flash_attn_v100 import flash_attn_func as flash_attn_v100_func @@ -81,7 +113,7 @@ try: has_flash_attn = True has_flash_attn_with_paged = True except Exception as e: - print("Error while querying for `flash_attn` | support", e) + print("Error while querying for `flash_attn` support", e) try: from xformers.ops.fmha import memory_efficient_attention @@ -91,8 +123,9 @@ try: except Exception as e: print("Error while importing `xformers`", e) +# to-do: find a better way to query for if there's available kernels since these return true regardless if torch.backends.cuda.flash_sdp_enabled(): - AVAILABLE_ATTENTIONS.append("flash") + AVAILABLE_ATTENTIONS.append("flash") if torch.backends.cuda.mem_efficient_sdp_enabled(): AVAILABLE_ATTENTIONS.append("mem_efficient") diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 7eb70fc..ed96d54 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -514,8 +514,6 @@ class Base(nn.Module): # experimental NAR-only mode self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None - # there seems to have been a regression where anything touching the wrapped LlamaAttention class breaks - if attention_backend == "auto": if AVAILABLE_ATTENTIONS: attention_backend = AVAILABLE_ATTENTIONS[0]