diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 69552f9..734373d 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -172,6 +172,9 @@ def unload_model(): # to-do: clean up this mess @torch.inference_mode() def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_duration=None): + # dirty hack during model training + codes = torch.where( codes >= (max_token = 1000 if cfg.audio_backend == "nemo" else 1024 ), 0, codes ) + # upcast so it won't whine if codes.dtype in [torch.int8, torch.int16, torch.uint8]: codes = codes.to(torch.int32) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index ca7be8e..5b15ba8 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -1019,7 +1019,7 @@ class Base_V2(nn.Module): name = f'{name}[{level}]' sequence = token if token.dim() <= 1 else token[:, level] - nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal ) + nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level ) else: sequence = token.t() nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal ) @@ -1058,7 +1058,7 @@ class Base_V2(nn.Module): sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) - nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal ) + nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level ) else: nlls = [] accs = []