This commit is contained in:
mrq 2025-02-13 13:01:55 -06:00
parent e3becec0e8
commit cebacc2303

View File

@ -1530,7 +1530,7 @@ class Base(nn.Module):
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
if self.version < 4 or (self.version >= 5 and self.version < 7 and self.config and self.config.experimental.audio_embedding_sums):
return torch.full_like(input[..., 0], self.ignore_index)
if self.version < 7:
@ -1610,6 +1610,9 @@ class Base(nn.Module):
proms = [ input ] if isinstance(input, torch.Tensor) else input
# iterate over the list to inject their tokens
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
if logits[batch_index].dim() < 3 and token.dim() >= 2:
token = token[..., 0]
elif name == "resp":
# mask found, apply it
if self.version < 7: