wandb additions (to-do eventually, upload samples as artifacts)
This commit is contained in:
parent
ec87308d75
commit
a30dffcca7
vall_e
|
@ -318,6 +318,11 @@ def load_engines(training=True, **model_kwargs):
|
|||
kwargs["group"] = "DDP"
|
||||
kwargs['id'] = f'{key_name}-{salt}-{global_rank()}'
|
||||
|
||||
kwargs['config'] = dict(
|
||||
config = engine.hyper_config.__dict__,
|
||||
hyperparameters = cfg.hyperparameters.__dict__,
|
||||
)
|
||||
|
||||
try:
|
||||
engine.wandb = wandb.init(project=key_name, **kwargs)
|
||||
engine.wandb.watch(engine.module)
|
||||
|
|
|
@ -246,9 +246,19 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
|
||||
stats = {k: sum(v) / len(v) for k, v in stats.items() if v}
|
||||
engines_stats = {
|
||||
f'{name}.{eval_name}': stats,
|
||||
eval_name: stats,
|
||||
"it": engines.global_step,
|
||||
}
|
||||
|
||||
try:
|
||||
for engine in engines:
|
||||
if engine.wandb is not None:
|
||||
engine.wandb.log({
|
||||
f'{eval_name}.loss.mstft': stats['loss'],
|
||||
}, step=engine.global_step)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||
|
||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
|
|
Loading…
Reference in New Issue
Block a user