I cannot believe it's not actually called Wand DB (added wandb logging support since I think it would have been a much better way to look at my metrics)
This commit is contained in:
parent
67f7bad168
commit
1a73ac6a20
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -6,3 +6,4 @@ __pycache__
|
|||
/vall_e/version.py
|
||||
/.cache
|
||||
/voices
|
||||
/wandb
|
|
@ -695,6 +695,8 @@ class Trainer:
|
|||
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
||||
gc_mode: str | None = None # deprecated, but marks when to do GC
|
||||
|
||||
wandb: bool = True # use wandb, if available
|
||||
|
||||
weight_dtype: str = "float16" # dtype to have the model under
|
||||
|
||||
amp: bool = False # automatic mixed precision
|
||||
|
|
|
@ -28,6 +28,12 @@ try:
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to import wandb: {str(e)}')
|
||||
wandb = None
|
||||
|
||||
from functools import cache
|
||||
|
||||
@cache
|
||||
|
@ -278,4 +284,11 @@ def load_engines(training=True, **model_kwargs):
|
|||
if cfg.optimizations.model_offloading:
|
||||
engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
|
||||
|
||||
# setup wandb
|
||||
if engine._training and cfg.trainer.wandb and wandb is not None:
|
||||
engine.wandb = wandb.init(project=name)
|
||||
engine.wandb.watch(engine.module)
|
||||
else:
|
||||
engine.wandb = None
|
||||
|
||||
return engines
|
||||
|
|
|
@ -554,23 +554,22 @@ class Engines(dict[str, Engine]):
|
|||
if grad_norm is not None:
|
||||
grad_norm /= loss_scale
|
||||
|
||||
stats.update(
|
||||
flatten_dict(
|
||||
{
|
||||
name.split("-")[0]: dict(
|
||||
**engine_stats,
|
||||
lr=engine.get_lr()[0],
|
||||
grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm,
|
||||
loss_scale=loss_scale if loss_scale != 1 else None,
|
||||
elapsed_time=elapsed_time,
|
||||
engine_step=engine.global_step,
|
||||
samples_processed=engine.global_samples,
|
||||
tokens_processed=engine.tokens_processed,
|
||||
)
|
||||
}
|
||||
),
|
||||
model_stats = dict(
|
||||
**engine_stats,
|
||||
lr=engine.get_lr()[0],
|
||||
grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm,
|
||||
loss_scale=loss_scale if loss_scale != 1 else None,
|
||||
elapsed_time=elapsed_time,
|
||||
engine_step=engine.global_step,
|
||||
samples_processed=engine.global_samples,
|
||||
tokens_processed=engine.tokens_processed,
|
||||
)
|
||||
|
||||
if engine.wandb is not None:
|
||||
engine.wandb.log(model_stats)
|
||||
|
||||
stats.update(flatten_dict({name.split("-")[0]: model_stats}))
|
||||
|
||||
self._update()
|
||||
|
||||
if len(self.keys()) > 1:
|
||||
|
|
Loading…
Reference in New Issue
Block a user