re-added mamba as a possible non-experimental arch backend (test trainer will set it as AR only, doing any NAR tasks lobotomizes it)

This commit is contained in:
mrq 2024-06-04 22:41:22 -05:00
parent 687c71e028
commit e0886c5a78
3 changed files with 65 additions and 8 deletions

View File

@ -317,7 +317,7 @@ class AR_NAR(Base):
def example_usage():
#cfg.trainer.backend = "local"
cfg.trainer.backend = "local"
cfg.hyperparameters.gradient_accumulation_steps = 1
if cfg.audio_backend == "dac":
cfg.sample_rate = 44_000
@ -334,7 +334,11 @@ def example_usage():
import re
device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.model.prom_levels)
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
if "mamba" in cfg.model.arch_type:
cfg.model.prom_levels = 1
cfg.model.resp_levels = 1
def tokenize(content):
return torch.tensor( cfg.tokenizer.encode(content) )
@ -368,7 +372,7 @@ def example_usage():
'n_tokens': 1024,
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 8, # 32
'n_layers': 12, # 32
'n_experts': 1,
'p_dropout': 0.1,
@ -386,7 +390,7 @@ def example_usage():
"""
model = AR_NAR(**kwargs).to(device)
steps = 50
steps = 250
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
@ -459,11 +463,12 @@ def example_usage():
engine.eval()
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
if cfg.model.max_levels > 1:
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
for i, o in enumerate(resps_list):
_ = decode_to_file(o, f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
unload_model()

View File

@ -155,6 +155,40 @@ except Exception as e:
print("Error importing `mixtral` arch:", e)
try:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs):
if hidden_states is None:
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
if self.gradient_checkpointing and hidden_states.requires_grad:
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False )
else:
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params )
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
hidden_states = MambaLayerNormFn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm_f, MambaRMSNorm)
)
return hidden_states
MambaMixelModel.forward = MambaMixelModel_forward
except Exception as e:
print("Error importing `mixtral` arch:", e)
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
try:
@ -686,6 +720,21 @@ class Base(nn.Module):
ff_mult=4,
gradient_checkpointing=self.gradient_checkpointing,
)
elif self.arch_type in ["mamba","mamba2"]:
self.model = MambaMixelModel(
vocab_size=n_resp_tokens,
d_model=d_model,
n_layer=n_layers*2,
d_intermediate=0,
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if self.arch_type == "mamba2" else {},
rms_norm=True,
fused_add_norm=True,
residual_in_fp32=True,
#attn_layer_idx=attn_layer_idx,
#attn_cfg=attn_cfg,
#initializer_cfg=initializer_cfg,
)
self.model.gradient_checkpointing = self.gradient_checkpointing
else:
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
@ -804,7 +853,8 @@ class Base(nn.Module):
x = out.last_hidden_state
if state is not None:
state = out.past_key_values
elif self.arch_type in ["mamba","mamba2"]:
x = self.model( hidden_states=x )
elif self.arch_type == "bitnet":
x = self.model(x)

View File

@ -158,6 +158,8 @@ class Model(LlmArchClass):
d_model=d_model,
n_layer=n_layers*2,
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if SELECTED_ARCH == "mamba2" else {},
fused_add_norm=True,
residual_in_fp32=True,
))
self.backbone.gradient_checkpointing = gradient_checkpointing