ugh
This commit is contained in:
parent
e3becec0e8
commit
cebacc2303
|
@ -1530,7 +1530,7 @@ class Base(nn.Module):
|
||||||
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
|
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
|
||||||
|
|
||||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||||
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
|
if self.version < 4 or (self.version >= 5 and self.version < 7 and self.config and self.config.experimental.audio_embedding_sums):
|
||||||
return torch.full_like(input[..., 0], self.ignore_index)
|
return torch.full_like(input[..., 0], self.ignore_index)
|
||||||
|
|
||||||
if self.version < 7:
|
if self.version < 7:
|
||||||
|
@ -1610,6 +1610,9 @@ class Base(nn.Module):
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
# iterate over the list to inject their tokens
|
# iterate over the list to inject their tokens
|
||||||
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
|
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
|
||||||
|
|
||||||
|
if logits[batch_index].dim() < 3 and token.dim() >= 2:
|
||||||
|
token = token[..., 0]
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
# mask found, apply it
|
# mask found, apply it
|
||||||
if self.version < 7:
|
if self.version < 7:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user