reticulating splines
This commit is contained in:
parent
56f8be4d62
commit
17094b8002
|
@ -172,6 +172,9 @@ def unload_model():
|
|||
# to-do: clean up this mess
|
||||
@torch.inference_mode()
|
||||
def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_duration=None):
|
||||
# dirty hack during model training
|
||||
codes = torch.where( codes >= (max_token = 1000 if cfg.audio_backend == "nemo" else 1024 ), 0, codes )
|
||||
|
||||
# upcast so it won't whine
|
||||
if codes.dtype in [torch.int8, torch.int16, torch.uint8]:
|
||||
codes = codes.to(torch.int32)
|
||||
|
|
|
@ -1019,7 +1019,7 @@ class Base_V2(nn.Module):
|
|||
name = f'{name}[{level}]'
|
||||
|
||||
sequence = token if token.dim() <= 1 else token[:, level]
|
||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
|
||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
|
||||
else:
|
||||
sequence = token.t()
|
||||
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
||||
|
@ -1058,7 +1058,7 @@ class Base_V2(nn.Module):
|
|||
|
||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
|
||||
else:
|
||||
nlls = []
|
||||
accs = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user