actually going for the suggested "2x layers, no intermediate scaling" is wrong for VALL-E, directly copying the normal transformer structure fixes mamba2 performance in the test trainer

This commit is contained in:
mrq 2024-06-13 20:08:22 -05:00
parent ff97e7480d
commit 83eab4fa59
4 changed files with 10 additions and 7 deletions

View File

@ -430,7 +430,7 @@ def example_usage():
"""
model = AR_NAR(**kwargs).to(device)
steps = 50 if cfg.model.arch_type in ["mamba","mamba2"] else 200
steps = 200 if cfg.model.arch_type in ["mamba","mamba2"] else 200
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""

View File

@ -49,7 +49,7 @@ except Exception as e:
print("Error importing `mixtral` arch:", e)
try:
from .mamba import MambaMixelModel, MambaLMHeadModel
from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig
AVAILABLE_ARCHES.append("mamba")
AVAILABLE_ARCHES.append("mamba2")
except Exception as e:

View File

@ -581,8 +581,8 @@ class Base(nn.Module):
self.model = MambaMixelModel(
vocab_size=n_resp_tokens,
d_model=d_model,
n_layer=n_layers*2,
d_intermediate=0,
n_layer=n_layers,
d_intermediate=d_model*4,
ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {},
rms_norm=True,
fused_add_norm=True,
@ -1092,7 +1092,8 @@ class Base(nn.Module):
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
# perform repetition penalizing
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
if "len" not in self.capabilities:
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
# argmax instead
if temperature <= 0.0:

View File

@ -106,8 +106,10 @@ class Model(LlmArchClass):
super().__init__(config=MambaConfig(
vocab_size=vocab_size,
d_model=d_model,
n_layer=n_layers*2,
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if cfg.model.arch_type == "mamba2" else {},
n_layer=n_layers,
d_intermediate=d_model*4,
ssm_cfg={"layer": "Mamba2"} if cfg.model.arch_type == "mamba2" else {},
rms_norm=True,
fused_add_norm=True,
residual_in_fp32=True,
))