the sooner I accept there's no FA for V100s the sooner I'll go to bed
This commit is contained in:
parent
d636edd3a2
commit
29c35528e5
|
@ -19,52 +19,66 @@ try:
|
|||
except Exception as e:
|
||||
print("Error while querying for `flash_attention_2` support", e)
|
||||
|
||||
# Borrowed from https://github.com/turboderp/exllamav2/blob/master/exllamav2/attn.py#L32
|
||||
# Adapted to provide flash_attn_v1 support
|
||||
is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count()))
|
||||
try:
|
||||
import flash_attn
|
||||
flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()]
|
||||
is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count()))
|
||||
if not is_ampere_or_newer_gpu:
|
||||
# Uses https://github.com/ZRayZzz/flash-attention-v100/
|
||||
# Currently doesn't work because it's hard-coded to use a head dim of 128, will throw NaNs otherwise...
|
||||
from flash_attn_v100 import flash_attn_func as flash_attn_v100_func
|
||||
|
||||
if [1, 0, 9] == flash_attn_ver:
|
||||
AVAILABLE_ATTENTIONS.append("flash_attn")
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
from einops import rearrange
|
||||
|
||||
# converts the flash_attn_2 calling convention to flash_attn_1's
|
||||
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False, deterministic=False, *args, **kwargs):
|
||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||
seqlen_k = k.shape[1]
|
||||
q, k, v = [rearrange(x, 'b s ... -> (b s) ...').contiguous() for x in [q, k, v]]
|
||||
|
||||
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = cu_seqlens_q
|
||||
|
||||
return flash_attn_unpadded_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_attn_probs, deterministic
|
||||
AVAILABLE_ATTENTIONS.append("flash_attn_v100") # needed to signal to use padding
|
||||
def flash_attn_func(q, k, v, softmax_scale=None, causal=False, *args, **kwargs):
|
||||
return flash_attn_v100_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
softmax_scale,
|
||||
causal
|
||||
)
|
||||
else:
|
||||
# Borrowed from https://github.com/turboderp/exllamav2/blob/master/exllamav2/attn.py#L32
|
||||
# Adapted to provide flash_attn_v1 support
|
||||
import flash_attn
|
||||
flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()]
|
||||
|
||||
has_flash_attn = True
|
||||
elif [2, 2, 1] <= flash_attn_ver < [2, 5, 7]:
|
||||
AVAILABLE_ATTENTIONS.append("flash_attn")
|
||||
from flash_attn import flash_attn_func
|
||||
has_flash_attn = True
|
||||
elif [2, 5, 7] <= flash_attn_ver:
|
||||
AVAILABLE_ATTENTIONS.append("flash_attn")
|
||||
from flash_attn import flash_attn_func, flash_attn_with_kvcache
|
||||
if flash_attn_ver <= [1, 0, 9]:
|
||||
AVAILABLE_ATTENTIONS.append("flash_attn")
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
from einops import rearrange
|
||||
|
||||
signature = list(inspect.signature(flash_attn_func).parameters)
|
||||
has_flash_attn_with_window = "window_size" in signature
|
||||
has_flash_attn_with_softcap = "softcap" in signature
|
||||
# converts the flash_attn_2 calling convention to flash_attn_1's
|
||||
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False, deterministic=False, *args, **kwargs):
|
||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||
seqlen_k = k.shape[1]
|
||||
q, k, v = [rearrange(x, 'b s ... -> (b s) ...').contiguous() for x in [q, k, v]]
|
||||
|
||||
import flash_attn_2_cuda as flash_attn_cuda
|
||||
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = cu_seqlens_q
|
||||
|
||||
has_flash_attn = True
|
||||
has_flash_attn_with_paged = True
|
||||
return flash_attn_unpadded_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_attn_probs, deterministic
|
||||
)
|
||||
|
||||
has_flash_attn = True
|
||||
elif [2, 2, 1] <= flash_attn_ver < [2, 5, 7]:
|
||||
AVAILABLE_ATTENTIONS.append("flash_attn")
|
||||
from flash_attn import flash_attn_func
|
||||
has_flash_attn = True
|
||||
elif [2, 5, 7] <= flash_attn_ver:
|
||||
AVAILABLE_ATTENTIONS.append("flash_attn")
|
||||
from flash_attn import flash_attn_func, flash_attn_with_kvcache
|
||||
|
||||
signature = list(inspect.signature(flash_attn_func).parameters)
|
||||
has_flash_attn_with_window = "window_size" in signature
|
||||
has_flash_attn_with_softcap = "softcap" in signature
|
||||
|
||||
import flash_attn_2_cuda as flash_attn_cuda
|
||||
|
||||
has_flash_attn = True
|
||||
has_flash_attn_with_paged = True
|
||||
except Exception as e:
|
||||
print("Error while querying for `flash_attn` | support", e)
|
||||
|
||||
|
|
|
@ -401,6 +401,9 @@ class Base(nn.Module):
|
|||
|
||||
self.l_padding = l_padding
|
||||
|
||||
if "flash_attn_v100" in AVAILABLE_ATTENTIONS:
|
||||
self.l_padding = 32
|
||||
|
||||
self.ignore_index = -100
|
||||
|
||||
self.n_resp_levels = self.config.resp_levels if self.config else n_resp_levels
|
||||
|
@ -623,7 +626,7 @@ class Base(nn.Module):
|
|||
attn_implementation=hf_attention,
|
||||
#gradient_checkpointing=self.gradient_checkpointing,
|
||||
))
|
||||
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]:
|
||||
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto", "flash_attn"]:
|
||||
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
|
||||
|
||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||
|
|
Loading…
Reference in New Issue
Block a user