mamba2-hf using vasqu/mamba2-torch because it lets me use mamba2 without triton ops (training with my 4xV100s are not happy with mamba2 because of triton)

This commit is contained in:
mrq 2024-06-14 19:42:17 -05:00
parent 31f71fa134
commit ccb14c06ef
4 changed files with 58 additions and 9 deletions

View File

@ -1,62 +1,71 @@
AVAILABLE_ARCHES = []
ERROR_ARCHES = {}
try:
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
AVAILABLE_ARCHES.append("transformer")
except Exception as e:
print("Error importing `transformer` arch:", e)
ERROR_ARCHES["transformer"] = e
pass
try:
from .retnet import RetNetDecoder, RetNetConfig
AVAILABLE_ARCHES.append("retnet")
except Exception as e:
print("Error importing `retnet` arch:", e)
ERROR_ARCHES["retnet"] = e
pass
try:
from .retnet_syncdoth.retnet_ts import RetNetDecoder as RetNetDecoder_TS, RetNetConfig as RetNetConfig_TS
AVAILABLE_ARCHES.append("retnet-ts")
except Exception as e:
print("Error importing `retnet-ts` arch:", e)
ERROR_ARCHES["retnet-ts"] = e
pass
try:
from .retnet_syncdoth.retnet_hf import RetNetDecoder as RetNetDecoder_HF, RetNetConfig as RetNetConfig_HF, RetNetForCausalLM
AVAILABLE_ARCHES.append("retnet-hf")
except Exception as e:
print("Error importing `retnet-hf` arch:", e)
ERROR_ARCHES["retnet-hf"] = e
pass
try:
from .llama import LlamaModel, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Base, LlamaForCausalLM
AVAILABLE_ARCHES.append("llama")
except Exception as e:
print("Error importing `llama` arch:", e)
ERROR_ARCHES["llama"] = e
pass
try:
from .bitnet import BitNetTransformer
AVAILABLE_ARCHES.append("bitnet")
except Exception as e:
print("Error importing `bitnet` arch:", e)
ERROR_ARCHES["bitnet"] = e
pass
try:
from .mixtral import MixtralModel, MixtralConfig
AVAILABLE_ARCHES.append("mixtral")
except Exception as e:
print("Error importing `mixtral` arch:", e)
ERROR_ARCHES["mixtral"] = e
try:
from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig
AVAILABLE_ARCHES.append("mamba")
AVAILABLE_ARCHES.append("mamba2")
except Exception as e:
print("Error importing `mamba` arch:", e)
ERROR_ARCHES["mamba"] = e
ERROR_ARCHES["mamba2"] = e
try:
from .mamba_vasqu import Mamba2Model_HF, Mamba2Config_HF
AVAILABLE_ARCHES.append("mamba2-hf")
except Exception as e:
ERROR_ARCHES["mamba2-hf"] = e
# desu should remove, perf was very lacking in comparison to regular bitnet
try:
from .mmfreelm import *
AVAILABLE_ARCHES.append("mmfreelm")
except Exception as e:
print("Error importing `mmfreelm` arch:", e)
ERROR_ARCHES["mmfreelm"] = e

View File

@ -0,0 +1 @@
from .mamba2_hf import *

View File

@ -0,0 +1,4 @@
# https://github.com/vasqu/mamba2-torch
# NOTE: edit `src/mamba2_torch/__init__.py` to remove reference to .src. because of how pip treats packages
from mamba2_torch import Mamba2Model as Mamba2Model_HF, Mamba2Config as Mamba2Config_HF

View File

@ -329,6 +329,10 @@ class Base(nn.Module):
n_prom_tokens = n_audio_tokens
# check if requested arch is unavailable
if self.arch_type in ERROR_ARCHES:
raise ERROR_ARCHES[self.arch_type]
if "len" not in self.capabilities:
# +1 to include the stop token
n_resp_tokens = n_audio_tokens + self.causal_size
@ -592,6 +596,21 @@ class Base(nn.Module):
#initializer_cfg=initializer_cfg,
)
self.model.gradient_checkpointing = self.gradient_checkpointing
elif self.arch_type in ["mamba2-hf"]:
self.model = Mamba2Model_HF(Mamba2Config_HF(
vocab_size=n_resp_tokens,
hidden_size=d_model,
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
expand=4,
num_hidden_layers=n_layers,
is_encoder_decoder=False,
is_decoder=True,
use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it)
))
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
elif self.arch_type == "mmfreelm":
self.model = HGRNBitModel(HGRNBitConfig(
vocab_size=n_resp_tokens,
@ -617,6 +636,9 @@ class Base(nn.Module):
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
if hasattr( self.model, "embeddings" ):
del self.model.embeddings
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention )
@ -713,6 +735,19 @@ class Base(nn.Module):
state = out.past_key_values
elif self.arch_type in ["mamba","mamba2"]:
x = self.model( hidden_states=x )
elif self.arch_type == "mamba2-hf":
first = state is None or len(state) == 0
kwargs = dict(
inputs_embeds=x,
cache_params=state,
return_dict=True,
)
out = self.model(**kwargs)
x = out.last_hidden_state
if state is not None:
state = out.cache_params
elif self.arch_type == "bitnet":
x = self.model(x)
elif self.arch_type == "mmfreelm":