oops
This commit is contained in:
parent
0aa59e6f3f
commit
d04f6911b4
|
@ -1149,7 +1149,7 @@ class Dataset(_Dataset):
|
|||
text = torch.tensor([bos_id, eos_id]).to(self.text_dtype)
|
||||
|
||||
# pad the target with silence
|
||||
if p_resp_pad_silence < random.random():
|
||||
if random.random() < cfg.dataset.p_resp_pad_silence:
|
||||
resps = pad_codes_with_silence( resps )
|
||||
|
||||
return dict(
|
||||
|
|
|
@ -11,6 +11,24 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
|
|||
|
||||
AVAILABLE_ATTENTIONS = []
|
||||
|
||||
try:
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
AVAILABLE_ATTENTIONS.append("flash_attention_2")
|
||||
except Exception as e:
|
||||
print("Error while querying for `flash_attn_2` support", e)
|
||||
|
||||
"""
|
||||
try:
|
||||
from xformers.ops import LowerTriangularMask
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
|
||||
AVAILABLE_ATTENTIONS.append("xformers")
|
||||
except Exception as e:
|
||||
print("Error while importing `xformers`", e)
|
||||
"""
|
||||
|
||||
if torch.backends.cuda.flash_sdp_enabled():
|
||||
AVAILABLE_ATTENTIONS.append("flash")
|
||||
|
||||
|
@ -26,22 +44,6 @@ if torch.backends.cuda.cudnn_sdp_enabled():
|
|||
if AVAILABLE_ATTENTIONS:
|
||||
AVAILABLE_ATTENTIONS.append("sdpa")
|
||||
|
||||
try:
|
||||
from xformers.ops import LowerTriangularMask
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
|
||||
AVAILABLE_ATTENTIONS.append("xformers")
|
||||
except Exception as e:
|
||||
print("Error while importing `xformers`", e)
|
||||
|
||||
try:
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
AVAILABLE_ATTENTIONS.append("flash_attention_2")
|
||||
except Exception as e:
|
||||
print("Error while querying for `flash_attn_2` support", e)
|
||||
|
||||
class LlamaAttention_Adapted(LlamaAttention):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if 'mode' in kwargs:
|
||||
|
@ -137,8 +139,6 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
print("attention")
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
|
|
|
@ -514,16 +514,8 @@ class Base(nn.Module):
|
|||
# there seems to have been a regression where anything touching the wrapped LlamaAttention class breaks
|
||||
|
||||
if attention_backend == "auto":
|
||||
if "flash_attention_2" in AVAILABLE_ATTENTIONS:
|
||||
attention_backend = "flash_attention_2"
|
||||
elif "flash" in AVAILABLE_ATTENTIONS:
|
||||
attention_backend = "flash"
|
||||
elif "mem_efficient" in AVAILABLE_ATTENTIONS:
|
||||
attention_backend = "mem_efficient"
|
||||
elif "math" in AVAILABLE_ATTENTIONS:
|
||||
attention_backend = "math"
|
||||
elif "sdpa" in AVAILABLE_ATTENTIONS:
|
||||
attention_backend = "sdpa"
|
||||
if AVAILABLE_ATTENTIONS:
|
||||
attention_backend = AVAILABLE_ATTENTIONS[0]
|
||||
else:
|
||||
attention_backend = "eager"
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user