This commit is contained in:
mrq 2024-11-20 16:27:51 -06:00
parent 1a73ac6a20
commit cd6e9ba2f2
2 changed files with 19 additions and 3 deletions

View File

@ -1,6 +1,6 @@
from ..config import cfg
from ..utils.distributed import fix_unset_envs, ddp_model
from ..utils.distributed import fix_unset_envs, ddp_model, world_size
fix_unset_envs()
if cfg.trainer.backend == "deepspeed":
@ -286,7 +286,15 @@ def load_engines(training=True, **model_kwargs):
# setup wandb
if engine._training and cfg.trainer.wandb and wandb is not None:
engine.wandb = wandb.init(project=name)
key_name = name
kwargs = {}
if cfg.lora is not None:
key_name = cfg.lora.full_name
if world_size() > 1:
kwargs["group"] = "DDP"
engine.wandb = wandb.init(project=key_name, **kwargs)
engine.wandb.watch(engine.module)
else:
engine.wandb = None

View File

@ -466,6 +466,10 @@ class Engines(dict[str, Engine]):
def quit(self):
cleanup_distributed()
for name, engine in self.items():
if engine.wandb is not None:
engine.wandb.finish()
def step(self, batch, feeder: TrainFeeder = default_feeder):
total_elapsed_time = 0
@ -568,7 +572,11 @@ class Engines(dict[str, Engine]):
if engine.wandb is not None:
engine.wandb.log(model_stats)
stats.update(flatten_dict({name.split("-")[0]: model_stats}))
key_name = name
if cfg.lora is not None:
key_name = cfg.lora.full_name
stats.update(flatten_dict({key_name.split("-")[0]: model_stats}))
self._update()