From be83ddabaa32870ff3ce1092d000fb12a13184aa Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 13 Nov 2024 10:17:52 -0600 Subject: [PATCH] better causal-ness for split loss calc, and also do masking for NAR-len for it --- vall_e/models/base.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 5be47a6..0464acf 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1445,15 +1445,30 @@ class Base(nn.Module): batch_size = len( inputs ) for i, batch in enumerate( inputs ): - quant_level = quant_levels[i] - it = 0 - + quant_level = quant_levels[i] task_name = None + + causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) + dropout_mask = None for name, input in batch: - # do not use resp + if name == "dropout_mask": + dropout_mask = input + + for name, input in batch: + # meta-input, no corresponding token at the moment + if name == "task": + task_name = input + if task_type in ["len", "stt"]: + causal = True + continue + # do not use resp as-is if name == "resp": - if self.interleave: + if dropout_mask is not None: + # if mask use original token, else ignore + causal = False + target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) ) + elif self.interleave: input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) elif task_type in summed_embeddings_task: input = torch.full_like(input[..., 0], self.ignore_index) @@ -1463,17 +1478,11 @@ class Base(nn.Module): elif name == "prom": proms = [ input ] if isinstance(input, torch.Tensor) else input input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] ) - # meta-input, no corresponding token at the moment - elif name == "task": - task_name = input - continue seq_len = input.shape[0] logit = logits[i][it:it+seq_len] it += seq_len + 1 # +1 to incorporate the separator - - causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_name in ["len", "stt"]) # for the AR, shift sequence so that it predicts the next token # (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)