From a30dffcca797f319ca6a5fd3058e3c17369b52ed Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 6 Mar 2025 15:44:40 -0600 Subject: [PATCH] wandb additions (to-do eventually, upload samples as artifacts) --- vall_e/engines/__init__.py | 5 +++++ vall_e/train.py | 12 +++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 5ec6569..7d83028 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -318,6 +318,11 @@ def load_engines(training=True, **model_kwargs): kwargs["group"] = "DDP" kwargs['id'] = f'{key_name}-{salt}-{global_rank()}' + kwargs['config'] = dict( + config = engine.hyper_config.__dict__, + hyperparameters = cfg.hyperparameters.__dict__, + ) + try: engine.wandb = wandb.init(project=key_name, **kwargs) engine.wandb.watch(engine.module) diff --git a/vall_e/train.py b/vall_e/train.py index ded1ec4..972ae45 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -246,9 +246,19 @@ def run_eval(engines, eval_name, dl, args=None): stats = {k: sum(v) / len(v) for k, v in stats.items() if v} engines_stats = { - f'{name}.{eval_name}': stats, + eval_name: stats, "it": engines.global_step, } + + try: + for engine in engines: + if engine.wandb is not None: + engine.wandb.log({ + f'{eval_name}.loss.mstft': stats['loss'], + }, step=engine.global_step) + except Exception as e: + print(e) + #engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")