pain (some shit to try and get some flash attention for ROCm (gfx1100) through triton fused attention but no good)
This commit is contained in:
parent
40e1799adc
commit
6b0891448c
|
@ -34,6 +34,7 @@ try:
|
|||
AVAILABLE_ARCHES.append("llama")
|
||||
except Exception as e:
|
||||
ERROR_ARCHES["llama"] = e
|
||||
AVAILABLE_ARCHES = []
|
||||
pass
|
||||
|
||||
try:
|
||||
|
|
|
@ -20,9 +20,41 @@ try:
|
|||
except Exception as e:
|
||||
print("Error while querying for `flash_attention_2` support", e)
|
||||
|
||||
is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count()))
|
||||
is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count()))
|
||||
|
||||
try:
|
||||
if not is_ampere_or_newer_gpu:
|
||||
if is_rocm and False:
|
||||
# try to use triton flash attention / fused attention
|
||||
# currently only forward works, backwards throws an assert
|
||||
# even then it's extremely slow on my 7900XTX so the provided code is probably botched since it's a benchmark sample
|
||||
from einops import rearrange
|
||||
from .triton_flash_attention import triton_attention, MetaData
|
||||
|
||||
def flash_attn_func(q, k, v, softmax_scale=None, causal=False, *args, **kwargs):
|
||||
metadata = MetaData(sm_scale=softmax_scale)
|
||||
batch_size, seqlen_q, seqlen_k = q.shape[0], q.shape[1], k.shape[1]
|
||||
|
||||
metadata.max_seqlens_q = seqlen_q
|
||||
metadata.max_seqlens_k = seqlen_k
|
||||
|
||||
# varlen but doesn't seem necessary
|
||||
if False:
|
||||
q, k, v = [rearrange(x, 'b s ... -> (b s) ...').contiguous() for x in [q, k, v]]
|
||||
|
||||
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = cu_seqlens_q
|
||||
|
||||
metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
|
||||
|
||||
if causal:
|
||||
metadata.need_causal()
|
||||
|
||||
return triton_attention( q, k, v, None, metadata )[0]
|
||||
|
||||
AVAILABLE_ATTENTIONS.append("flash_attn")
|
||||
AVAILABLE_ATTENTIONS.append("flash_attn_rocm")
|
||||
elif not is_ampere_or_newer_gpu:
|
||||
# Uses https://github.com/ZRayZzz/flash-attention-v100/
|
||||
# Currently doesn't work because it's hard-coded to use a head dim of 128, will throw NaNs otherwise...
|
||||
from flash_attn_v100 import flash_attn_func as flash_attn_v100_func
|
||||
|
@ -81,7 +113,7 @@ try:
|
|||
has_flash_attn = True
|
||||
has_flash_attn_with_paged = True
|
||||
except Exception as e:
|
||||
print("Error while querying for `flash_attn` | support", e)
|
||||
print("Error while querying for `flash_attn` support", e)
|
||||
|
||||
try:
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
|
@ -91,8 +123,9 @@ try:
|
|||
except Exception as e:
|
||||
print("Error while importing `xformers`", e)
|
||||
|
||||
# to-do: find a better way to query for if there's available kernels since these return true regardless
|
||||
if torch.backends.cuda.flash_sdp_enabled():
|
||||
AVAILABLE_ATTENTIONS.append("flash")
|
||||
AVAILABLE_ATTENTIONS.append("flash")
|
||||
|
||||
if torch.backends.cuda.mem_efficient_sdp_enabled():
|
||||
AVAILABLE_ATTENTIONS.append("mem_efficient")
|
||||
|
|
|
@ -514,8 +514,6 @@ class Base(nn.Module):
|
|||
# experimental NAR-only mode
|
||||
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
|
||||
|
||||
# there seems to have been a regression where anything touching the wrapped LlamaAttention class breaks
|
||||
|
||||
if attention_backend == "auto":
|
||||
if AVAILABLE_ATTENTIONS:
|
||||
attention_backend = AVAILABLE_ATTENTIONS[0]
|
||||
|
|
Loading…
Reference in New Issue
Block a user