oops
This commit is contained in:
parent
d19f93a2c0
commit
23f3b56fda
|
@ -257,7 +257,7 @@ class AudioClassifier(nn.Module):
|
|||
xi = [
|
||||
#x if l == 0 else
|
||||
x if x.shape[-1] == max_size else
|
||||
torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 )
|
||||
torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])], device=device, dtype=dtype) ] * (max_size - x.shape[-1]), dim=-1 )
|
||||
for x, l in zip(xi, levels)
|
||||
]
|
||||
return torch.stack( xi )
|
||||
|
@ -1098,7 +1098,7 @@ class Base(nn.Module):
|
|||
|
||||
delta = ids[batch_index].shape[0] - batch.shape[0]
|
||||
if delta > 0:
|
||||
batch = torch.cat( [ batch, torch.tensor([1] * delta) ] )
|
||||
batch = torch.cat( [ batch, torch.tensor([1] * delta, device=device, dtype=torch.int32) ] )
|
||||
|
||||
x_list.append( batch )
|
||||
|
||||
|
@ -1119,7 +1119,7 @@ class Base(nn.Module):
|
|||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
if isinstance(input, str):
|
||||
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device)
|
||||
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16)
|
||||
|
||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||
|
|
Loading…
Reference in New Issue
Block a user