From 0d809561c6b7f03a954718a33aac3a277d0b5092 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 5 Mar 2025 16:35:34 -0600 Subject: [PATCH] accuracy k=1 and k=80 because im probably dumb for k=10 as the default since it does not represent any usecase --- vall_e/models/base_v2.py | 73 +++++++++++++++++++++++++++------------- 1 file changed, 49 insertions(+), 24 deletions(-) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 1dd3bd4..76b55ca 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -906,7 +906,7 @@ class Base_V2(nn.Module): sequence = sequence.reshape(-1) nll = None - metrics = None + acc_k1 = None if compute_hard_loss: reduction = 'mean' if not batched else 'none' @@ -920,14 +920,23 @@ class Base_V2(nn.Module): if compute_acc: accuracy_metric = MulticlassAccuracy( logit.shape[-1], - top_k = min(logit.shape[0], 10), + top_k = 1, average="micro", multidim_average="global", ignore_index = -100 ).to(logit.device) - metrics = accuracy_metric( logit, sequence ) + acc_k1 = accuracy_metric( logit, sequence ) + + accuracy_metric = MulticlassAccuracy( + logit.shape[-1], + top_k = min(logit.shape[0], 80), + average="micro", + multidim_average="global", + ignore_index = -100 + ).to(logit.device) + acc_k80 = accuracy_metric( logit, sequence ) - return nll, metrics + return nll, acc_k1, acc_k80 for batch_index, batch in enumerate(inputs): quant_level = quant_levels[batch_index] @@ -1013,7 +1022,7 @@ class Base_V2(nn.Module): continue if logits[batch_index].dim() < 3: - nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal ) + nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][start:end], token.long(), causal ) elif not self.resp_parallel_training: # cringe way to deduce "requested" level level = quant_level @@ -1026,25 +1035,31 @@ class Base_V2(nn.Module): name = f'{name}[{level}]' sequence = token if token.dim() <= 1 else token[:, level] - nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level ) + nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level ) else: sequence = token.t() - nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal ) + nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal ) if nll is not None: nll = nll.mean() loss_key = f'{name}.nll' - acc_key = f'{name}.acc' + acc_k1_key = f'{name}.acc[k=1]' + acc_k80_key = f'{name}.acc[k=80]' if nll is not None: if loss_key not in loss: loss[loss_key] = [] loss[loss_key].append( nll * loss_factor ) - if metrics is not None: - if acc_key not in stats: - stats[acc_key] = [] - stats[acc_key].append( metrics ) + if acc_k1 is not None: + if acc_k1_key not in stats: + stats[acc_k1_key] = [] + stats[acc_k1_key].append( acc_k1 ) + + if acc_k80 is not None: + if acc_k80_key not in stats: + stats[acc_k80_key] = [] + stats[acc_k80_key].append( acc_k80 ) # add to list else: target.append( token ) @@ -1054,7 +1069,7 @@ class Base_V2(nn.Module): if not self.config.loss_factors: if logits[batch_index].dim() < 3: sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) - nll, metrics = _calc_loss( logits[batch_index], sequence, causal ) + nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index], sequence, causal ) elif not self.resp_parallel_training: # cringe way to deduce "requested" level level = 0 @@ -1065,35 +1080,45 @@ class Base_V2(nn.Module): sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) - nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level ) + nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][level], sequence.long(), causal, level ) else: nlls = [] - accs = [] + acc_k1s = [] + acc_k80s = [] for level, logit in enumerate( logits[batch_index] ): sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) - nll, metrics = _calc_loss( logit, sequence, causal, level ) + nll, acc_k1, acc_k80 = _calc_loss( logit, sequence, causal, level ) if nll: nlls.append( nll ) - if metrics: - accs.append( metrics ) + if acc_k1: + acc_k1s.append( acc_k1 ) + if acc_k80: + acc_k80s.append( acc_k80 ) if nlls: nll = sum(nlls) / len(nlls) - if accs: - metrics = sum(accs) / len(accs) + if acc_k1s: + acc_k1 = sum(acc_k1s) / len(acc_k1s) + if acc_k80s: + acc_k80 = sum(acc_k80s) / len(acc_k80s) if nll is not None: if 'nll' not in loss: loss['nll'] = [] loss["nll"].append( nll ) - if metrics is not None: - if 'acc' not in stats: - stats['acc'] = [] - stats["acc"].append( metrics ) + if acc_k1 is not None: + if 'acc[k=1]' not in stats: + stats['acc[k=1]'] = [] + stats["acc[k=1]"].append( acc_k1 ) + + if acc_k80 is not None: + if 'acc[k=80]' not in stats: + stats['acc[k=80]'] = [] + stats["acc[k=80]"].append( acc_k80 ) # average loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }