This commit is contained in:
mrq 2024-08-08 19:38:55 -05:00
parent 0aa59e6f3f
commit d04f6911b4
3 changed files with 21 additions and 29 deletions

View File

@ -1149,7 +1149,7 @@ class Dataset(_Dataset):
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype) text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
# pad the target with silence # pad the target with silence
if p_resp_pad_silence < random.random(): if random.random() < cfg.dataset.p_resp_pad_silence:
resps = pad_codes_with_silence( resps ) resps = pad_codes_with_silence( resps )
return dict( return dict(

View File

@ -11,6 +11,24 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
AVAILABLE_ATTENTIONS = [] AVAILABLE_ATTENTIONS = []
try:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
AVAILABLE_ATTENTIONS.append("flash_attention_2")
except Exception as e:
print("Error while querying for `flash_attn_2` support", e)
"""
try:
from xformers.ops import LowerTriangularMask
from xformers.ops.fmha import memory_efficient_attention
AVAILABLE_ATTENTIONS.append("xformers")
except Exception as e:
print("Error while importing `xformers`", e)
"""
if torch.backends.cuda.flash_sdp_enabled(): if torch.backends.cuda.flash_sdp_enabled():
AVAILABLE_ATTENTIONS.append("flash") AVAILABLE_ATTENTIONS.append("flash")
@ -26,22 +44,6 @@ if torch.backends.cuda.cudnn_sdp_enabled():
if AVAILABLE_ATTENTIONS: if AVAILABLE_ATTENTIONS:
AVAILABLE_ATTENTIONS.append("sdpa") AVAILABLE_ATTENTIONS.append("sdpa")
try:
from xformers.ops import LowerTriangularMask
from xformers.ops.fmha import memory_efficient_attention
AVAILABLE_ATTENTIONS.append("xformers")
except Exception as e:
print("Error while importing `xformers`", e)
try:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
AVAILABLE_ATTENTIONS.append("flash_attention_2")
except Exception as e:
print("Error while querying for `flash_attn_2` support", e)
class LlamaAttention_Adapted(LlamaAttention): class LlamaAttention_Adapted(LlamaAttention):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if 'mode' in kwargs: if 'mode' in kwargs:
@ -137,8 +139,6 @@ class LlamaAttention_Adapted(LlamaAttention):
is_causal=is_causal, is_causal=is_causal,
) )
print("attention")
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1) attn_output = attn_output.view(bsz, q_len, -1)

View File

@ -514,16 +514,8 @@ class Base(nn.Module):
# there seems to have been a regression where anything touching the wrapped LlamaAttention class breaks # there seems to have been a regression where anything touching the wrapped LlamaAttention class breaks
if attention_backend == "auto": if attention_backend == "auto":
if "flash_attention_2" in AVAILABLE_ATTENTIONS: if AVAILABLE_ATTENTIONS:
attention_backend = "flash_attention_2" attention_backend = AVAILABLE_ATTENTIONS[0]
elif "flash" in AVAILABLE_ATTENTIONS:
attention_backend = "flash"
elif "mem_efficient" in AVAILABLE_ATTENTIONS:
attention_backend = "mem_efficient"
elif "math" in AVAILABLE_ATTENTIONS:
attention_backend = "math"
elif "sdpa" in AVAILABLE_ATTENTIONS:
attention_backend = "sdpa"
else: else:
attention_backend = "eager" attention_backend = "eager"