sageattn (forgot to bother with testing this the other day, seems ifne)
This commit is contained in:
parent
31ab90d84a
commit
ca31da0a95
|
@ -334,6 +334,8 @@ A bulk of it pertains to modifying `LlamaAttention` and detecting available atte
|
||||||
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
|
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
|
||||||
* `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently)
|
* `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently)
|
||||||
* `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
|
* `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
|
||||||
|
* `sageattn`: uses [SageAttention](https://github.com/thu-ml/SageAttention).
|
||||||
|
* training under this is untested, but dropout is not applied (yet).
|
||||||
* `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed)
|
* `default`: uses the naive path for hte internal implementation (used for attention-debugging purposed)
|
||||||
* `transformers` Llama\*Attention implementations:
|
* `transformers` Llama\*Attention implementations:
|
||||||
* `eager`: default `LlamaAttention`
|
* `eager`: default `LlamaAttention`
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -86,6 +86,7 @@ setup(
|
||||||
|
|
||||||
# attention helpers
|
# attention helpers
|
||||||
"xformers",
|
"xformers",
|
||||||
|
"sageattention==1.0.6",
|
||||||
# "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it
|
# "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it
|
||||||
|
|
||||||
# other audio backend that doesn't prove fruitful
|
# other audio backend that doesn't prove fruitful
|
||||||
|
|
|
@ -22,6 +22,13 @@ AVAILABLE_ATTENTIONS = []
|
||||||
|
|
||||||
LN_2 = 0.69314718056
|
LN_2 = 0.69314718056
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn
|
||||||
|
|
||||||
|
AVAILABLE_ATTENTIONS.append("sageattn")
|
||||||
|
except Exception as e:
|
||||||
|
_logger.warning(f"Error while querying for `sageattn` support: {str(e)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||||
|
|
||||||
|
@ -390,11 +397,15 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
key_states = key_states.contiguous()
|
key_states = key_states.contiguous()
|
||||||
value_states = value_states.contiguous()
|
value_states = value_states.contiguous()
|
||||||
|
|
||||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
if mode in ["sageattn"]:
|
||||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
attn_output = sageattn(
|
||||||
# is_causal = True if x_mask is None and q_len > 1 else False
|
query_states,
|
||||||
|
key_states,
|
||||||
if mode in ["fused_attn"]:
|
value_states,
|
||||||
|
tensor_layout="HND",
|
||||||
|
is_causal=is_causal
|
||||||
|
)
|
||||||
|
elif mode in ["fused_attn"]:
|
||||||
attn_output = fused_attn_func(
|
attn_output = fused_attn_func(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
|
@ -418,6 +429,9 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
f" {attn_output.size()}"
|
f" {attn_output.size()}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
|
# is_causal = True if x_mask is None and q_len > 1 else False
|
||||||
is_causal = True if x_mask is None and q_len > 1 else False
|
is_causal = True if x_mask is None and q_len > 1 else False
|
||||||
with torch.nn.attention.sdpa_kernel(self.mode):
|
with torch.nn.attention.sdpa_kernel(self.mode):
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user