diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 15b5b75..06b6940 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1530,7 +1530,7 @@ class Base(nn.Module): return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) # ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens - if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums): + if self.version < 4 or (self.version >= 5 and self.version < 7 and self.config and self.config.experimental.audio_embedding_sums): return torch.full_like(input[..., 0], self.ignore_index) if self.version < 7: @@ -1610,6 +1610,9 @@ class Base(nn.Module): proms = [ input ] if isinstance(input, torch.Tensor) else input # iterate over the list to inject their tokens token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) + + if logits[batch_index].dim() < 3 and token.dim() >= 2: + token = token[..., 0] elif name == "resp": # mask found, apply it if self.version < 7: