oops
This commit is contained in:
parent
1a73ac6a20
commit
cd6e9ba2f2
|
@ -1,6 +1,6 @@
|
|||
from ..config import cfg
|
||||
|
||||
from ..utils.distributed import fix_unset_envs, ddp_model
|
||||
from ..utils.distributed import fix_unset_envs, ddp_model, world_size
|
||||
fix_unset_envs()
|
||||
|
||||
if cfg.trainer.backend == "deepspeed":
|
||||
|
@ -286,7 +286,15 @@ def load_engines(training=True, **model_kwargs):
|
|||
|
||||
# setup wandb
|
||||
if engine._training and cfg.trainer.wandb and wandb is not None:
|
||||
engine.wandb = wandb.init(project=name)
|
||||
key_name = name
|
||||
kwargs = {}
|
||||
if cfg.lora is not None:
|
||||
key_name = cfg.lora.full_name
|
||||
|
||||
if world_size() > 1:
|
||||
kwargs["group"] = "DDP"
|
||||
|
||||
engine.wandb = wandb.init(project=key_name, **kwargs)
|
||||
engine.wandb.watch(engine.module)
|
||||
else:
|
||||
engine.wandb = None
|
||||
|
|
|
@ -466,6 +466,10 @@ class Engines(dict[str, Engine]):
|
|||
def quit(self):
|
||||
cleanup_distributed()
|
||||
|
||||
for name, engine in self.items():
|
||||
if engine.wandb is not None:
|
||||
engine.wandb.finish()
|
||||
|
||||
def step(self, batch, feeder: TrainFeeder = default_feeder):
|
||||
total_elapsed_time = 0
|
||||
|
||||
|
@ -568,7 +572,11 @@ class Engines(dict[str, Engine]):
|
|||
if engine.wandb is not None:
|
||||
engine.wandb.log(model_stats)
|
||||
|
||||
stats.update(flatten_dict({name.split("-")[0]: model_stats}))
|
||||
key_name = name
|
||||
if cfg.lora is not None:
|
||||
key_name = cfg.lora.full_name
|
||||
|
||||
stats.update(flatten_dict({key_name.split("-")[0]: model_stats}))
|
||||
|
||||
self._update()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user