This commit is contained in:
mrq 2024-08-04 08:18:57 -05:00
parent d19f93a2c0
commit 23f3b56fda

View File

@ -257,7 +257,7 @@ class AudioClassifier(nn.Module):
xi = [
#x if l == 0 else
x if x.shape[-1] == max_size else
torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 )
torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])], device=device, dtype=dtype) ] * (max_size - x.shape[-1]), dim=-1 )
for x, l in zip(xi, levels)
]
return torch.stack( xi )
@ -1098,7 +1098,7 @@ class Base(nn.Module):
delta = ids[batch_index].shape[0] - batch.shape[0]
if delta > 0:
batch = torch.cat( [ batch, torch.tensor([1] * delta) ] )
batch = torch.cat( [ batch, torch.tensor([1] * delta, device=device, dtype=torch.int32) ] )
x_list.append( batch )
@ -1119,7 +1119,7 @@ class Base(nn.Module):
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
if isinstance(input, str):
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device)
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16)
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):