diff --git a/vall_e/models/base.py b/vall_e/models/base.py index f85501d..6eb32f5 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1639,13 +1639,6 @@ class Base(nn.Module): if loss_factor == 0.0: continue - - # cringe way to deduce "requested" level - level = quant_level - for i in range( self.n_resp_levels ): - if classifier_level == f'NAR:{i}:{i}': - level = i - break if logits[batch_index].dim() < 3: nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal ) @@ -1653,6 +1646,12 @@ class Base(nn.Module): if name == "resp": name = f'{name}[{quant_level}]' elif not self.resp_parallel_training: + # cringe way to deduce "requested" level + level = quant_level + for i in range( self.n_resp_levels ): + if classifier_level.endswith(f':{i}:{i}'): + level = i + break if name == "resp": name = f'{name}[{level}]' sequence = token if token.dim() <= 1 else token[:, level] @@ -1711,9 +1710,10 @@ class Base(nn.Module): # cringe way to deduce "requested" level level = 0 for i in range( self.n_resp_levels ): - if classifier_level == f'NAR:{i}:{i}': + if classifier_level.endswith(f':{i}:{i}'): level = i break + sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )