From ccb14c06effa36fe7094f9fb2c8a988dfc7769b1 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 14 Jun 2024 19:42:17 -0500 Subject: [PATCH] mamba2-hf using `vasqu/mamba2-torch` because it lets me use mamba2 without triton ops (training with my 4xV100s are not happy with mamba2 because of triton) --- vall_e/models/arch/__init__.py | 27 ++++++++++------ vall_e/models/arch/mamba_vasqu/__init__.py | 1 + vall_e/models/arch/mamba_vasqu/mamba2_hf.py | 4 +++ vall_e/models/base.py | 35 +++++++++++++++++++++ 4 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 vall_e/models/arch/mamba_vasqu/__init__.py create mode 100644 vall_e/models/arch/mamba_vasqu/mamba2_hf.py diff --git a/vall_e/models/arch/__init__.py b/vall_e/models/arch/__init__.py index c0843f4..a8233d2 100755 --- a/vall_e/models/arch/__init__.py +++ b/vall_e/models/arch/__init__.py @@ -1,62 +1,71 @@ AVAILABLE_ARCHES = [] +ERROR_ARCHES = {} try: from .transformer import SinusoidalEmbedding, Block as TransformerBlock AVAILABLE_ARCHES.append("transformer") except Exception as e: - print("Error importing `transformer` arch:", e) + ERROR_ARCHES["transformer"] = e pass try: from .retnet import RetNetDecoder, RetNetConfig AVAILABLE_ARCHES.append("retnet") except Exception as e: - print("Error importing `retnet` arch:", e) + ERROR_ARCHES["retnet"] = e pass try: from .retnet_syncdoth.retnet_ts import RetNetDecoder as RetNetDecoder_TS, RetNetConfig as RetNetConfig_TS AVAILABLE_ARCHES.append("retnet-ts") except Exception as e: - print("Error importing `retnet-ts` arch:", e) + ERROR_ARCHES["retnet-ts"] = e pass try: from .retnet_syncdoth.retnet_hf import RetNetDecoder as RetNetDecoder_HF, RetNetConfig as RetNetConfig_HF, RetNetForCausalLM AVAILABLE_ARCHES.append("retnet-hf") except Exception as e: - print("Error importing `retnet-hf` arch:", e) + ERROR_ARCHES["retnet-hf"] = e pass try: from .llama import LlamaModel, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Base, LlamaForCausalLM AVAILABLE_ARCHES.append("llama") except Exception as e: - print("Error importing `llama` arch:", e) + ERROR_ARCHES["llama"] = e pass try: from .bitnet import BitNetTransformer AVAILABLE_ARCHES.append("bitnet") except Exception as e: - print("Error importing `bitnet` arch:", e) + ERROR_ARCHES["bitnet"] = e pass try: from .mixtral import MixtralModel, MixtralConfig AVAILABLE_ARCHES.append("mixtral") except Exception as e: - print("Error importing `mixtral` arch:", e) + ERROR_ARCHES["mixtral"] = e try: from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig AVAILABLE_ARCHES.append("mamba") AVAILABLE_ARCHES.append("mamba2") except Exception as e: - print("Error importing `mamba` arch:", e) + ERROR_ARCHES["mamba"] = e + ERROR_ARCHES["mamba2"] = e +try: + from .mamba_vasqu import Mamba2Model_HF, Mamba2Config_HF + AVAILABLE_ARCHES.append("mamba2-hf") +except Exception as e: + ERROR_ARCHES["mamba2-hf"] = e + +# desu should remove, perf was very lacking in comparison to regular bitnet try: from .mmfreelm import * AVAILABLE_ARCHES.append("mmfreelm") except Exception as e: - print("Error importing `mmfreelm` arch:", e) \ No newline at end of file + ERROR_ARCHES["mmfreelm"] = e \ No newline at end of file diff --git a/vall_e/models/arch/mamba_vasqu/__init__.py b/vall_e/models/arch/mamba_vasqu/__init__.py new file mode 100644 index 0000000..0c20b1b --- /dev/null +++ b/vall_e/models/arch/mamba_vasqu/__init__.py @@ -0,0 +1 @@ +from .mamba2_hf import * \ No newline at end of file diff --git a/vall_e/models/arch/mamba_vasqu/mamba2_hf.py b/vall_e/models/arch/mamba_vasqu/mamba2_hf.py new file mode 100644 index 0000000..f285963 --- /dev/null +++ b/vall_e/models/arch/mamba_vasqu/mamba2_hf.py @@ -0,0 +1,4 @@ +# https://github.com/vasqu/mamba2-torch +# NOTE: edit `src/mamba2_torch/__init__.py` to remove reference to .src. because of how pip treats packages + +from mamba2_torch import Mamba2Model as Mamba2Model_HF, Mamba2Config as Mamba2Config_HF \ No newline at end of file diff --git a/vall_e/models/base.py b/vall_e/models/base.py index b8e5b92..e3d85a8 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -329,6 +329,10 @@ class Base(nn.Module): n_prom_tokens = n_audio_tokens + # check if requested arch is unavailable + if self.arch_type in ERROR_ARCHES: + raise ERROR_ARCHES[self.arch_type] + if "len" not in self.capabilities: # +1 to include the stop token n_resp_tokens = n_audio_tokens + self.causal_size @@ -592,6 +596,21 @@ class Base(nn.Module): #initializer_cfg=initializer_cfg, ) self.model.gradient_checkpointing = self.gradient_checkpointing + elif self.arch_type in ["mamba2-hf"]: + self.model = Mamba2Model_HF(Mamba2Config_HF( + vocab_size=n_resp_tokens, + hidden_size=d_model, + max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds + expand=4, + num_hidden_layers=n_layers, + is_encoder_decoder=False, + is_decoder=True, + use_triton_kernels=False, # the entire reason is to NOT use triton (because V100s hate it) + )) + if self.gradient_checkpointing and not self.model.gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) elif self.arch_type == "mmfreelm": self.model = HGRNBitModel(HGRNBitConfig( vocab_size=n_resp_tokens, @@ -617,6 +636,9 @@ class Base(nn.Module): else: raise RuntimeError(f'Unknown arch specified: {self.arch_type}') + if hasattr( self.model, "embeddings" ): + del self.model.embeddings + if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: self.model = ml.replace_attention( self.model, klass=LlamaAttention, target=LlamaAttention_Base, mode=self.config.attention ) @@ -713,6 +735,19 @@ class Base(nn.Module): state = out.past_key_values elif self.arch_type in ["mamba","mamba2"]: x = self.model( hidden_states=x ) + elif self.arch_type == "mamba2-hf": + first = state is None or len(state) == 0 + + kwargs = dict( + inputs_embeds=x, + cache_params=state, + return_dict=True, + ) + + out = self.model(**kwargs) + x = out.last_hidden_state + if state is not None: + state = out.cache_params elif self.arch_type == "bitnet": x = self.model(x) elif self.arch_type == "mmfreelm":