From d04f6911b44c694cb91ab90f78ac7654aad44248 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 8 Aug 2024 19:38:55 -0500 Subject: [PATCH] oops --- vall_e/data.py | 2 +- vall_e/models/arch/llama.py | 36 ++++++++++++++++++------------------ vall_e/models/base.py | 12 ++---------- 3 files changed, 21 insertions(+), 29 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 8d777d7..f3c4b4f 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1149,7 +1149,7 @@ class Dataset(_Dataset): text = torch.tensor([bos_id, eos_id]).to(self.text_dtype) # pad the target with silence - if p_resp_pad_silence < random.random(): + if random.random() < cfg.dataset.p_resp_pad_silence: resps = pad_codes_with_silence( resps ) return dict( diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index b2ee4e0..b92e145 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -11,6 +11,24 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar AVAILABLE_ATTENTIONS = [] +try: + from transformers.utils import is_flash_attn_2_available + + if is_flash_attn_2_available(): + AVAILABLE_ATTENTIONS.append("flash_attention_2") +except Exception as e: + print("Error while querying for `flash_attn_2` support", e) + +""" +try: + from xformers.ops import LowerTriangularMask + from xformers.ops.fmha import memory_efficient_attention + + AVAILABLE_ATTENTIONS.append("xformers") +except Exception as e: + print("Error while importing `xformers`", e) +""" + if torch.backends.cuda.flash_sdp_enabled(): AVAILABLE_ATTENTIONS.append("flash") @@ -26,22 +44,6 @@ if torch.backends.cuda.cudnn_sdp_enabled(): if AVAILABLE_ATTENTIONS: AVAILABLE_ATTENTIONS.append("sdpa") -try: - from xformers.ops import LowerTriangularMask - from xformers.ops.fmha import memory_efficient_attention - - AVAILABLE_ATTENTIONS.append("xformers") -except Exception as e: - print("Error while importing `xformers`", e) - -try: - from transformers.utils import is_flash_attn_2_available - - if is_flash_attn_2_available(): - AVAILABLE_ATTENTIONS.append("flash_attention_2") -except Exception as e: - print("Error while querying for `flash_attn_2` support", e) - class LlamaAttention_Adapted(LlamaAttention): def __init__(self, *args, **kwargs): if 'mode' in kwargs: @@ -137,8 +139,6 @@ class LlamaAttention_Adapted(LlamaAttention): is_causal=is_causal, ) - print("attention") - attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 89404c3..0e02116 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -514,16 +514,8 @@ class Base(nn.Module): # there seems to have been a regression where anything touching the wrapped LlamaAttention class breaks if attention_backend == "auto": - if "flash_attention_2" in AVAILABLE_ATTENTIONS: - attention_backend = "flash_attention_2" - elif "flash" in AVAILABLE_ATTENTIONS: - attention_backend = "flash" - elif "mem_efficient" in AVAILABLE_ATTENTIONS: - attention_backend = "mem_efficient" - elif "math" in AVAILABLE_ATTENTIONS: - attention_backend = "math" - elif "sdpa" in AVAILABLE_ATTENTIONS: - attention_backend = "sdpa" + if AVAILABLE_ATTENTIONS: + attention_backend = AVAILABLE_ATTENTIONS[0] else: attention_backend = "eager"