diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 5be47a6..0464acf 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1445,15 +1445,30 @@ class Base(nn.Module): batch_size = len( inputs ) for i, batch in enumerate( inputs ): - quant_level = quant_levels[i] - it = 0 - + quant_level = quant_levels[i] task_name = None + + causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) + dropout_mask = None for name, input in batch: - # do not use resp + if name == "dropout_mask": + dropout_mask = input + + for name, input in batch: + # meta-input, no corresponding token at the moment + if name == "task": + task_name = input + if task_type in ["len", "stt"]: + causal = True + continue + # do not use resp as-is if name == "resp": - if self.interleave: + if dropout_mask is not None: + # if mask use original token, else ignore + causal = False + target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) ) + elif self.interleave: input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) elif task_type in summed_embeddings_task: input = torch.full_like(input[..., 0], self.ignore_index) @@ -1463,17 +1478,11 @@ class Base(nn.Module): elif name == "prom": proms = [ input ] if isinstance(input, torch.Tensor) else input input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] ) - # meta-input, no corresponding token at the moment - elif name == "task": - task_name = input - continue seq_len = input.shape[0] logit = logits[i][it:it+seq_len] it += seq_len + 1 # +1 to incorporate the separator - - causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_name in ["len", "stt"]) # for the AR, shift sequence so that it predicts the next token # (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)