oops
This commit is contained in:
parent
d19f93a2c0
commit
23f3b56fda
|
@ -257,7 +257,7 @@ class AudioClassifier(nn.Module):
|
||||||
xi = [
|
xi = [
|
||||||
#x if l == 0 else
|
#x if l == 0 else
|
||||||
x if x.shape[-1] == max_size else
|
x if x.shape[-1] == max_size else
|
||||||
torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])] ).to(dtype=dtype, device=device) ] * (max_size - x.shape[-1]), dim=-1 )
|
torch.cat( [ x, torch.tensor( [[ -float("inf") ] for _ in range(x.shape[0])], device=device, dtype=dtype) ] * (max_size - x.shape[-1]), dim=-1 )
|
||||||
for x, l in zip(xi, levels)
|
for x, l in zip(xi, levels)
|
||||||
]
|
]
|
||||||
return torch.stack( xi )
|
return torch.stack( xi )
|
||||||
|
@ -1098,7 +1098,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
delta = ids[batch_index].shape[0] - batch.shape[0]
|
delta = ids[batch_index].shape[0] - batch.shape[0]
|
||||||
if delta > 0:
|
if delta > 0:
|
||||||
batch = torch.cat( [ batch, torch.tensor([1] * delta) ] )
|
batch = torch.cat( [ batch, torch.tensor([1] * delta, device=device, dtype=torch.int32) ] )
|
||||||
|
|
||||||
x_list.append( batch )
|
x_list.append( batch )
|
||||||
|
|
||||||
|
@ -1119,7 +1119,7 @@ class Base(nn.Module):
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_token( input, quant_level ):
|
def prompt_input_to_token( input, quant_level ):
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ] ).to(dtype=torch.int16, device=device)
|
return torch.tensor( [ get_task_symmap()[f'<{input}>'] ], device=device, dtype=torch.int16)
|
||||||
|
|
||||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user