diff --git a/vall_e/models/base.py b/vall_e/models/base.py index cea6759..15b5b75 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1538,24 +1538,12 @@ class Base(nn.Module): return input - def _calc_loss( logit, sequence, factor = 1 ): - """ - if any(sequence >= logit.shape[-1]): - _logger.warning(f'Batch contains extraneous value: {sequence}') - return - """ + def _calc_loss( logit, sequence, causal = True ): # filter tokens that exceed the vocab size - if any(sequence >= logit.shape[-1]): - extraneous = [] - for i, t in enumerate( sequence ): - if t < logits[batch_index].shape[-1]: - continue - extraneous.append(t.item()) - sequence[i] = self.ignore_index - _logger.warning(f'Batch contains extraneous value: {extraneous} >= {logit.shape[-1]}') - + sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence ) + # drop if all tokens are ignored if all(sequence == self.ignore_index): - return + return None, None # shift if causal if causal: @@ -1563,11 +1551,10 @@ class Base(nn.Module): logit = logit[..., :-l, :] # shift the target so that token n... sequence = sequence[..., l:] # ...predicts token n + 1 + nll = None + metrics = None if compute_hard_loss: - nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index ) * factor - if 'nll' not in loss: - loss['nll'] = [] - loss["nll"].append( nll ) + nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index ) if compute_acc: if self.metrics is not None and classifier_level in self.classifiers.names: @@ -1582,9 +1569,8 @@ class Base(nn.Module): ).to(logit.device) metrics = accuracy_metric( logit, sequence ) - if 'acc' not in stats: - stats['acc'] = [] - stats["acc"].append( metrics ) + metrics + return nll, metrics for batch_index, batch in enumerate(inputs): quant_level = quant_levels[batch_index] @@ -1675,11 +1661,34 @@ class Base(nn.Module): continue if logits[batch_index].dim() < 3: - _calc_loss( logits[batch_index][start:end], token.long(), loss_factor ) + nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal ) else: + nlls = [] + accs = [] + for level, logit in enumerate( logits[batch_index] ): sequence = token if token.dim() <= 1 else token[:, level] - _calc_loss( logit[start:end], sequence.long(), loss_factor ) + nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal ) + + if nll: + nlls.append( nll ) + if metrics: + accs.append( metrics ) + + if nlls: + nll = sum(nlls) / len(nlls) + if accs: + accs = sum(accs) / len(accs) + + if nll is not None: + if f'{name}.nll' not in loss: + loss[f'{name}.nll'] = [] + loss[f"{name}.nll"].append( nll * loss_factor ) + + if metrics is not None: + if f'{name}.acc' not in stats: + stats[f'{name}.acc'] = [] + stats[f"{name}.acc"].append( metrics ) # add to list else: target.append( token ) @@ -1688,12 +1697,35 @@ class Base(nn.Module): if not self.config.loss_factors: if logits[batch_index].dim() < 3: sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) - _calc_loss( logits[batch_index], sequence ) + nll, metrics = _calc_loss( logits[batch_index], sequence, causal ) else: + nlls = [] + accs = [] + for level, logit in enumerate( logits[batch_index] ): sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) - _calc_loss( logit, sequence ) + nll, metrics = _calc_loss( logit, sequence, causal ) + + if nll: + nlls.append( nll ) + if metrics: + accs.append( metrics ) + + if nlls: + nll = sum(nlls) / len(nlls) + if accs: + accs = sum(accs) / len(accs) + + if nll is not None: + if 'nll' not in loss: + loss['nll'] = [] + loss["nll"].append( nll ) + + if metrics is not None: + if 'acc' not in stats: + stats['acc'] = [] + stats["acc"].append( metrics ) # average loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }