temporarily dropping support for xformers because it's breaking when using an attention mask (which i dont remember commenting it out when being passed), default to not use wandb because it's being a pain when doing tests and not actual sessionsS)
This commit is contained in:
parent
8aafae91fd
commit
24d888c47c
|
@ -697,7 +697,7 @@ class Trainer:
|
|||
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
||||
gc_mode: str | None = None # deprecated, but marks when to do GC
|
||||
|
||||
wandb: bool = True # use wandb, if available
|
||||
wandb: bool = False # use wandb, if available
|
||||
|
||||
weight_dtype: str = "float16" # dtype to have the model under
|
||||
|
||||
|
|
|
@ -805,7 +805,10 @@ class Dataset(_Dataset):
|
|||
self.task_symmap = self._get_task_symmap()
|
||||
|
||||
# grab IDs for bos, space, and eos for easy input creation later
|
||||
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
|
||||
try:
|
||||
self.empty_text = [ cfg.tokenizer._bos_token, cfg.tokenizer.get_vocab()[" "], cfg.tokenizer._eos_token ]
|
||||
except Exception as e:
|
||||
self.empty_text = [None, None, None]
|
||||
|
||||
# have it fetch at training time if any is invalid, because the tokenizer obj might not have it easily fetchable ahead of itme
|
||||
# encoding before parallelizing things causes things to whine
|
||||
|
|
|
@ -109,6 +109,7 @@ try:
|
|||
except Exception as e:
|
||||
_logger.warning(f"Error while querying for `flash_attn` support: {str(e)}")
|
||||
|
||||
"""
|
||||
try:
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
from xformers.ops.fmha.attn_bias import LowerTriangularFromBottomRightMask, LowerTriangularMask
|
||||
|
@ -116,6 +117,7 @@ try:
|
|||
AVAILABLE_ATTENTIONS.append("xformers")
|
||||
except Exception as e:
|
||||
_logger.warning(f"Error while importing `xformers`: {str(e)}")
|
||||
"""
|
||||
|
||||
# to-do: find a better way to query for if there's available kernels since these return true regardless
|
||||
if torch.backends.cuda.flash_sdp_enabled():
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
|
||||
from transformers.models.mamba.modeling_mamba import MambaModel
|
||||
from transformers.models.mamba2.modeling_mamba2 import Mamba2Model
|
||||
|
||||
from transformers.models.mamba.configuration_mamba import MambaConfig
|
||||
from transformers.models.mamba.modeling_mamba import MambaModel
|
||||
|
||||
"""
|
||||
from transformers.models.mamba2.modeling_mamba2 import Mamba2Model
|
||||
from transformers.models.mamba2.configuration_mamba2 import Mamba2Config
|
||||
"""
|
||||
|
||||
from mamba2_torch.modeling.configuration_mamba2 import Mamba2Config
|
||||
from mamba2_torch.modeling.modeling_mamba2 import Mamba2Model
|
||||
|
||||
"""
|
||||
# https://github.com/state-spaces/mamba
|
||||
|
|
Loading…
Reference in New Issue
Block a user