sageattn (forgot to bother with testing this the other day, seems ifne)
This commit is contained in:
parent
31ab90d84a
commit
ca31da0a95
|
@ -334,6 +334,8 @@ A bulk of it pertains to modifying `LlamaAttention` and detecting available atte
|
|||
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
|
||||
* `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently)
|
||||
* `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
|
||||
* `sageattn`: uses [SageAttention](https://github.com/thu-ml/SageAttention).
|
||||
* training under this is untested, but dropout is not applied (yet).
|
||||
* `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed)
|
||||
* `transformers` Llama\*Attention implementations:
|
||||
* `eager`: default `LlamaAttention`
|
||||
|
|
1
setup.py
1
setup.py
|
@ -86,6 +86,7 @@ setup(
|
|||
|
||||
# attention helpers
|
||||
"xformers",
|
||||
"sageattention==1.0.6",
|
||||
# "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it
|
||||
|
||||
# other audio backend that doesn't prove fruitful
|
||||
|
|
|
@ -22,6 +22,13 @@ AVAILABLE_ATTENTIONS = []
|
|||
|
||||
LN_2 = 0.69314718056
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
|
||||
AVAILABLE_ATTENTIONS.append("sageattn")
|
||||
except Exception as e:
|
||||
_logger.warning(f"Error while querying for `sageattn` support: {str(e)}")
|
||||
|
||||
try:
|
||||
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
|
||||
|
@ -390,11 +397,15 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# is_causal = True if x_mask is None and q_len > 1 else False
|
||||
|
||||
if mode in ["fused_attn"]:
|
||||
if mode in ["sageattn"]:
|
||||
attn_output = sageattn(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
tensor_layout="HND",
|
||||
is_causal=is_causal
|
||||
)
|
||||
elif mode in ["fused_attn"]:
|
||||
attn_output = fused_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
|
@ -418,6 +429,9 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
f" {attn_output.size()}"
|
||||
)
|
||||
else:
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# is_causal = True if x_mask is None and q_len > 1 else False
|
||||
is_causal = True if x_mask is None and q_len > 1 else False
|
||||
with torch.nn.attention.sdpa_kernel(self.mode):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
|
|
Loading…
Reference in New Issue
Block a user