This commit is contained in:
mrq 2025-02-26 10:49:06 -06:00
parent 95da4e9405
commit 7d2e64630c

View File

@ -146,7 +146,7 @@ def load_engines(training=True, **model_kwargs):
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
optimizer_class = ml.Adagrad
elif cfg.hyperparameters.optimizer.lower() == "muon":
optimizer = ml.Muon
optimizer_class = ml.Muon
muon_params = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ]
adamw_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ]
@ -221,9 +221,6 @@ def load_engines(training=True, **model_kwargs):
continue
state[k] = ml.resize_weight( state[k], tokens )
for k, v in last_embedding_keys.items():
state[k][-1] = v
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# load lora weights if exists