From 24d888c47c3ef730c50ece3de1f23a07a70e2ce1 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 22 Nov 2024 11:29:12 -0600 Subject: [PATCH] temporarily dropping support for xformers because it's breaking when using an attention mask (which i dont remember commenting it out when being passed), default to not use wandb because it's being a pain when doing tests and not actual sessionsS) --- vall_e/config.py | 2 +- vall_e/data.py | 5 ++++- vall_e/models/arch/llama.py | 2 ++ vall_e/models/arch/mamba.py | 11 ++++++++--- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 1fedcf9..8b0da14 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -697,7 +697,7 @@ class Trainer: check_for_oom: bool = True # checks for OOMs thrown during forward/backwards gc_mode: str | None = None # deprecated, but marks when to do GC - wandb: bool = True # use wandb, if available + wandb: bool = False # use wandb, if available weight_dtype: str = "float16" # dtype to have the model under diff --git a/vall_e/data.py b/vall_e/data.py index 05c4bac..acb7849 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -805,7 +805,10 @@ class Dataset(_Dataset): self.task_symmap = self._get_task_symmap() # grab IDs for bos, space, and eos for easy input creation later - self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ] + try: + self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ] + except Exception as e: + self.empty_text = [None, None, None] # have it fetch at training time if any is invalid, because the tokenizer obj might not have it easily fetchable ahead of itme # encoding before parallelizing things causes things to whine diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index e8b28b8..607bbfb 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -109,6 +109,7 @@ try: except Exception as e: _logger.warning(f"Error while querying for `flash_attn` support: {str(e)}") +""" try: from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask, LowerTriangularMask @@ -116,6 +117,7 @@ try: AVAILABLE_ATTENTIONS.append("xformers") except Exception as e: _logger.warning(f"Error while importing `xformers`: {str(e)}") +""" # to-do: find a better way to query for if there's available kernels since these return true regardless if torch.backends.cuda.flash_sdp_enabled(): diff --git a/vall_e/models/arch/mamba.py b/vall_e/models/arch/mamba.py index 78e20ce..15cd8c9 100644 --- a/vall_e/models/arch/mamba.py +++ b/vall_e/models/arch/mamba.py @@ -1,9 +1,14 @@ -from transformers.models.mamba.modeling_mamba import MambaModel -from transformers.models.mamba2.modeling_mamba2 import Mamba2Model - from transformers.models.mamba.configuration_mamba import MambaConfig +from transformers.models.mamba.modeling_mamba import MambaModel + +""" +from transformers.models.mamba2.modeling_mamba2 import Mamba2Model from transformers.models.mamba2.configuration_mamba2 import Mamba2Config +""" + +from mamba2_torch.modeling.configuration_mamba2 import Mamba2Config +from mamba2_torch.modeling.modeling_mamba2 import Mamba2Model """ # https://github.com/state-spaces/mamba