one more time one more time (this normalization isn't a spook)
This commit is contained in:
parent
93044829af
commit
5e9d1a5302
|
@ -610,7 +610,8 @@ class Engines(dict[str, Engine]):
|
||||||
_logger.warning(f'Loss scale ({loss_scale}) exceeds max_loss_scale ({cfg.trainer.deepspeed.max_loss_scale}), capping...')
|
_logger.warning(f'Loss scale ({loss_scale}) exceeds max_loss_scale ({cfg.trainer.deepspeed.max_loss_scale}), capping...')
|
||||||
engine.set_loss_scale(cfg.trainer.deepspeed.max_loss_scale)
|
engine.set_loss_scale(cfg.trainer.deepspeed.max_loss_scale)
|
||||||
|
|
||||||
if grad_norm is not None:
|
# scale the grad norm to normal, if not using ZeRO because ZeRO does this already
|
||||||
|
if grad_norm is not None and not cfg.trainer.deepspeed.zero_optimization_level:
|
||||||
grad_norm /= loss_scale
|
grad_norm /= loss_scale
|
||||||
|
|
||||||
model_stats = dict(
|
model_stats = dict(
|
||||||
|
|
|
@ -1300,7 +1300,10 @@ class Base_V2(nn.Module):
|
||||||
nll = 0
|
nll = 0
|
||||||
nlls = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index )
|
nlls = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index )
|
||||||
|
|
||||||
|
# not my best code
|
||||||
it = 0
|
it = 0
|
||||||
|
weights = 0
|
||||||
|
bsz = len( loss_targets )
|
||||||
for seq, level in zip( loss_targets, loss_levels ):
|
for seq, level in zip( loss_targets, loss_levels ):
|
||||||
seq_len = seq.shape[0]
|
seq_len = seq.shape[0]
|
||||||
start = it
|
start = it
|
||||||
|
@ -1308,8 +1311,13 @@ class Base_V2(nn.Module):
|
||||||
end = it
|
end = it
|
||||||
|
|
||||||
nll += nlls[start:end].mean() * level_loss_factors[level]
|
nll += nlls[start:end].mean() * level_loss_factors[level]
|
||||||
|
weights += level_loss_factors[level]
|
||||||
|
|
||||||
nll /= len( loss_targets )
|
# normalize by batch
|
||||||
|
nll /= bsz
|
||||||
|
# re-scale by summed weights
|
||||||
|
nll /= (weights / bsz)
|
||||||
|
# no this isn't redundant I swear, it'll propagate properly
|
||||||
|
|
||||||
if compute_acc:
|
if compute_acc:
|
||||||
n_vocab = loss_logit.shape[-1]
|
n_vocab = loss_logit.shape[-1]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user