the sooner I accept there's no FA for V100s the sooner I'll go to bed

This commit is contained in:
mrq 2024-08-18 23:54:33 -05:00
parent d636edd3a2
commit 29c35528e5
2 changed files with 55 additions and 38 deletions

View File

@ -19,14 +19,30 @@ try:
except Exception as e:
print("Error while querying for `flash_attention_2` support", e)
# Borrowed from https://github.com/turboderp/exllamav2/blob/master/exllamav2/attn.py#L32
# Adapted to provide flash_attn_v1 support
is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count()))
try:
if not is_ampere_or_newer_gpu:
# Uses https://github.com/ZRayZzz/flash-attention-v100/
# Currently doesn't work because it's hard-coded to use a head dim of 128, will throw NaNs otherwise...
from flash_attn_v100 import flash_attn_func as flash_attn_v100_func
AVAILABLE_ATTENTIONS.append("flash_attn")
AVAILABLE_ATTENTIONS.append("flash_attn_v100") # needed to signal to use padding
def flash_attn_func(q, k, v, softmax_scale=None, causal=False, *args, **kwargs):
return flash_attn_v100_func(
q,
k,
v,
softmax_scale,
causal
)
else:
# Borrowed from https://github.com/turboderp/exllamav2/blob/master/exllamav2/attn.py#L32
# Adapted to provide flash_attn_v1 support
import flash_attn
flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()]
is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count()))
if [1, 0, 9] == flash_attn_ver:
if flash_attn_ver <= [1, 0, 9]:
AVAILABLE_ATTENTIONS.append("flash_attn")
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from einops import rearrange
@ -63,8 +79,6 @@ try:
has_flash_attn = True
has_flash_attn_with_paged = True
except Exception as e:
print("Error while querying for `flash_attn` | support", e)

View File

@ -401,6 +401,9 @@ class Base(nn.Module):
self.l_padding = l_padding
if "flash_attn_v100" in AVAILABLE_ATTENTIONS:
self.l_padding = 32
self.ignore_index = -100
self.n_resp_levels = self.config.resp_levels if self.config else n_resp_levels
@ -623,7 +626,7 @@ class Base(nn.Module):
attn_implementation=hf_attention,
#gradient_checkpointing=self.gradient_checkpointing,
))
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto"]:
if attention_backend in ["mem_efficient", "math", "flash", "cudnn", "auto", "flash_attn"]:
self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend )
if self.gradient_checkpointing and not self.model.gradient_checkpointing: