diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 2d5c9ea..2dde826 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -1,6 +1,6 @@ from ..config import cfg -from ..utils.distributed import fix_unset_envs, ddp_model +from ..utils.distributed import fix_unset_envs, ddp_model, world_size fix_unset_envs() if cfg.trainer.backend == "deepspeed": @@ -286,7 +286,15 @@ def load_engines(training=True, **model_kwargs): # setup wandb if engine._training and cfg.trainer.wandb and wandb is not None: - engine.wandb = wandb.init(project=name) + key_name = name + kwargs = {} + if cfg.lora is not None: + key_name = cfg.lora.full_name + + if world_size() > 1: + kwargs["group"] = "DDP" + + engine.wandb = wandb.init(project=key_name, **kwargs) engine.wandb.watch(engine.module) else: engine.wandb = None diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 1078740..a0df548 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -466,6 +466,10 @@ class Engines(dict[str, Engine]): def quit(self): cleanup_distributed() + for name, engine in self.items(): + if engine.wandb is not None: + engine.wandb.finish() + def step(self, batch, feeder: TrainFeeder = default_feeder): total_elapsed_time = 0 @@ -568,7 +572,11 @@ class Engines(dict[str, Engine]): if engine.wandb is not None: engine.wandb.log(model_stats) - stats.update(flatten_dict({name.split("-")[0]: model_stats})) + key_name = name + if cfg.lora is not None: + key_name = cfg.lora.full_name + + stats.update(flatten_dict({key_name.split("-")[0]: model_stats})) self._update()