I cannot believe it's not actually called Wand DB (added wandb logging support since I think it would have been a much better way to look at my metrics)
This commit is contained in:
parent
67f7bad168
commit
1a73ac6a20
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -6,3 +6,4 @@ __pycache__
|
||||||
/vall_e/version.py
|
/vall_e/version.py
|
||||||
/.cache
|
/.cache
|
||||||
/voices
|
/voices
|
||||||
|
/wandb
|
|
@ -695,6 +695,8 @@ class Trainer:
|
||||||
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
||||||
gc_mode: str | None = None # deprecated, but marks when to do GC
|
gc_mode: str | None = None # deprecated, but marks when to do GC
|
||||||
|
|
||||||
|
wandb: bool = True # use wandb, if available
|
||||||
|
|
||||||
weight_dtype: str = "float16" # dtype to have the model under
|
weight_dtype: str = "float16" # dtype to have the model under
|
||||||
|
|
||||||
amp: bool = False # automatic mixed precision
|
amp: bool = False # automatic mixed precision
|
||||||
|
|
|
@ -28,6 +28,12 @@ try:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
except Exception as e:
|
||||||
|
_logger.warning(f'Failed to import wandb: {str(e)}')
|
||||||
|
wandb = None
|
||||||
|
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
|
@ -278,4 +284,11 @@ def load_engines(training=True, **model_kwargs):
|
||||||
if cfg.optimizations.model_offloading:
|
if cfg.optimizations.model_offloading:
|
||||||
engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
|
engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading )
|
||||||
|
|
||||||
|
# setup wandb
|
||||||
|
if engine._training and cfg.trainer.wandb and wandb is not None:
|
||||||
|
engine.wandb = wandb.init(project=name)
|
||||||
|
engine.wandb.watch(engine.module)
|
||||||
|
else:
|
||||||
|
engine.wandb = None
|
||||||
|
|
||||||
return engines
|
return engines
|
||||||
|
|
|
@ -554,22 +554,21 @@ class Engines(dict[str, Engine]):
|
||||||
if grad_norm is not None:
|
if grad_norm is not None:
|
||||||
grad_norm /= loss_scale
|
grad_norm /= loss_scale
|
||||||
|
|
||||||
stats.update(
|
model_stats = dict(
|
||||||
flatten_dict(
|
**engine_stats,
|
||||||
{
|
lr=engine.get_lr()[0],
|
||||||
name.split("-")[0]: dict(
|
grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm,
|
||||||
**engine_stats,
|
loss_scale=loss_scale if loss_scale != 1 else None,
|
||||||
lr=engine.get_lr()[0],
|
elapsed_time=elapsed_time,
|
||||||
grad_norm=grad_norm.item() if isinstance( grad_norm, torch.Tensor ) else grad_norm,
|
engine_step=engine.global_step,
|
||||||
loss_scale=loss_scale if loss_scale != 1 else None,
|
samples_processed=engine.global_samples,
|
||||||
elapsed_time=elapsed_time,
|
tokens_processed=engine.tokens_processed,
|
||||||
engine_step=engine.global_step,
|
|
||||||
samples_processed=engine.global_samples,
|
|
||||||
tokens_processed=engine.tokens_processed,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if engine.wandb is not None:
|
||||||
|
engine.wandb.log(model_stats)
|
||||||
|
|
||||||
|
stats.update(flatten_dict({name.split("-")[0]: model_stats}))
|
||||||
|
|
||||||
self._update()
|
self._update()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user