better causal-ness for split loss calc, and also do masking for NAR-len for it

This commit is contained in:
mrq 2024-11-13 10:17:52 -06:00
parent 6b76419123
commit be83ddabaa

View File

@ -1445,15 +1445,30 @@ class Base(nn.Module):
batch_size = len( inputs ) batch_size = len( inputs )
for i, batch in enumerate( inputs ): for i, batch in enumerate( inputs ):
quant_level = quant_levels[i]
it = 0 it = 0
quant_level = quant_levels[i]
task_name = None task_name = None
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
dropout_mask = None
for name, input in batch: for name, input in batch:
# do not use resp if name == "dropout_mask":
dropout_mask = input
for name, input in batch:
# meta-input, no corresponding token at the moment
if name == "task":
task_name = input
if task_type in ["len", "stt"]:
causal = True
continue
# do not use resp as-is
if name == "resp": if name == "resp":
if self.interleave: if dropout_mask is not None:
# if mask use original token, else ignore
causal = False
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
elif self.interleave:
input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
elif task_type in summed_embeddings_task: elif task_type in summed_embeddings_task:
input = torch.full_like(input[..., 0], self.ignore_index) input = torch.full_like(input[..., 0], self.ignore_index)
@ -1463,18 +1478,12 @@ class Base(nn.Module):
elif name == "prom": elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input proms = [ input ] if isinstance(input, torch.Tensor) else input
input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] ) input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] )
# meta-input, no corresponding token at the moment
elif name == "task":
task_name = input
continue
seq_len = input.shape[0] seq_len = input.shape[0]
logit = logits[i][it:it+seq_len] logit = logits[i][it:it+seq_len]
it += seq_len + 1 # +1 to incorporate the separator it += seq_len + 1 # +1 to incorporate the separator
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_name in ["len", "stt"])
# for the AR, shift sequence so that it predicts the next token # for the AR, shift sequence so that it predicts the next token
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it) # (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
if causal and seq_len > 1: if causal and seq_len > 1: