From cebacc23035ab9b2904bced6b9ee2bc81f1d41b5 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 13 Feb 2025 13:01:55 -0600 Subject: [PATCH] ugh --- vall_e/models/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 15b5b75..06b6940 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1530,7 +1530,7 @@ class Base(nn.Module): return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) # ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens - if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums): + if self.version < 4 or (self.version >= 5 and self.version < 7 and self.config and self.config.experimental.audio_embedding_sums): return torch.full_like(input[..., 0], self.ignore_index) if self.version < 7: @@ -1610,6 +1610,9 @@ class Base(nn.Module): proms = [ input ] if isinstance(input, torch.Tensor) else input # iterate over the list to inject their tokens token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) + + if logits[batch_index].dim() < 3 and token.dim() >= 2: + token = token[..., 0] elif name == "resp": # mask found, apply it if self.version < 7: