From 5e9d1a5302241a7738b831770527be5b6a3d794b Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 7 Mar 2025 19:32:42 -0600 Subject: [PATCH] one more time one more time (this normalization isn't a spook) --- vall_e/engines/base.py | 3 ++- vall_e/models/base_v2.py | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 7d9388e..ae430e3 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -610,7 +610,8 @@ class Engines(dict[str, Engine]): _logger.warning(f'Loss scale ({loss_scale}) exceeds max_loss_scale ({cfg.trainer.deepspeed.max_loss_scale}), capping...') engine.set_loss_scale(cfg.trainer.deepspeed.max_loss_scale) - if grad_norm is not None: + # scale the grad norm to normal, if not using ZeRO because ZeRO does this already + if grad_norm is not None and not cfg.trainer.deepspeed.zero_optimization_level: grad_norm /= loss_scale model_stats = dict( diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index ec111a1..aa71a57 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -1300,7 +1300,10 @@ class Base_V2(nn.Module): nll = 0 nlls = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index ) + # not my best code it = 0 + weights = 0 + bsz = len( loss_targets ) for seq, level in zip( loss_targets, loss_levels ): seq_len = seq.shape[0] start = it @@ -1308,8 +1311,13 @@ class Base_V2(nn.Module): end = it nll += nlls[start:end].mean() * level_loss_factors[level] + weights += level_loss_factors[level] - nll /= len( loss_targets ) + # normalize by batch + nll /= bsz + # re-scale by summed weights + nll /= (weights / bsz) + # no this isn't redundant I swear, it'll propagate properly if compute_acc: n_vocab = loss_logit.shape[-1]