sageattn (forgot to bother with testing this the other day, seems ifne)

This commit is contained in:
mrq 2024-12-03 15:14:57 -06:00
parent 31ab90d84a
commit ca31da0a95
3 changed files with 22 additions and 5 deletions

View File

@ -334,6 +334,8 @@ A bulk of it pertains to modifying `LlamaAttention` and detecting available atte
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper) * `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
* `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently) * `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently)
* `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while * `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
* `sageattn`: uses [SageAttention](https://github.com/thu-ml/SageAttention).
* training under this is untested, but dropout is not applied (yet).
* `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed) * `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed)
* `transformers` Llama\*Attention implementations: * `transformers` Llama\*Attention implementations:
* `eager`: default `LlamaAttention` * `eager`: default `LlamaAttention`

View File

@ -86,6 +86,7 @@ setup(
# attention helpers # attention helpers
"xformers", "xformers",
"sageattention==1.0.6",
# "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it # "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it
# other audio backend that doesn't prove fruitful # other audio backend that doesn't prove fruitful

View File

@ -22,6 +22,13 @@ AVAILABLE_ATTENTIONS = []
LN_2 = 0.69314718056 LN_2 = 0.69314718056
try:
from sageattention import sageattn
AVAILABLE_ATTENTIONS.append("sageattn")
except Exception as e:
_logger.warning(f"Error while querying for `sageattn` support: {str(e)}")
try: try:
from torch.nn.attention.flex_attention import flex_attention, create_block_mask from torch.nn.attention.flex_attention import flex_attention, create_block_mask
@ -390,11 +397,15 @@ class LlamaAttention_Adapted(LlamaAttention):
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment if mode in ["sageattn"]:
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. attn_output = sageattn(
# is_causal = True if x_mask is None and q_len > 1 else False query_states,
key_states,
if mode in ["fused_attn"]: value_states,
tensor_layout="HND",
is_causal=is_causal
)
elif mode in ["fused_attn"]:
attn_output = fused_attn_func( attn_output = fused_attn_func(
query_states, query_states,
key_states, key_states,
@ -418,6 +429,9 @@ class LlamaAttention_Adapted(LlamaAttention):
f" {attn_output.size()}" f" {attn_output.size()}"
) )
else: else:
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# is_causal = True if x_mask is None and q_len > 1 else False
is_causal = True if x_mask is None and q_len > 1 else False is_causal = True if x_mask is None and q_len > 1 else False
with torch.nn.attention.sdpa_kernel(self.mode): with torch.nn.attention.sdpa_kernel(self.mode):
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(