one more time one more time (this normalization isn't a spook)
This commit is contained in:
parent
93044829af
commit
5e9d1a5302
|
@ -610,7 +610,8 @@ class Engines(dict[str, Engine]):
|
|||
_logger.warning(f'Loss scale ({loss_scale}) exceeds max_loss_scale ({cfg.trainer.deepspeed.max_loss_scale}), capping...')
|
||||
engine.set_loss_scale(cfg.trainer.deepspeed.max_loss_scale)
|
||||
|
||||
if grad_norm is not None:
|
||||
# scale the grad norm to normal, if not using ZeRO because ZeRO does this already
|
||||
if grad_norm is not None and not cfg.trainer.deepspeed.zero_optimization_level:
|
||||
grad_norm /= loss_scale
|
||||
|
||||
model_stats = dict(
|
||||
|
|
|
@ -1300,7 +1300,10 @@ class Base_V2(nn.Module):
|
|||
nll = 0
|
||||
nlls = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index )
|
||||
|
||||
# not my best code
|
||||
it = 0
|
||||
weights = 0
|
||||
bsz = len( loss_targets )
|
||||
for seq, level in zip( loss_targets, loss_levels ):
|
||||
seq_len = seq.shape[0]
|
||||
start = it
|
||||
|
@ -1308,8 +1311,13 @@ class Base_V2(nn.Module):
|
|||
end = it
|
||||
|
||||
nll += nlls[start:end].mean() * level_loss_factors[level]
|
||||
weights += level_loss_factors[level]
|
||||
|
||||
nll /= len( loss_targets )
|
||||
# normalize by batch
|
||||
nll /= bsz
|
||||
# re-scale by summed weights
|
||||
nll /= (weights / bsz)
|
||||
# no this isn't redundant I swear, it'll propagate properly
|
||||
|
||||
if compute_acc:
|
||||
n_vocab = loss_logit.shape[-1]
|
||||
|
|
Loading…
Reference in New Issue
Block a user