diff --git a/.gitignore b/.gitignore index d806633..1439eba 100755 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ __pycache__ /vall_e/version.py /.cache /voices +/wandb \ No newline at end of file diff --git a/vall_e/config.py b/vall_e/config.py index 51a1baf..9d3d91a 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -695,6 +695,8 @@ class Trainer: check_for_oom: bool = True # checks for OOMs thrown during forward/backwards gc_mode: str | None = None # deprecated, but marks when to do GC + wandb: bool = True # use wandb, if available + weight_dtype: str = "float16" # dtype to have the model under amp: bool = False # automatic mixed precision diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index a3d6e29..2d5c9ea 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -28,6 +28,12 @@ try: except Exception as e: pass +try: + import wandb +except Exception as e: + _logger.warning(f'Failed to import wandb: {str(e)}') + wandb = None + from functools import cache @cache @@ -278,4 +284,11 @@ def load_engines(training=True, **model_kwargs): if cfg.optimizations.model_offloading: engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading ) + # setup wandb + if engine._training and cfg.trainer.wandb and wandb is not None: + engine.wandb = wandb.init(project=name) + engine.wandb.watch(engine.module) + else: + engine.wandb = None + return engines diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 0302346..1078740 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -554,22 +554,21 @@ class Engines(dict[str, Engine]): if grad_norm is not None: grad_norm /= loss_scale - stats.update( - flatten_dict( - { - name.split("-")[0]: dict( - **engine_stats, - lr=engine.get_lr()[0], - grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm, - loss_scale=loss_scale if loss_scale != 1 else None, - elapsed_time=elapsed_time, - engine_step=engine.global_step, - samples_processed=engine.global_samples, - tokens_processed=engine.tokens_processed, - ) - } - ), + model_stats = dict( + **engine_stats, + lr=engine.get_lr()[0], + grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm, + loss_scale=loss_scale if loss_scale != 1 else None, + elapsed_time=elapsed_time, + engine_step=engine.global_step, + samples_processed=engine.global_samples, + tokens_processed=engine.tokens_processed, ) + + if engine.wandb is not None: + engine.wandb.log(model_stats) + + stats.update(flatten_dict({name.split("-")[0]: model_stats})) self._update()