more better-er loss calc I suppose
This commit is contained in:
parent
e8f182b634
commit
e3becec0e8
|
@ -1538,24 +1538,12 @@ class Base(nn.Module):
|
||||||
|
|
||||||
return input
|
return input
|
||||||
|
|
||||||
def _calc_loss( logit, sequence, factor = 1 ):
|
def _calc_loss( logit, sequence, causal = True ):
|
||||||
"""
|
|
||||||
if any(sequence >= logit.shape[-1]):
|
|
||||||
_logger.warning(f'Batch contains extraneous value: {sequence}')
|
|
||||||
return
|
|
||||||
"""
|
|
||||||
# filter tokens that exceed the vocab size
|
# filter tokens that exceed the vocab size
|
||||||
if any(sequence >= logit.shape[-1]):
|
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
|
||||||
extraneous = []
|
# drop if all tokens are ignored
|
||||||
for i, t in enumerate( sequence ):
|
|
||||||
if t < logits[batch_index].shape[-1]:
|
|
||||||
continue
|
|
||||||
extraneous.append(t.item())
|
|
||||||
sequence[i] = self.ignore_index
|
|
||||||
_logger.warning(f'Batch contains extraneous value: {extraneous} >= {logit.shape[-1]}')
|
|
||||||
|
|
||||||
if all(sequence == self.ignore_index):
|
if all(sequence == self.ignore_index):
|
||||||
return
|
return None, None
|
||||||
|
|
||||||
# shift if causal
|
# shift if causal
|
||||||
if causal:
|
if causal:
|
||||||
|
@ -1563,11 +1551,10 @@ class Base(nn.Module):
|
||||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||||
sequence = sequence[..., l:] # ...predicts token n + 1
|
sequence = sequence[..., l:] # ...predicts token n + 1
|
||||||
|
|
||||||
|
nll = None
|
||||||
|
metrics = None
|
||||||
if compute_hard_loss:
|
if compute_hard_loss:
|
||||||
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index ) * factor
|
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index )
|
||||||
if 'nll' not in loss:
|
|
||||||
loss['nll'] = []
|
|
||||||
loss["nll"].append( nll )
|
|
||||||
|
|
||||||
if compute_acc:
|
if compute_acc:
|
||||||
if self.metrics is not None and classifier_level in self.classifiers.names:
|
if self.metrics is not None and classifier_level in self.classifiers.names:
|
||||||
|
@ -1582,9 +1569,8 @@ class Base(nn.Module):
|
||||||
).to(logit.device)
|
).to(logit.device)
|
||||||
metrics = accuracy_metric( logit, sequence )
|
metrics = accuracy_metric( logit, sequence )
|
||||||
|
|
||||||
if 'acc' not in stats:
|
metrics
|
||||||
stats['acc'] = []
|
return nll, metrics
|
||||||
stats["acc"].append( metrics )
|
|
||||||
|
|
||||||
for batch_index, batch in enumerate(inputs):
|
for batch_index, batch in enumerate(inputs):
|
||||||
quant_level = quant_levels[batch_index]
|
quant_level = quant_levels[batch_index]
|
||||||
|
@ -1675,11 +1661,34 @@ class Base(nn.Module):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if logits[batch_index].dim() < 3:
|
if logits[batch_index].dim() < 3:
|
||||||
_calc_loss( logits[batch_index][start:end], token.long(), loss_factor )
|
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
||||||
else:
|
else:
|
||||||
|
nlls = []
|
||||||
|
accs = []
|
||||||
|
|
||||||
for level, logit in enumerate( logits[batch_index] ):
|
for level, logit in enumerate( logits[batch_index] ):
|
||||||
sequence = token if token.dim() <= 1 else token[:, level]
|
sequence = token if token.dim() <= 1 else token[:, level]
|
||||||
_calc_loss( logit[start:end], sequence.long(), loss_factor )
|
nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal )
|
||||||
|
|
||||||
|
if nll:
|
||||||
|
nlls.append( nll )
|
||||||
|
if metrics:
|
||||||
|
accs.append( metrics )
|
||||||
|
|
||||||
|
if nlls:
|
||||||
|
nll = sum(nlls) / len(nlls)
|
||||||
|
if accs:
|
||||||
|
accs = sum(accs) / len(accs)
|
||||||
|
|
||||||
|
if nll is not None:
|
||||||
|
if f'{name}.nll' not in loss:
|
||||||
|
loss[f'{name}.nll'] = []
|
||||||
|
loss[f"{name}.nll"].append( nll * loss_factor )
|
||||||
|
|
||||||
|
if metrics is not None:
|
||||||
|
if f'{name}.acc' not in stats:
|
||||||
|
stats[f'{name}.acc'] = []
|
||||||
|
stats[f"{name}.acc"].append( metrics )
|
||||||
# add to list
|
# add to list
|
||||||
else:
|
else:
|
||||||
target.append( token )
|
target.append( token )
|
||||||
|
@ -1688,12 +1697,35 @@ class Base(nn.Module):
|
||||||
if not self.config.loss_factors:
|
if not self.config.loss_factors:
|
||||||
if logits[batch_index].dim() < 3:
|
if logits[batch_index].dim() < 3:
|
||||||
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||||
_calc_loss( logits[batch_index], sequence )
|
nll, metrics = _calc_loss( logits[batch_index], sequence, causal )
|
||||||
else:
|
else:
|
||||||
|
nlls = []
|
||||||
|
accs = []
|
||||||
|
|
||||||
for level, logit in enumerate( logits[batch_index] ):
|
for level, logit in enumerate( logits[batch_index] ):
|
||||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||||
_calc_loss( logit, sequence )
|
nll, metrics = _calc_loss( logit, sequence, causal )
|
||||||
|
|
||||||
|
if nll:
|
||||||
|
nlls.append( nll )
|
||||||
|
if metrics:
|
||||||
|
accs.append( metrics )
|
||||||
|
|
||||||
|
if nlls:
|
||||||
|
nll = sum(nlls) / len(nlls)
|
||||||
|
if accs:
|
||||||
|
accs = sum(accs) / len(accs)
|
||||||
|
|
||||||
|
if nll is not None:
|
||||||
|
if 'nll' not in loss:
|
||||||
|
loss['nll'] = []
|
||||||
|
loss["nll"].append( nll )
|
||||||
|
|
||||||
|
if metrics is not None:
|
||||||
|
if 'acc' not in stats:
|
||||||
|
stats['acc'] = []
|
||||||
|
stats["acc"].append( metrics )
|
||||||
|
|
||||||
# average
|
# average
|
||||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||||
|
|
Loading…
Reference in New Issue
Block a user