re-added mamba as a possible non-experimental arch backend (test trainer will set it as AR only, doing any NAR tasks lobotomizes it)
This commit is contained in:
parent
687c71e028
commit
e0886c5a78
|
@ -317,7 +317,7 @@ class AR_NAR(Base):
|
|||
|
||||
|
||||
def example_usage():
|
||||
#cfg.trainer.backend = "local"
|
||||
cfg.trainer.backend = "local"
|
||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||
if cfg.audio_backend == "dac":
|
||||
cfg.sample_rate = 44_000
|
||||
|
@ -334,7 +334,11 @@ def example_usage():
|
|||
import re
|
||||
|
||||
device = "cuda"
|
||||
x8 = partial(repeat, pattern="t -> t l", l=cfg.model.prom_levels)
|
||||
|
||||
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
||||
if "mamba" in cfg.model.arch_type:
|
||||
cfg.model.prom_levels = 1
|
||||
cfg.model.resp_levels = 1
|
||||
|
||||
def tokenize(content):
|
||||
return torch.tensor( cfg.tokenizer.encode(content) )
|
||||
|
@ -368,7 +372,7 @@ def example_usage():
|
|||
'n_tokens': 1024,
|
||||
'd_model': 1024, # 256, # 1024, # 1536
|
||||
'n_heads': 16, # 4, # 16, # 24
|
||||
'n_layers': 8, # 32
|
||||
'n_layers': 12, # 32
|
||||
'n_experts': 1,
|
||||
|
||||
'p_dropout': 0.1,
|
||||
|
@ -386,7 +390,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 50
|
||||
steps = 250
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
||||
|
@ -459,11 +463,12 @@ def example_usage():
|
|||
|
||||
engine.eval()
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||
if cfg.model.max_levels > 1:
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||
|
||||
for i, o in enumerate(resps_list):
|
||||
_ = decode_to_file(o, f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||
|
||||
unload_model()
|
||||
|
||||
|
|
|
@ -155,6 +155,40 @@ except Exception as e:
|
|||
print("Error importing `mixtral` arch:", e)
|
||||
|
||||
|
||||
try:
|
||||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
|
||||
|
||||
def MambaMixelModel_forward(self, input_ids=None, hidden_states=None, inference_params=None, **mixer_kwargs):
|
||||
if hidden_states is None:
|
||||
hidden_states = self.embedding(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
if self.gradient_checkpointing and hidden_states.requires_grad:
|
||||
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False )
|
||||
else:
|
||||
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params )
|
||||
if not self.fused_add_norm:
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
||||
else:
|
||||
# Set prenorm=False here since we don't need the residual
|
||||
hidden_states = MambaLayerNormFn(
|
||||
hidden_states,
|
||||
self.norm_f.weight,
|
||||
self.norm_f.bias,
|
||||
eps=self.norm_f.eps,
|
||||
residual=residual,
|
||||
prenorm=False,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
is_rms_norm=isinstance(self.norm_f, MambaRMSNorm)
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
MambaMixelModel.forward = MambaMixelModel_forward
|
||||
except Exception as e:
|
||||
print("Error importing `mixtral` arch:", e)
|
||||
|
||||
|
||||
AVAILABLE_ATTENTIONS = ["mem_efficient", "math"]
|
||||
|
||||
try:
|
||||
|
@ -686,6 +720,21 @@ class Base(nn.Module):
|
|||
ff_mult=4,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
elif self.arch_type in ["mamba","mamba2"]:
|
||||
self.model = MambaMixelModel(
|
||||
vocab_size=n_resp_tokens,
|
||||
d_model=d_model,
|
||||
n_layer=n_layers*2,
|
||||
d_intermediate=0,
|
||||
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if self.arch_type == "mamba2" else {},
|
||||
rms_norm=True,
|
||||
fused_add_norm=True,
|
||||
residual_in_fp32=True,
|
||||
#attn_layer_idx=attn_layer_idx,
|
||||
#attn_cfg=attn_cfg,
|
||||
#initializer_cfg=initializer_cfg,
|
||||
)
|
||||
self.model.gradient_checkpointing = self.gradient_checkpointing
|
||||
else:
|
||||
raise RuntimeError(f'Unknown arch specified: {self.arch_type}')
|
||||
|
||||
|
@ -804,7 +853,8 @@ class Base(nn.Module):
|
|||
x = out.last_hidden_state
|
||||
if state is not None:
|
||||
state = out.past_key_values
|
||||
|
||||
elif self.arch_type in ["mamba","mamba2"]:
|
||||
x = self.model( hidden_states=x )
|
||||
elif self.arch_type == "bitnet":
|
||||
x = self.model(x)
|
||||
|
||||
|
|
|
@ -158,6 +158,8 @@ class Model(LlmArchClass):
|
|||
d_model=d_model,
|
||||
n_layer=n_layers*2,
|
||||
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if SELECTED_ARCH == "mamba2" else {},
|
||||
fused_add_norm=True,
|
||||
residual_in_fp32=True,
|
||||
))
|
||||
|
||||
self.backbone.gradient_checkpointing = gradient_checkpointing
|
||||
|
|
Loading…
Reference in New Issue
Block a user