ugh
This commit is contained in:
parent
2ea387c08a
commit
cbd4d7d7f4
|
@ -1473,54 +1473,10 @@ class Base(nn.Module):
|
||||||
if loss_factor == 0.0:
|
if loss_factor == 0.0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if logits[batch_index].dim() < 3:
|
"""
|
||||||
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
if name == "resp":
|
||||||
|
name = f'{name}[{quant_level}]'
|
||||||
if name == "resp":
|
"""
|
||||||
name = f'{name}[{quant_level}]'
|
|
||||||
elif not self.resp_parallel_training:
|
|
||||||
# cringe way to deduce "requested" level
|
|
||||||
level = quant_level
|
|
||||||
for i in range( self.n_resp_levels ):
|
|
||||||
if classifier_level.endswith(f':{i}:{i}'):
|
|
||||||
level = i
|
|
||||||
break
|
|
||||||
if name == "resp":
|
|
||||||
name = f'{name}[{level}]'
|
|
||||||
sequence = token if token.dim() <= 1 else token[:, level]
|
|
||||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
|
|
||||||
else:
|
|
||||||
nlls = []
|
|
||||||
accs = []
|
|
||||||
|
|
||||||
for level, logit in enumerate( logits[batch_index] ):
|
|
||||||
sequence = token if token.dim() <= 1 else token[:, level]
|
|
||||||
nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal )
|
|
||||||
|
|
||||||
if name == "resp":
|
|
||||||
if nll is not None:
|
|
||||||
if f'{name}[{level}].nll' not in loss:
|
|
||||||
loss[f'{name}[{level}].nll'] = []
|
|
||||||
loss[f"{name}[{level}].nll"].append( nll * loss_factor )
|
|
||||||
|
|
||||||
if metrics is not None:
|
|
||||||
if f'{name}[{level}].acc' not in stats:
|
|
||||||
stats[f'{name}[{level}].acc'] = []
|
|
||||||
stats[f"{name}[{level}].acc"].append( metrics )
|
|
||||||
|
|
||||||
nll = None
|
|
||||||
metrics = None
|
|
||||||
else:
|
|
||||||
if nll:
|
|
||||||
nlls.append( nll )
|
|
||||||
if metrics:
|
|
||||||
accs.append( metrics )
|
|
||||||
else:
|
|
||||||
if nlls:
|
|
||||||
nll = sum(nlls) / len(nlls)
|
|
||||||
if accs:
|
|
||||||
accs = sum(accs) / len(accs)
|
|
||||||
|
|
||||||
if nll is not None:
|
if nll is not None:
|
||||||
if f'{name}.nll' not in loss:
|
if f'{name}.nll' not in loss:
|
||||||
loss[f'{name}.nll'] = []
|
loss[f'{name}.nll'] = []
|
||||||
|
@ -1536,39 +1492,9 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# perofrm loss calculation on the entire sequence
|
# perofrm loss calculation on the entire sequence
|
||||||
if not self.config.loss_factors:
|
if not self.config.loss_factors:
|
||||||
if logits[batch_index].dim() < 3:
|
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||||
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
nll, metrics = _calc_loss( logits[batch_index], sequence, causal )
|
||||||
nll, metrics = _calc_loss( logits[batch_index], sequence, causal )
|
|
||||||
elif not self.resp_parallel_training:
|
|
||||||
# cringe way to deduce "requested" level
|
|
||||||
level = 0
|
|
||||||
for i in range( self.n_resp_levels ):
|
|
||||||
if classifier_level.endswith(f':{i}:{i}'):
|
|
||||||
level = i
|
|
||||||
break
|
|
||||||
|
|
||||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
|
||||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
|
||||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
|
||||||
else:
|
|
||||||
nlls = []
|
|
||||||
accs = []
|
|
||||||
|
|
||||||
for level, logit in enumerate( logits[batch_index] ):
|
|
||||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
|
||||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
|
||||||
nll, metrics = _calc_loss( logit, sequence, causal )
|
|
||||||
|
|
||||||
if nll:
|
|
||||||
nlls.append( nll )
|
|
||||||
if metrics:
|
|
||||||
accs.append( metrics )
|
|
||||||
|
|
||||||
if nlls:
|
|
||||||
nll = sum(nlls) / len(nlls)
|
|
||||||
if accs:
|
|
||||||
accs = sum(accs) / len(accs)
|
|
||||||
|
|
||||||
if nll is not None:
|
if nll is not None:
|
||||||
if 'nll' not in loss:
|
if 'nll' not in loss:
|
||||||
loss['nll'] = []
|
loss['nll'] = []
|
||||||
|
|
|
@ -898,11 +898,10 @@ class Base_V2(nn.Module):
|
||||||
nlls.append( nll )
|
nlls.append( nll )
|
||||||
if metrics:
|
if metrics:
|
||||||
accs.append( metrics )
|
accs.append( metrics )
|
||||||
else:
|
if nlls:
|
||||||
if nlls:
|
nll = sum(nlls) / len(nlls)
|
||||||
nll = sum(nlls) / len(nlls)
|
if accs:
|
||||||
if accs:
|
accs = sum(accs) / len(accs)
|
||||||
accs = sum(accs) / len(accs)
|
|
||||||
|
|
||||||
if nll is not None:
|
if nll is not None:
|
||||||
if f'{name}.nll' not in loss:
|
if f'{name}.nll' not in loss:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user