one more time one more time (this normalization isn't a spook)

This commit is contained in:
mrq 2025-03-07 19:32:42 -06:00
parent 93044829af
commit 5e9d1a5302
2 changed files with 11 additions and 2 deletions

View File

@ -610,7 +610,8 @@ class Engines(dict[str, Engine]):
_logger.warning(f'Loss scale ({loss_scale}) exceeds max_loss_scale ({cfg.trainer.deepspeed.max_loss_scale}), capping...')
engine.set_loss_scale(cfg.trainer.deepspeed.max_loss_scale)
if grad_norm is not None:
# scale the grad norm to normal, if not using ZeRO because ZeRO does this already
if grad_norm is not None and not cfg.trainer.deepspeed.zero_optimization_level:
grad_norm /= loss_scale
model_stats = dict(

View File

@ -1300,7 +1300,10 @@ class Base_V2(nn.Module):
nll = 0
nlls = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index )
# not my best code
it = 0
weights = 0
bsz = len( loss_targets )
for seq, level in zip( loss_targets, loss_levels ):
seq_len = seq.shape[0]
start = it
@ -1308,8 +1311,13 @@ class Base_V2(nn.Module):
end = it
nll += nlls[start:end].mean() * level_loss_factors[level]
weights += level_loss_factors[level]
nll /= len( loss_targets )
# normalize by batch
nll /= bsz
# re-scale by summed weights
nll /= (weights / bsz)
# no this isn't redundant I swear, it'll propagate properly
if compute_acc:
n_vocab = loss_logit.shape[-1]