ugh
This commit is contained in:
parent
e3becec0e8
commit
cebacc2303
|
@ -1530,7 +1530,7 @@ class Base(nn.Module):
|
|||
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
|
||||
|
||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
||||
if self.version < 4 or (self.version >= 5 and self.version < 7 and self.config and self.config.experimental.audio_embedding_sums):
|
||||
return torch.full_like(input[..., 0], self.ignore_index)
|
||||
|
||||
if self.version < 7:
|
||||
|
@ -1610,6 +1610,9 @@ class Base(nn.Module):
|
|||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
# iterate over the list to inject their tokens
|
||||
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
|
||||
|
||||
if logits[batch_index].dim() < 3 and token.dim() >= 2:
|
||||
token = token[..., 0]
|
||||
elif name == "resp":
|
||||
# mask found, apply it
|
||||
if self.version < 7:
|
||||
|
|
Loading…
Reference in New Issue
Block a user