actually going for the suggested "2x layers, no intermediate scaling" is wrong for VALL-E, directly copying the normal transformer structure fixes mamba2 performance in the test trainer
This commit is contained in:
parent
ff97e7480d
commit
83eab4fa59
|
@ -430,7 +430,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 50 if cfg.model.arch_type in ["mamba","mamba2"] else 200
|
||||
steps = 200 if cfg.model.arch_type in ["mamba","mamba2"] else 200
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
|
|
@ -49,7 +49,7 @@ except Exception as e:
|
|||
print("Error importing `mixtral` arch:", e)
|
||||
|
||||
try:
|
||||
from .mamba import MambaMixelModel, MambaLMHeadModel
|
||||
from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig
|
||||
AVAILABLE_ARCHES.append("mamba")
|
||||
AVAILABLE_ARCHES.append("mamba2")
|
||||
except Exception as e:
|
||||
|
|
|
@ -581,8 +581,8 @@ class Base(nn.Module):
|
|||
self.model = MambaMixelModel(
|
||||
vocab_size=n_resp_tokens,
|
||||
d_model=d_model,
|
||||
n_layer=n_layers*2,
|
||||
d_intermediate=0,
|
||||
n_layer=n_layers,
|
||||
d_intermediate=d_model*4,
|
||||
ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {},
|
||||
rms_norm=True,
|
||||
fused_add_norm=True,
|
||||
|
@ -1092,7 +1092,8 @@ class Base(nn.Module):
|
|||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
|
||||
# perform repetition penalizing
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||
if "len" not in self.capabilities:
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||
|
||||
# argmax instead
|
||||
if temperature <= 0.0:
|
||||
|
|
|
@ -106,8 +106,10 @@ class Model(LlmArchClass):
|
|||
super().__init__(config=MambaConfig(
|
||||
vocab_size=vocab_size,
|
||||
d_model=d_model,
|
||||
n_layer=n_layers*2,
|
||||
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if cfg.model.arch_type == "mamba2" else {},
|
||||
n_layer=n_layers,
|
||||
d_intermediate=d_model*4,
|
||||
ssm_cfg={"layer": "Mamba2"} if cfg.model.arch_type == "mamba2" else {},
|
||||
rms_norm=True,
|
||||
fused_add_norm=True,
|
||||
residual_in_fp32=True,
|
||||
))
|
||||
|
|
Loading…
Reference in New Issue
Block a user