actually going for the suggested "2x layers, no intermediate scaling" is wrong for VALL-E, directly copying the normal transformer structure fixes mamba2 performance in the test trainer
This commit is contained in:
parent
ff97e7480d
commit
83eab4fa59
|
@ -430,7 +430,7 @@ def example_usage():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
steps = 50 if cfg.model.arch_type in ["mamba","mamba2"] else 200
|
steps = 200 if cfg.model.arch_type in ["mamba","mamba2"] else 200
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
|
|
|
@ -49,7 +49,7 @@ except Exception as e:
|
||||||
print("Error importing `mixtral` arch:", e)
|
print("Error importing `mixtral` arch:", e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .mamba import MambaMixelModel, MambaLMHeadModel
|
from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig
|
||||||
AVAILABLE_ARCHES.append("mamba")
|
AVAILABLE_ARCHES.append("mamba")
|
||||||
AVAILABLE_ARCHES.append("mamba2")
|
AVAILABLE_ARCHES.append("mamba2")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -581,8 +581,8 @@ class Base(nn.Module):
|
||||||
self.model = MambaMixelModel(
|
self.model = MambaMixelModel(
|
||||||
vocab_size=n_resp_tokens,
|
vocab_size=n_resp_tokens,
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
n_layer=n_layers*2,
|
n_layer=n_layers,
|
||||||
d_intermediate=0,
|
d_intermediate=d_model*4,
|
||||||
ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {},
|
ssm_cfg={"layer": "Mamba2"} if self.arch_type == "mamba2" else {},
|
||||||
rms_norm=True,
|
rms_norm=True,
|
||||||
fused_add_norm=True,
|
fused_add_norm=True,
|
||||||
|
@ -1092,7 +1092,8 @@ class Base(nn.Module):
|
||||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||||
|
|
||||||
# perform repetition penalizing
|
# perform repetition penalizing
|
||||||
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
if "len" not in self.capabilities:
|
||||||
|
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||||
|
|
||||||
# argmax instead
|
# argmax instead
|
||||||
if temperature <= 0.0:
|
if temperature <= 0.0:
|
||||||
|
|
|
@ -106,8 +106,10 @@ class Model(LlmArchClass):
|
||||||
super().__init__(config=MambaConfig(
|
super().__init__(config=MambaConfig(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
n_layer=n_layers*2,
|
n_layer=n_layers,
|
||||||
ssm_cfg={"layer": "Mamba2", "chunk_size":64} if cfg.model.arch_type == "mamba2" else {},
|
d_intermediate=d_model*4,
|
||||||
|
ssm_cfg={"layer": "Mamba2"} if cfg.model.arch_type == "mamba2" else {},
|
||||||
|
rms_norm=True,
|
||||||
fused_add_norm=True,
|
fused_add_norm=True,
|
||||||
residual_in_fp32=True,
|
residual_in_fp32=True,
|
||||||
))
|
))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user