wandb additions (to-do eventually, upload samples as artifacts)

This commit is contained in:
mrq 2025-03-06 15:44:40 -06:00
parent ec87308d75
commit a30dffcca7
2 changed files with 16 additions and 1 deletions
vall_e

View File

@ -318,6 +318,11 @@ def load_engines(training=True, **model_kwargs):
kwargs["group"] = "DDP"
kwargs['id'] = f'{key_name}-{salt}-{global_rank()}'
kwargs['config'] = dict(
config = engine.hyper_config.__dict__,
hyperparameters = cfg.hyperparameters.__dict__,
)
try:
engine.wandb = wandb.init(project=key_name, **kwargs)
engine.wandb.watch(engine.module)

View File

@ -246,9 +246,19 @@ def run_eval(engines, eval_name, dl, args=None):
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
engines_stats = {
f'{name}.{eval_name}': stats,
eval_name: stats,
"it": engines.global_step,
}
try:
for engine in engines:
if engine.wandb is not None:
engine.wandb.log({
f'{eval_name}.loss.mstft': stats['loss'],
}, step=engine.global_step)
except Exception as e:
print(e)
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")