This commit is contained in:
mrq 2024-11-20 19:21:03 -06:00
parent cd6e9ba2f2
commit dfdba3f190

View File

@ -464,11 +464,11 @@ class Engines(dict[str, Engine]):
return stats return stats
def quit(self): def quit(self):
cleanup_distributed()
for name, engine in self.items(): for name, engine in self.items():
if engine.wandb is not None: if engine.wandb is not None:
engine.wandb.finish() engine.wandb.finish()
cleanup_distributed()
def step(self, batch, feeder: TrainFeeder = default_feeder): def step(self, batch, feeder: TrainFeeder = default_feeder):
total_elapsed_time = 0 total_elapsed_time = 0
@ -560,17 +560,20 @@ class Engines(dict[str, Engine]):
model_stats = dict( model_stats = dict(
**engine_stats, **engine_stats,
lr=engine.get_lr()[0],
grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm, grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm,
loss_scale=loss_scale if loss_scale != 1 else None, loss_scale=loss_scale if loss_scale != 1 else None,
)
if engine.wandb is not None:
engine.wandb.log(model_stats, step=engine.global_step)
model_stats = model_stats | dict(
lr=engine.get_lr()[0],
elapsed_time=elapsed_time, elapsed_time=elapsed_time,
engine_step=engine.global_step, engine_step=engine.global_step,
samples_processed=engine.global_samples, samples_processed=engine.global_samples,
tokens_processed=engine.tokens_processed, tokens_processed=engine.tokens_processed,
) )
if engine.wandb is not None:
engine.wandb.log(model_stats)
key_name = name key_name = name
if cfg.lora is not None: if cfg.lora is not None: