reticulating splines
This commit is contained in:
parent
56f8be4d62
commit
17094b8002
|
@ -172,6 +172,9 @@ def unload_model():
|
||||||
# to-do: clean up this mess
|
# to-do: clean up this mess
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_duration=None):
|
def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_duration=None):
|
||||||
|
# dirty hack during model training
|
||||||
|
codes = torch.where( codes >= (max_token = 1000 if cfg.audio_backend == "nemo" else 1024 ), 0, codes )
|
||||||
|
|
||||||
# upcast so it won't whine
|
# upcast so it won't whine
|
||||||
if codes.dtype in [torch.int8, torch.int16, torch.uint8]:
|
if codes.dtype in [torch.int8, torch.int16, torch.uint8]:
|
||||||
codes = codes.to(torch.int32)
|
codes = codes.to(torch.int32)
|
||||||
|
|
|
@ -1019,7 +1019,7 @@ class Base_V2(nn.Module):
|
||||||
name = f'{name}[{level}]'
|
name = f'{name}[{level}]'
|
||||||
|
|
||||||
sequence = token if token.dim() <= 1 else token[:, level]
|
sequence = token if token.dim() <= 1 else token[:, level]
|
||||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
|
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
|
||||||
else:
|
else:
|
||||||
sequence = token.t()
|
sequence = token.t()
|
||||||
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
||||||
|
@ -1058,7 +1058,7 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
|
||||||
else:
|
else:
|
||||||
nlls = []
|
nlls = []
|
||||||
accs = []
|
accs = []
|
||||||
|
|
Loading…
Reference in New Issue
Block a user