temporarily dropping support for xformers because it's breaking when using an attention mask (which i dont remember commenting it out when being passed), default to not use wandb because it's being a pain when doing tests and not actual sessionsS)

This commit is contained in:
mrq 2024-11-22 11:29:12 -06:00
parent 8aafae91fd
commit 24d888c47c
4 changed files with 15 additions and 5 deletions

View File

@ -697,7 +697,7 @@ class Trainer:
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
gc_mode: str | None = None # deprecated, but marks when to do GC
wandb: bool = True # use wandb, if available
wandb: bool = False # use wandb, if available
weight_dtype: str = "float16" # dtype to have the model under

View File

@ -805,7 +805,10 @@ class Dataset(_Dataset):
self.task_symmap = self._get_task_symmap()
# grab IDs for bos, space, and eos for easy input creation later
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
try:
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
except Exception as e:
self.empty_text = [None, None, None]
# have it fetch at training time if any is invalid, because the tokenizer obj might not have it easily fetchable ahead of itme
# encoding before parallelizing things causes things to whine

View File

@ -109,6 +109,7 @@ try:
except Exception as e:
_logger.warning(f"Error while querying for `flash_attn` support: {str(e)}")
"""
try:
from xformers.ops.fmha import memory_efficient_attention
from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask, LowerTriangularMask
@ -116,6 +117,7 @@ try:
AVAILABLE_ATTENTIONS.append("xformers")
except Exception as e:
_logger.warning(f"Error while importing `xformers`: {str(e)}")
"""
# to-do: find a better way to query for if there's available kernels since these return true regardless
if torch.backends.cuda.flash_sdp_enabled():

View File

@ -1,9 +1,14 @@
from transformers.models.mamba.modeling_mamba import MambaModel
from transformers.models.mamba2.modeling_mamba2 import Mamba2Model
from transformers.models.mamba.configuration_mamba import MambaConfig
from transformers.models.mamba.modeling_mamba import MambaModel
"""
from transformers.models.mamba2.modeling_mamba2 import Mamba2Model
from transformers.models.mamba2.configuration_mamba2 import Mamba2Config
"""
from mamba2_torch.modeling.configuration_mamba2 import Mamba2Config
from mamba2_torch.modeling.modeling_mamba2 import Mamba2Model
"""
# https://github.com/state-spaces/mamba