lol
This commit is contained in:
parent
95da4e9405
commit
7d2e64630c
|
@ -146,7 +146,7 @@ def load_engines(training=True, **model_kwargs):
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
elif cfg.hyperparameters.optimizer.lower() == "adagrad":
|
||||||
optimizer_class = ml.Adagrad
|
optimizer_class = ml.Adagrad
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "muon":
|
elif cfg.hyperparameters.optimizer.lower() == "muon":
|
||||||
optimizer = ml.Muon
|
optimizer_class = ml.Muon
|
||||||
|
|
||||||
muon_params = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ]
|
muon_params = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ]
|
||||||
adamw_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ]
|
adamw_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ]
|
||||||
|
@ -221,9 +221,6 @@ def load_engines(training=True, **model_kwargs):
|
||||||
continue
|
continue
|
||||||
state[k] = ml.resize_weight( state[k], tokens )
|
state[k] = ml.resize_weight( state[k], tokens )
|
||||||
|
|
||||||
for k, v in last_embedding_keys.items():
|
|
||||||
state[k][-1] = v
|
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
# load lora weights if exists
|
# load lora weights if exists
|
||||||
|
|
Loading…
Reference in New Issue
Block a user