actually save the optimizer for the local engine backend because safetensors doesn't save it

This commit is contained in:
mrq 2024-12-12 17:12:59 -06:00
parent f41251f648
commit 0fbfb8bbe8

View File

@ -142,22 +142,15 @@ class Engine():
if is_global_leader(): if is_global_leader():
module = self.module.state_dict() module = self.module.state_dict()
# if training lora
# this is a separate path to override saving the weights
lora = None
if cfg.lora is not None: if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True )
save_dir = cfg.ckpt_dir / cfg.lora.full_name save_dir = cfg.ckpt_dir / cfg.lora.full_name
save_path = save_dir / tag / f"state.{cfg.weights_format}" save_path = save_dir / tag / f"state.{cfg.weights_format}"
save_path_optimizer = save_dir / tag / f"optimizer.pth"
save_path.parent.mkdir(parents=True, exist_ok=True) save_path.parent.mkdir(parents=True, exist_ok=True)
torch_save({ torch_save({
"module": module, "module": module,
"lora": lora,
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
"stats": { "stats": {
"global_step": self.global_step, "global_step": self.global_step,
"micro_step": self.micro_step, "micro_step": self.micro_step,
@ -166,6 +159,11 @@ class Engine():
} }
}, save_path) }, save_path)
torch_save({
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
}, save_path_optimizer )
open(save_dir / "latest", 'w').write( tag ) open(save_dir / "latest", 'w').write( tag )
torch.distributed.barrier() torch.distributed.barrier()
@ -184,6 +182,7 @@ class Engine():
tag = open(tag_path).read() tag = open(tag_path).read()
load_path = load_dir / tag / f"state.{cfg.weights_format}" load_path = load_dir / tag / f"state.{cfg.weights_format}"
load_path_optimizer = load_dir / tag / f"optimizer.pth"
if not load_path.exists(): if not load_path.exists():
return return
@ -196,6 +195,11 @@ class Engine():
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed'] self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading) self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)
if "optimizer" not in state and load_path_optimizer.exists():
optimizer_state = torch_load(load_path_optimizer, device=cfg.device)
state["optimizer"] = optimizer_state["optimizer"] if "optimizer" in optimizer_state else None
state["lr_scheduler"] = optimizer_state["lr_scheduler"] if "lr_scheduler" in optimizer_state else None
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state