oops
This commit is contained in:
parent
1a73ac6a20
commit
cd6e9ba2f2
|
@ -1,6 +1,6 @@
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
|
|
||||||
from ..utils.distributed import fix_unset_envs, ddp_model
|
from ..utils.distributed import fix_unset_envs, ddp_model, world_size
|
||||||
fix_unset_envs()
|
fix_unset_envs()
|
||||||
|
|
||||||
if cfg.trainer.backend == "deepspeed":
|
if cfg.trainer.backend == "deepspeed":
|
||||||
|
@ -286,7 +286,15 @@ def load_engines(training=True, **model_kwargs):
|
||||||
|
|
||||||
# setup wandb
|
# setup wandb
|
||||||
if engine._training and cfg.trainer.wandb and wandb is not None:
|
if engine._training and cfg.trainer.wandb and wandb is not None:
|
||||||
engine.wandb = wandb.init(project=name)
|
key_name = name
|
||||||
|
kwargs = {}
|
||||||
|
if cfg.lora is not None:
|
||||||
|
key_name = cfg.lora.full_name
|
||||||
|
|
||||||
|
if world_size() > 1:
|
||||||
|
kwargs["group"] = "DDP"
|
||||||
|
|
||||||
|
engine.wandb = wandb.init(project=key_name, **kwargs)
|
||||||
engine.wandb.watch(engine.module)
|
engine.wandb.watch(engine.module)
|
||||||
else:
|
else:
|
||||||
engine.wandb = None
|
engine.wandb = None
|
||||||
|
|
|
@ -466,6 +466,10 @@ class Engines(dict[str, Engine]):
|
||||||
def quit(self):
|
def quit(self):
|
||||||
cleanup_distributed()
|
cleanup_distributed()
|
||||||
|
|
||||||
|
for name, engine in self.items():
|
||||||
|
if engine.wandb is not None:
|
||||||
|
engine.wandb.finish()
|
||||||
|
|
||||||
def step(self, batch, feeder: TrainFeeder = default_feeder):
|
def step(self, batch, feeder: TrainFeeder = default_feeder):
|
||||||
total_elapsed_time = 0
|
total_elapsed_time = 0
|
||||||
|
|
||||||
|
@ -568,7 +572,11 @@ class Engines(dict[str, Engine]):
|
||||||
if engine.wandb is not None:
|
if engine.wandb is not None:
|
||||||
engine.wandb.log(model_stats)
|
engine.wandb.log(model_stats)
|
||||||
|
|
||||||
stats.update(flatten_dict({name.split("-")[0]: model_stats}))
|
key_name = name
|
||||||
|
if cfg.lora is not None:
|
||||||
|
key_name = cfg.lora.full_name
|
||||||
|
|
||||||
|
stats.update(flatten_dict({key_name.split("-")[0]: model_stats}))
|
||||||
|
|
||||||
self._update()
|
self._update()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user