This commit is contained in:
mrq 2024-11-20 16:27:51 -06:00
parent 1a73ac6a20
commit cd6e9ba2f2
2 changed files with 19 additions and 3 deletions

View File

@ -1,6 +1,6 @@
from ..config import cfg from ..config import cfg
from ..utils.distributed import fix_unset_envs, ddp_model from ..utils.distributed import fix_unset_envs, ddp_model, world_size
fix_unset_envs() fix_unset_envs()
if cfg.trainer.backend == "deepspeed": if cfg.trainer.backend == "deepspeed":
@ -286,7 +286,15 @@ def load_engines(training=True, **model_kwargs):
# setup wandb # setup wandb
if engine._training and cfg.trainer.wandb and wandb is not None: if engine._training and cfg.trainer.wandb and wandb is not None:
engine.wandb = wandb.init(project=name) key_name = name
kwargs = {}
if cfg.lora is not None:
key_name = cfg.lora.full_name
if world_size() > 1:
kwargs["group"] = "DDP"
engine.wandb = wandb.init(project=key_name, **kwargs)
engine.wandb.watch(engine.module) engine.wandb.watch(engine.module)
else: else:
engine.wandb = None engine.wandb = None

View File

@ -466,6 +466,10 @@ class Engines(dict[str, Engine]):
def quit(self): def quit(self):
cleanup_distributed() cleanup_distributed()
for name, engine in self.items():
if engine.wandb is not None:
engine.wandb.finish()
def step(self, batch, feeder: TrainFeeder = default_feeder): def step(self, batch, feeder: TrainFeeder = default_feeder):
total_elapsed_time = 0 total_elapsed_time = 0
@ -568,7 +572,11 @@ class Engines(dict[str, Engine]):
if engine.wandb is not None: if engine.wandb is not None:
engine.wandb.log(model_stats) engine.wandb.log(model_stats)
stats.update(flatten_dict({name.split("-")[0]: model_stats})) key_name = name
if cfg.lora is not None:
key_name = cfg.lora.full_name
stats.update(flatten_dict({key_name.split("-")[0]: model_stats}))
self._update() self._update()