This commit is contained in:
mrq 2024-08-08 19:38:55 -05:00
parent 0aa59e6f3f
commit d04f6911b4
3 changed files with 21 additions and 29 deletions

View File

@ -1149,7 +1149,7 @@ class Dataset(_Dataset):
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
# pad the target with silence
if p_resp_pad_silence < random.random():
if random.random() < cfg.dataset.p_resp_pad_silence:
resps = pad_codes_with_silence( resps )
return dict(

View File

@ -11,6 +11,24 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
AVAILABLE_ATTENTIONS = []
try:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
AVAILABLE_ATTENTIONS.append("flash_attention_2")
except Exception as e:
print("Error while querying for `flash_attn_2` support", e)
"""
try:
from xformers.ops import LowerTriangularMask
from xformers.ops.fmha import memory_efficient_attention
AVAILABLE_ATTENTIONS.append("xformers")
except Exception as e:
print("Error while importing `xformers`", e)
"""
if torch.backends.cuda.flash_sdp_enabled():
AVAILABLE_ATTENTIONS.append("flash")
@ -26,22 +44,6 @@ if torch.backends.cuda.cudnn_sdp_enabled():
if AVAILABLE_ATTENTIONS:
AVAILABLE_ATTENTIONS.append("sdpa")
try:
from xformers.ops import LowerTriangularMask
from xformers.ops.fmha import memory_efficient_attention
AVAILABLE_ATTENTIONS.append("xformers")
except Exception as e:
print("Error while importing `xformers`", e)
try:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
AVAILABLE_ATTENTIONS.append("flash_attention_2")
except Exception as e:
print("Error while querying for `flash_attn_2` support", e)
class LlamaAttention_Adapted(LlamaAttention):
def __init__(self, *args, **kwargs):
if 'mode' in kwargs:
@ -137,8 +139,6 @@ class LlamaAttention_Adapted(LlamaAttention):
is_causal=is_causal,
)
print("attention")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)

View File

@ -514,16 +514,8 @@ class Base(nn.Module):
# there seems to have been a regression where anything touching the wrapped LlamaAttention class breaks
if attention_backend == "auto":
if "flash_attention_2" in AVAILABLE_ATTENTIONS:
attention_backend = "flash_attention_2"
elif "flash" in AVAILABLE_ATTENTIONS:
attention_backend = "flash"
elif "mem_efficient" in AVAILABLE_ATTENTIONS:
attention_backend = "mem_efficient"
elif "math" in AVAILABLE_ATTENTIONS:
attention_backend = "math"
elif "sdpa" in AVAILABLE_ATTENTIONS:
attention_backend = "sdpa"
if AVAILABLE_ATTENTIONS:
attention_backend = AVAILABLE_ATTENTIONS[0]
else:
attention_backend = "eager"