actually save the optimizer for the local engine backend because safetensors doesn't save it
This commit is contained in:
parent
f41251f648
commit
0fbfb8bbe8
|
@ -142,22 +142,15 @@ class Engine():
|
||||||
if is_global_leader():
|
if is_global_leader():
|
||||||
module = self.module.state_dict()
|
module = self.module.state_dict()
|
||||||
|
|
||||||
# if training lora
|
|
||||||
# this is a separate path to override saving the weights
|
|
||||||
lora = None
|
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
lora, module = lora_get_state_dict( module, split = True )
|
|
||||||
save_dir = cfg.ckpt_dir / cfg.lora.full_name
|
save_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||||
|
|
||||||
save_path = save_dir / tag / f"state.{cfg.weights_format}"
|
save_path = save_dir / tag / f"state.{cfg.weights_format}"
|
||||||
|
save_path_optimizer = save_dir / tag / f"optimizer.pth"
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
torch_save({
|
torch_save({
|
||||||
"module": module,
|
"module": module,
|
||||||
"lora": lora,
|
|
||||||
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
|
||||||
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
|
||||||
|
|
||||||
"stats": {
|
"stats": {
|
||||||
"global_step": self.global_step,
|
"global_step": self.global_step,
|
||||||
"micro_step": self.micro_step,
|
"micro_step": self.micro_step,
|
||||||
|
@ -166,6 +159,11 @@ class Engine():
|
||||||
}
|
}
|
||||||
}, save_path)
|
}, save_path)
|
||||||
|
|
||||||
|
torch_save({
|
||||||
|
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
||||||
|
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
||||||
|
}, save_path_optimizer )
|
||||||
|
|
||||||
open(save_dir / "latest", 'w').write( tag )
|
open(save_dir / "latest", 'w').write( tag )
|
||||||
|
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
@ -184,6 +182,7 @@ class Engine():
|
||||||
tag = open(tag_path).read()
|
tag = open(tag_path).read()
|
||||||
|
|
||||||
load_path = load_dir / tag / f"state.{cfg.weights_format}"
|
load_path = load_dir / tag / f"state.{cfg.weights_format}"
|
||||||
|
load_path_optimizer = load_dir / tag / f"optimizer.pth"
|
||||||
|
|
||||||
if not load_path.exists():
|
if not load_path.exists():
|
||||||
return
|
return
|
||||||
|
@ -196,6 +195,11 @@ class Engine():
|
||||||
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
|
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
|
||||||
self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)
|
self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
|
if "optimizer" not in state and load_path_optimizer.exists():
|
||||||
|
optimizer_state = torch_load(load_path_optimizer, device=cfg.device)
|
||||||
|
state["optimizer"] = optimizer_state["optimizer"] if "optimizer" in optimizer_state else None
|
||||||
|
state["lr_scheduler"] = optimizer_state["lr_scheduler"] if "lr_scheduler" in optimizer_state else None
|
||||||
|
|
||||||
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
||||||
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user