From ca31da0a9589721627e44156af896418550a3a83 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 3 Dec 2024 15:14:57 -0600 Subject: [PATCH] sageattn (forgot to bother with testing this the other day, seems ifne) --- docs/models.md | 2 ++ setup.py | 1 + vall_e/models/arch/llama.py | 24 +++++++++++++++++++----- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/docs/models.md b/docs/models.md index 9d87cd0..093d565 100644 --- a/docs/models.md +++ b/docs/models.md @@ -334,6 +334,8 @@ A bulk of it pertains to modifying `LlamaAttention` and detecting available atte * `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper) * `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently) * `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while + * `sageattn`: uses [SageAttention](https://github.com/thu-ml/SageAttention). + * training under this is untested, but dropout is not applied (yet). * `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed) * `transformers` Llama\*Attention implementations: * `eager`: default `LlamaAttention` diff --git a/setup.py b/setup.py index 4190389..21d8fb4 100755 --- a/setup.py +++ b/setup.py @@ -86,6 +86,7 @@ setup( # attention helpers "xformers", + "sageattention==1.0.6", # "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it # other audio backend that doesn't prove fruitful diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 208c1de..e35981d 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -22,6 +22,13 @@ AVAILABLE_ATTENTIONS = [] LN_2 = 0.69314718056 +try: + from sageattention import sageattn + + AVAILABLE_ATTENTIONS.append("sageattn") +except Exception as e: + _logger.warning(f"Error while querying for `sageattn` support: {str(e)}") + try: from torch.nn.attention.flex_attention import flex_attention, create_block_mask @@ -390,11 +397,15 @@ class LlamaAttention_Adapted(LlamaAttention): key_states = key_states.contiguous() value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # is_causal = True if x_mask is None and q_len > 1 else False - - if mode in ["fused_attn"]: + if mode in ["sageattn"]: + attn_output = sageattn( + query_states, + key_states, + value_states, + tensor_layout="HND", + is_causal=is_causal + ) + elif mode in ["fused_attn"]: attn_output = fused_attn_func( query_states, key_states, @@ -418,6 +429,9 @@ class LlamaAttention_Adapted(LlamaAttention): f" {attn_output.size()}" ) else: + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # is_causal = True if x_mask is None and q_len > 1 else False is_causal = True if x_mask is None and q_len > 1 else False with torch.nn.attention.sdpa_kernel(self.mode): attn_output = torch.nn.functional.scaled_dot_product_attention(