mamba2-hf using vasqu/mamba2-torch
because it lets me use mamba2 without triton ops (training with my 4xV100s are not happy with mamba2 because of triton)
This commit is contained in:
parent
31f71fa134
commit
ccb14c06ef
|
@ -1,62 +1,71 @@
|
|||
AVAILABLE_ARCHES = []
|
||||
ERROR_ARCHES = {}
|
||||
|
||||
try:
|
||||
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
|
||||
AVAILABLE_ARCHES.append("transformer")
|
||||
except Exception as e:
|
||||
print("Error importing `transformer` arch:", e)
|
||||
ERROR_ARCHES["transformer"] = e
|
||||
pass
|
||||
|
||||
try:
|
||||
from .retnet import RetNetDecoder, RetNetConfig
|
||||
AVAILABLE_ARCHES.append("retnet")
|
||||
except Exception as e:
|
||||
print("Error importing `retnet` arch:", e)
|
||||
ERROR_ARCHES["retnet"] = e
|
||||
pass
|
||||
|
||||
try:
|
||||
from .retnet_syncdoth.retnet_ts import RetNetDecoder as RetNetDecoder_TS, RetNetConfig as RetNetConfig_TS
|
||||
AVAILABLE_ARCHES.append("retnet-ts")
|
||||
except Exception as e:
|
||||
print("Error importing `retnet-ts` arch:", e)
|
||||
ERROR_ARCHES["retnet-ts"] = e
|
||||
pass
|
||||
|
||||
try:
|
||||
from .retnet_syncdoth.retnet_hf import RetNetDecoder as RetNetDecoder_HF, RetNetConfig as RetNetConfig_HF, RetNetForCausalLM
|
||||
AVAILABLE_ARCHES.append("retnet-hf")
|
||||
except Exception as e:
|
||||
print("Error importing `retnet-hf` arch:", e)
|
||||
ERROR_ARCHES["retnet-hf"] = e
|
||||
pass
|
||||
|
||||
try:
|
||||
from .llama import LlamaModel, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Base, LlamaForCausalLM
|
||||
AVAILABLE_ARCHES.append("llama")
|
||||
except Exception as e:
|
||||
print("Error importing `llama` arch:", e)
|
||||
ERROR_ARCHES["llama"] = e
|
||||
pass
|
||||
|
||||
try:
|
||||
from .bitnet import BitNetTransformer
|
||||
AVAILABLE_ARCHES.append("bitnet")
|
||||
except Exception as e:
|
||||
print("Error importing `bitnet` arch:", e)
|
||||
ERROR_ARCHES["bitnet"] = e
|
||||
pass
|
||||
|
||||
try:
|
||||
from .mixtral import MixtralModel, MixtralConfig
|
||||
AVAILABLE_ARCHES.append("mixtral")
|
||||
except Exception as e:
|
||||
print("Error importing `mixtral` arch:", e)
|
||||
ERROR_ARCHES["mixtral"] = e
|
||||
|
||||
try:
|
||||
from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig
|
||||
AVAILABLE_ARCHES.append("mamba")
|
||||
AVAILABLE_ARCHES.append("mamba2")
|
||||
except Exception as e:
|
||||
print("Error importing `mamba` arch:", e)
|
||||
ERROR_ARCHES["mamba"] = e
|
||||
ERROR_ARCHES["mamba2"] = e
|
||||
|
||||
try:
|
||||
from .mamba_vasqu import Mamba2Model_HF, Mamba2Config_HF
|
||||
AVAILABLE_ARCHES.append("mamba2-hf")
|
||||
except Exception as e:
|
||||
ERROR_ARCHES["mamba2-hf"] = e
|
||||
|
||||
# desu should remove, perf was very lacking in comparison to regular bitnet
|
||||
try:
|
||||
from .mmfreelm import *
|
||||
AVAILABLE_ARCHES.append("mmfreelm")
|
||||
except Exception as e:
|
||||
print("Error importing `mmfreelm` arch:", e)
|
||||
ERROR_ARCHES["mmfreelm"] = e
|
1
vall_e/models/arch/mamba_vasqu/__init__.py
Normal file
1
vall_e/models/arch/mamba_vasqu/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .mamba2_hf import *
|
4
vall_e/models/arch/mamba_vasqu/mamba2_hf.py
Normal file
4
vall_e/models/arch/mamba_vasqu/mamba2_hf.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
# https://github.com/vasqu/mamba2-torch
|
||||
# NOTE: edit `src/mamba2_torch/__init__.py` to remove reference to .src. because of how pip treats packages
|
||||
|
||||
from mamba2_torch import Mamba2Model as Mamba2Model_HF, Mamba2Config as Mamba2Config_HF
|
|
@ -329,6 +329,10 @@ class Base(nn.Module):
|
|||
|
||||
n_prom_tokens = n_audio_tokens
|
||||
|
||||
# check if requested arch is unavailable
|
||||
if self.arch_type in ERROR_ARCHES:
|
||||
raise ERROR_ARCHES[self.arch_type]
|
||||
|
||||
if "len" not in self.capabilities:
|
||||
# +1 to include the stop token
|
||||
n_resp_tokens = n_audio_tokens + self.causal_size
|
||||
|
@ -592,6 +596,21 @@ class Base(nn.Module):
|
|||
#initializer_cfg=initializer_cfg,
|
||||
)
|
||||
self.model.gradient_checkpointing = self.gradient_checkpointing
|
||||
elif self.arch_type in ["mamba2-hf"]:
|
||||
self.model = Mamba2Model_HF(Mamba2Config_HF(
|
||||
vocab_size=n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
expand=4,
|
||||
num_hidden_layers=n_layers,
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it)
|
||||
))
|
||||
if self.gradient_checkpointing and not self.model.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
elif self.arch_type == "mmfreelm":
|
||||
self.model = HGRNBitModel(HGRNBitConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
|
@ -617,6 +636,9 @@ class Base(nn.Module):
|
|||
else:
|
||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||
|
||||
if hasattr( self.model, "embeddings" ):
|
||||
del self.model.embeddings
|
||||
|
||||
if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]:
|
||||
self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention )
|
||||
|
||||
|
@ -713,6 +735,19 @@ class Base(nn.Module):
|
|||
state = out.past_key_values
|
||||
elif self.arch_type in ["mamba","mamba2"]:
|
||||
x = self.model( hidden_states=x )
|
||||
elif self.arch_type == "mamba2-hf":
|
||||
first = state is None or len(state) == 0
|
||||
|
||||
kwargs = dict(
|
||||
inputs_embeds=x,
|
||||
cache_params=state,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
out = self.model(**kwargs)
|
||||
x = out.last_hidden_state
|
||||
if state is not None:
|
||||
state = out.cache_params
|
||||
elif self.arch_type == "bitnet":
|
||||
x = self.model(x)
|
||||
elif self.arch_type == "mmfreelm":
|
||||
|
|
Loading…
Reference in New Issue
Block a user