I cannot believe it's not actually called Wand DB (added wandb logging support since I think it would have been a much better way to look at my metrics)

This commit is contained in:
mrq 2024-11-20 16:10:47 -06:00
parent 67f7bad168
commit 1a73ac6a20
4 changed files with 30 additions and 15 deletions

1
.gitignore vendored
View File

@ -6,3 +6,4 @@ __pycache__
/vall_e/version.py
/.cache
/voices
/wandb

View File

@ -695,6 +695,8 @@ class Trainer:
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
gc_mode: str | None = None # deprecated, but marks when to do GC
wandb: bool = True # use wandb, if available
weight_dtype: str = "float16" # dtype to have the model under
amp: bool = False # automatic mixed precision

View File

@ -28,6 +28,12 @@ try:
except Exception as e:
pass
try:
import wandb
except Exception as e:
_logger.warning(f'Failed to import wandb: {str(e)}')
wandb = None
from functools import cache
@cache
@ -278,4 +284,11 @@ def load_engines(training=True, **model_kwargs):
if cfg.optimizations.model_offloading:
engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
# setup wandb
if engine._training and cfg.trainer.wandb and wandb is not None:
engine.wandb = wandb.init(project=name)
engine.wandb.watch(engine.module)
else:
engine.wandb = None
return engines

View File

@ -554,23 +554,22 @@ class Engines(dict[str, Engine]):
if grad_norm is not None:
grad_norm /= loss_scale
stats.update(
flatten_dict(
{
name.split("-")[0]: dict(
**engine_stats,
lr=engine.get_lr()[0],
grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm,
loss_scale=loss_scale if loss_scale != 1 else None,
elapsed_time=elapsed_time,
engine_step=engine.global_step,
samples_processed=engine.global_samples,
tokens_processed=engine.tokens_processed,
)
}
),
model_stats = dict(
**engine_stats,
lr=engine.get_lr()[0],
grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm,
loss_scale=loss_scale if loss_scale != 1 else None,
elapsed_time=elapsed_time,
engine_step=engine.global_step,
samples_processed=engine.global_samples,
tokens_processed=engine.tokens_processed,
)
if engine.wandb is not None:
engine.wandb.log(model_stats)
stats.update(flatten_dict({name.split("-")[0]: model_stats}))
self._update()
if len(self.keys()) > 1: