I cannot believe it's not actually called Wand DB (added wandb logging support since I think it would have been a much better way to look at my metrics)

This commit is contained in:
mrq 2024-11-20 16:10:47 -06:00
parent 67f7bad168
commit 1a73ac6a20
4 changed files with 30 additions and 15 deletions

1
.gitignore vendored
View File

@ -6,3 +6,4 @@ __pycache__
/vall_e/version.py /vall_e/version.py
/.cache /.cache
/voices /voices
/wandb

View File

@ -695,6 +695,8 @@ class Trainer:
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
gc_mode: str | None = None # deprecated, but marks when to do GC gc_mode: str | None = None # deprecated, but marks when to do GC
wandb: bool = True # use wandb, if available
weight_dtype: str = "float16" # dtype to have the model under weight_dtype: str = "float16" # dtype to have the model under
amp: bool = False # automatic mixed precision amp: bool = False # automatic mixed precision

View File

@ -28,6 +28,12 @@ try:
except Exception as e: except Exception as e:
pass pass
try:
import wandb
except Exception as e:
_logger.warning(f'Failed to import wandb: {str(e)}')
wandb = None
from functools import cache from functools import cache
@cache @cache
@ -278,4 +284,11 @@ def load_engines(training=True, **model_kwargs):
if cfg.optimizations.model_offloading: if cfg.optimizations.model_offloading:
engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading ) engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
# setup wandb
if engine._training and cfg.trainer.wandb and wandb is not None:
engine.wandb = wandb.init(project=name)
engine.wandb.watch(engine.module)
else:
engine.wandb = None
return engines return engines

View File

@ -554,22 +554,21 @@ class Engines(dict[str, Engine]):
if grad_norm is not None: if grad_norm is not None:
grad_norm /= loss_scale grad_norm /= loss_scale
stats.update( model_stats = dict(
flatten_dict( **engine_stats,
{ lr=engine.get_lr()[0],
name.split("-")[0]: dict( grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm,
**engine_stats, loss_scale=loss_scale if loss_scale != 1 else None,
lr=engine.get_lr()[0], elapsed_time=elapsed_time,
grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm, engine_step=engine.global_step,
loss_scale=loss_scale if loss_scale != 1 else None, samples_processed=engine.global_samples,
elapsed_time=elapsed_time, tokens_processed=engine.tokens_processed,
engine_step=engine.global_step,
samples_processed=engine.global_samples,
tokens_processed=engine.tokens_processed,
)
}
),
) )
if engine.wandb is not None:
engine.wandb.log(model_stats)
stats.update(flatten_dict({name.split("-")[0]: model_stats}))
self._update() self._update()