From b105f6211ea1c642777720df5a85b6d3c997e8c2 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 20 Aug 2023 13:39:58 -0500 Subject: [PATCH] added ability to export weights mid-training to avoid CBT to yank the weights while the training script is running --- vall_e/engines/base.py | 22 +++++++++++++++++++--- vall_e/export.py | 11 +---------- vall_e/utils/trainer.py | 21 +++++++++++++++++++-- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 5f11951..f820657 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -200,6 +200,9 @@ class Engines(dict[str, Engine]): self._global_step = 0 self._micro_step = 0 + for name, engine in self.items(): + engine.name = name + @property def global_step(self): return self._global_step @@ -218,6 +221,18 @@ class Engines(dict[str, Engine]): for engine in self.values(): engine.dispatch_attribute(*args, **kwargs) + def export(self, userdata={}): + for name, engine in self.items(): + outpath = cfg.ckpt_dir / name / "fp32.pth" + state_dict = { + "global_step": engine.global_step, + "micro_step": engine.micro_step, + 'module': engine.module.state_dict(), + } + state_dict.update(userdata) + torch.save(state_dict, outpath) + print(f"Exported {name} to {outpath}") + def save_checkpoint(self, tag=None): if not tag: tag = cfg.trainer.save_tag @@ -246,7 +261,7 @@ class Engines(dict[str, Engine]): p.unlink() d.rmdir() - def load_checkpoint(self, tag=None): + def load_checkpoint(self, tag=None, module_only=False): if not tag: tag = cfg.trainer.load_tag @@ -256,8 +271,9 @@ class Engines(dict[str, Engine]): tag=tag, load_dir=load_dir, load_module_strict=cfg.trainer.strict_loading, - load_optimizer_states=cfg.trainer.load_states, - load_lr_scheduler_states=cfg.trainer.load_states, + load_optimizer_states=False if module_only else cfg.trainer.load_states, + load_lr_scheduler_states=False if module_only else cfg.trainer.load_states, + load_module_only=module_only, ) if cfg.trainer.restart_step_count: engine.global_steps = 0 diff --git a/vall_e/export.py b/vall_e/export.py index 8f51871..464803b 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -12,16 +12,7 @@ def main(): args = parser.parse_args() engines = load_engines() - for name, engine in engines.items(): - outpath = cfg.ckpt_dir / name / "fp32.pth" - torch.save({ - "global_step": engine.global_step, - "micro_step": engine.micro_step, - 'module': engine.module.to('cpu', dtype=torch.float32).state_dict(), - #'optimizer': engine.optimizer.state_dict(), - 'symmap': get_phone_symmap(), - }, outpath) - print(f"Exported {name} to {outpath}") + engines.export(userdata={"symmap": get_phone_symmap()}) if __name__ == "__main__": main() \ No newline at end of file diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 4292c61..4b795eb 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -33,6 +33,7 @@ from ..models import get_models from .utils import to_device, do_gc from ..utils import wrapper as ml +from ..data import get_phone_symmap # should decouple from this trainer script _logger = logging.getLogger(__name__) _engines: Engines @@ -69,6 +70,9 @@ def load_engines(): optimizer = None lr_scheduler = None + # yuck, should instead check be optimier == "adamw" and backend != "deepspeed" + # and then have ds_cfg pass in the config flag to use pytorch adamw + # I genuinely cannot validate if this ever actually gets used in DeepSpeed if cfg.hyperparameters.optimizer.lower() == "adamw-torch": optimizer = ml.AdamW( model.parameters(), @@ -86,6 +90,9 @@ def load_engines(): if "module" in state: state = state["module"] + # should decouple the following from this trainer script + # probably with passing a fun that defaults to a lambda x: x deal + # extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks) if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]: o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape @@ -301,8 +308,18 @@ def train( if "lr" in command: rate = float(command.split(" ")[-1]) - engines.set_lr(rate) - print("Updating LR to:", rate) + try: + engines.set_lr(rate) + print("Updating LR to:", rate) + except Exception as e: + print("Failed to set LR rate to:", rate, str(e)) + + if "export" in command: + engines.save_checkpoint() + last_save_step = engines.global_step + + if is_global_leader(): + engines.export(userdata={"symmap": get_phone_symmap()}) save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency