oops
This commit is contained in:
parent
0aa59e6f3f
commit
d04f6911b4
|
@ -1149,7 +1149,7 @@ class Dataset(_Dataset):
|
||||||
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
||||||
|
|
||||||
# pad the target with silence
|
# pad the target with silence
|
||||||
if p_resp_pad_silence < random.random():
|
if random.random() < cfg.dataset.p_resp_pad_silence:
|
||||||
resps = pad_codes_with_silence( resps )
|
resps = pad_codes_with_silence( resps )
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
|
|
|
@ -11,6 +11,24 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
|
||||||
|
|
||||||
AVAILABLE_ATTENTIONS = []
|
AVAILABLE_ATTENTIONS = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
AVAILABLE_ATTENTIONS.append("flash_attention_2")
|
||||||
|
except Exception as e:
|
||||||
|
print("Error while querying for `flash_attn_2` support", e)
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from xformers.ops import LowerTriangularMask
|
||||||
|
from xformers.ops.fmha import memory_efficient_attention
|
||||||
|
|
||||||
|
AVAILABLE_ATTENTIONS.append("xformers")
|
||||||
|
except Exception as e:
|
||||||
|
print("Error while importing `xformers`", e)
|
||||||
|
"""
|
||||||
|
|
||||||
if torch.backends.cuda.flash_sdp_enabled():
|
if torch.backends.cuda.flash_sdp_enabled():
|
||||||
AVAILABLE_ATTENTIONS.append("flash")
|
AVAILABLE_ATTENTIONS.append("flash")
|
||||||
|
|
||||||
|
@ -26,22 +44,6 @@ if torch.backends.cuda.cudnn_sdp_enabled():
|
||||||
if AVAILABLE_ATTENTIONS:
|
if AVAILABLE_ATTENTIONS:
|
||||||
AVAILABLE_ATTENTIONS.append("sdpa")
|
AVAILABLE_ATTENTIONS.append("sdpa")
|
||||||
|
|
||||||
try:
|
|
||||||
from xformers.ops import LowerTriangularMask
|
|
||||||
from xformers.ops.fmha import memory_efficient_attention
|
|
||||||
|
|
||||||
AVAILABLE_ATTENTIONS.append("xformers")
|
|
||||||
except Exception as e:
|
|
||||||
print("Error while importing `xformers`", e)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.utils import is_flash_attn_2_available
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
|
||||||
AVAILABLE_ATTENTIONS.append("flash_attention_2")
|
|
||||||
except Exception as e:
|
|
||||||
print("Error while querying for `flash_attn_2` support", e)
|
|
||||||
|
|
||||||
class LlamaAttention_Adapted(LlamaAttention):
|
class LlamaAttention_Adapted(LlamaAttention):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
if 'mode' in kwargs:
|
if 'mode' in kwargs:
|
||||||
|
@ -137,8 +139,6 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
is_causal=is_causal,
|
is_causal=is_causal,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("attention")
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.view(bsz, q_len, -1)
|
attn_output = attn_output.view(bsz, q_len, -1)
|
||||||
|
|
||||||
|
|
|
@ -514,16 +514,8 @@ class Base(nn.Module):
|
||||||
# there seems to have been a regression where anything touching the wrapped LlamaAttention class breaks
|
# there seems to have been a regression where anything touching the wrapped LlamaAttention class breaks
|
||||||
|
|
||||||
if attention_backend == "auto":
|
if attention_backend == "auto":
|
||||||
if "flash_attention_2" in AVAILABLE_ATTENTIONS:
|
if AVAILABLE_ATTENTIONS:
|
||||||
attention_backend = "flash_attention_2"
|
attention_backend = AVAILABLE_ATTENTIONS[0]
|
||||||
elif "flash" in AVAILABLE_ATTENTIONS:
|
|
||||||
attention_backend = "flash"
|
|
||||||
elif "mem_efficient" in AVAILABLE_ATTENTIONS:
|
|
||||||
attention_backend = "mem_efficient"
|
|
||||||
elif "math" in AVAILABLE_ATTENTIONS:
|
|
||||||
attention_backend = "math"
|
|
||||||
elif "sdpa" in AVAILABLE_ATTENTIONS:
|
|
||||||
attention_backend = "sdpa"
|
|
||||||
else:
|
else:
|
||||||
attention_backend = "eager"
|
attention_backend = "eager"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user