diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 5ec6569..7d83028 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -318,6 +318,11 @@ def load_engines(training=True, **model_kwargs): kwargs["group"] = "DDP" kwargs['id'] = f'{key_name}-{salt}-{global_rank()}' + kwargs['config'] = dict( + config = engine.hyper_config.__dict__, + hyperparameters = cfg.hyperparameters.__dict__, + ) + try: engine.wandb = wandb.init(project=key_name, **kwargs) engine.wandb.watch(engine.module) diff --git a/vall_e/train.py b/vall_e/train.py index ded1ec4..972ae45 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -246,9 +246,19 @@ def run_eval(engines, eval_name, dl, args=None): stats = {k: sum(v) / len(v) for k, v in stats.items() if v} engines_stats = { - f'{name}.{eval_name}': stats, + eval_name: stats, "it": engines.global_step, } + + try: + for engine in engines: + if engine.wandb is not None: + engine.wandb.log({ + f'{eval_name}.loss.mstft': stats['loss'], + }, step=engine.global_step) + except Exception as e: + print(e) + #engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")