accuracy k=1 and k=80 because im probably dumb for k=10 as the default since it does not represent any usecase
This commit is contained in:
parent
2fb2b732fc
commit
0d809561c6
|
@ -906,7 +906,7 @@ class Base_V2(nn.Module):
|
|||
sequence = sequence.reshape(-1)
|
||||
|
||||
nll = None
|
||||
metrics = None
|
||||
acc_k1 = None
|
||||
|
||||
if compute_hard_loss:
|
||||
reduction = 'mean' if not batched else 'none'
|
||||
|
@ -920,14 +920,23 @@ class Base_V2(nn.Module):
|
|||
if compute_acc:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = min(logit.shape[0], 10),
|
||||
top_k = 1,
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(logit.device)
|
||||
metrics = accuracy_metric( logit, sequence )
|
||||
acc_k1 = accuracy_metric( logit, sequence )
|
||||
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
top_k = min(logit.shape[0], 80),
|
||||
average="micro",
|
||||
multidim_average="global",
|
||||
ignore_index = -100
|
||||
).to(logit.device)
|
||||
acc_k80 = accuracy_metric( logit, sequence )
|
||||
|
||||
return nll, metrics
|
||||
return nll, acc_k1, acc_k80
|
||||
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
quant_level = quant_levels[batch_index]
|
||||
|
@ -1013,7 +1022,7 @@ class Base_V2(nn.Module):
|
|||
continue
|
||||
|
||||
if logits[batch_index].dim() < 3:
|
||||
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
||||
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
||||
elif not self.resp_parallel_training:
|
||||
# cringe way to deduce "requested" level
|
||||
level = quant_level
|
||||
|
@ -1026,25 +1035,31 @@ class Base_V2(nn.Module):
|
|||
name = f'{name}[{level}]'
|
||||
|
||||
sequence = token if token.dim() <= 1 else token[:, level]
|
||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
|
||||
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
|
||||
else:
|
||||
sequence = token.t()
|
||||
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
||||
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
||||
|
||||
if nll is not None:
|
||||
nll = nll.mean()
|
||||
|
||||
loss_key = f'{name}.nll'
|
||||
acc_key = f'{name}.acc'
|
||||
acc_k1_key = f'{name}.acc[k=1]'
|
||||
acc_k80_key = f'{name}.acc[k=80]'
|
||||
if nll is not None:
|
||||
if loss_key not in loss:
|
||||
loss[loss_key] = []
|
||||
loss[loss_key].append( nll * loss_factor )
|
||||
|
||||
if metrics is not None:
|
||||
if acc_key not in stats:
|
||||
stats[acc_key] = []
|
||||
stats[acc_key].append( metrics )
|
||||
if acc_k1 is not None:
|
||||
if acc_k1_key not in stats:
|
||||
stats[acc_k1_key] = []
|
||||
stats[acc_k1_key].append( acc_k1 )
|
||||
|
||||
if acc_k80 is not None:
|
||||
if acc_k80_key not in stats:
|
||||
stats[acc_k80_key] = []
|
||||
stats[acc_k80_key].append( acc_k80 )
|
||||
# add to list
|
||||
else:
|
||||
target.append( token )
|
||||
|
@ -1054,7 +1069,7 @@ class Base_V2(nn.Module):
|
|||
if not self.config.loss_factors:
|
||||
if logits[batch_index].dim() < 3:
|
||||
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
nll, metrics = _calc_loss( logits[batch_index], sequence, causal )
|
||||
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index], sequence, causal )
|
||||
elif not self.resp_parallel_training:
|
||||
# cringe way to deduce "requested" level
|
||||
level = 0
|
||||
|
@ -1065,35 +1080,45 @@ class Base_V2(nn.Module):
|
|||
|
||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
|
||||
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
|
||||
else:
|
||||
nlls = []
|
||||
accs = []
|
||||
acc_k1s = []
|
||||
acc_k80s = []
|
||||
|
||||
for level, logit in enumerate( logits[batch_index] ):
|
||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||
nll, metrics = _calc_loss( logit, sequence, causal, level )
|
||||
nll, acc_k1, acc_k80 = _calc_loss( logit, sequence, causal, level )
|
||||
|
||||
if nll:
|
||||
nlls.append( nll )
|
||||
if metrics:
|
||||
accs.append( metrics )
|
||||
if acc_k1:
|
||||
acc_k1s.append( acc_k1 )
|
||||
if acc_k80:
|
||||
acc_k80s.append( acc_k80 )
|
||||
|
||||
if nlls:
|
||||
nll = sum(nlls) / len(nlls)
|
||||
if accs:
|
||||
metrics = sum(accs) / len(accs)
|
||||
if acc_k1s:
|
||||
acc_k1 = sum(acc_k1s) / len(acc_k1s)
|
||||
if acc_k80s:
|
||||
acc_k80 = sum(acc_k80s) / len(acc_k80s)
|
||||
|
||||
if nll is not None:
|
||||
if 'nll' not in loss:
|
||||
loss['nll'] = []
|
||||
loss["nll"].append( nll )
|
||||
|
||||
if metrics is not None:
|
||||
if 'acc' not in stats:
|
||||
stats['acc'] = []
|
||||
stats["acc"].append( metrics )
|
||||
if acc_k1 is not None:
|
||||
if 'acc[k=1]' not in stats:
|
||||
stats['acc[k=1]'] = []
|
||||
stats["acc[k=1]"].append( acc_k1 )
|
||||
|
||||
if acc_k80 is not None:
|
||||
if 'acc[k=80]' not in stats:
|
||||
stats['acc[k=80]'] = []
|
||||
stats["acc[k=80]"].append( acc_k80 )
|
||||
|
||||
# average
|
||||
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
|
||||
|
|
Loading…
Reference in New Issue
Block a user