reticulating splines

This commit is contained in:
mrq 2025-03-01 17:48:51 -06:00
parent 56f8be4d62
commit 17094b8002
2 changed files with 5 additions and 2 deletions

View File

@ -172,6 +172,9 @@ def unload_model():
# to-do: clean up this mess
@torch.inference_mode()
def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_duration=None):
# dirty hack during model training
codes = torch.where( codes >= (max_token = 1000 if cfg.audio_backend == "nemo" else 1024 ), 0, codes )
# upcast so it won't whine
if codes.dtype in [torch.int8, torch.int16, torch.uint8]:
codes = codes.to(torch.int32)

View File

@ -1019,7 +1019,7 @@ class Base_V2(nn.Module):
name = f'{name}[{level}]'
sequence = token if token.dim() <= 1 else token[:, level]
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
else:
sequence = token.t()
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
@ -1058,7 +1058,7 @@ class Base_V2(nn.Module):
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
else:
nlls = []
accs = []