From 0fbfb8bbe8d2da1248ce9f1ffa6992e0d46033a6 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 12 Dec 2024 17:12:59 -0600 Subject: [PATCH] actually save the optimizer for the local engine backend because safetensors doesn't save it --- vall_e/engines/base.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 8609004..61fb4e1 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -142,22 +142,15 @@ class Engine(): if is_global_leader(): module = self.module.state_dict() - # if training lora - # this is a separate path to override saving the weights - lora = None if cfg.lora is not None: - lora, module = lora_get_state_dict( module, split = True ) save_dir = cfg.ckpt_dir / cfg.lora.full_name save_path = save_dir / tag / f"state.{cfg.weights_format}" + save_path_optimizer = save_dir / tag / f"optimizer.pth" save_path.parent.mkdir(parents=True, exist_ok=True) torch_save({ "module": module, - "lora": lora, - "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, - "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, - "stats": { "global_step": self.global_step, "micro_step": self.micro_step, @@ -166,6 +159,11 @@ class Engine(): } }, save_path) + torch_save({ + "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, + "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, + }, save_path_optimizer ) + open(save_dir / "latest", 'w').write( tag ) torch.distributed.barrier() @@ -184,6 +182,7 @@ class Engine(): tag = open(tag_path).read() load_path = load_dir / tag / f"state.{cfg.weights_format}" + load_path_optimizer = load_dir / tag / f"optimizer.pth" if not load_path.exists(): return @@ -196,6 +195,11 @@ class Engine(): self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed'] self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading) + if "optimizer" not in state and load_path_optimizer.exists(): + optimizer_state = torch_load(load_path_optimizer, device=cfg.device) + state["optimizer"] = optimizer_state["optimizer"] if "optimizer" in optimizer_state else None + state["lr_scheduler"] = optimizer_state["lr_scheduler"] if "lr_scheduler" in optimizer_state else None + load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state