better causal-ness for split loss calc, and also do masking for NAR-len for it
This commit is contained in:
parent
6b76419123
commit
be83ddabaa
|
@ -1445,15 +1445,30 @@ class Base(nn.Module):
|
|||
batch_size = len( inputs )
|
||||
|
||||
for i, batch in enumerate( inputs ):
|
||||
quant_level = quant_levels[i]
|
||||
|
||||
it = 0
|
||||
|
||||
quant_level = quant_levels[i]
|
||||
task_name = None
|
||||
|
||||
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
|
||||
dropout_mask = None
|
||||
for name, input in batch:
|
||||
# do not use resp
|
||||
if name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
|
||||
for name, input in batch:
|
||||
# meta-input, no corresponding token at the moment
|
||||
if name == "task":
|
||||
task_name = input
|
||||
if task_type in ["len", "stt"]:
|
||||
causal = True
|
||||
continue
|
||||
# do not use resp as-is
|
||||
if name == "resp":
|
||||
if self.interleave:
|
||||
if dropout_mask is not None:
|
||||
# if mask use original token, else ignore
|
||||
causal = False
|
||||
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
|
||||
elif self.interleave:
|
||||
input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
|
||||
elif task_type in summed_embeddings_task:
|
||||
input = torch.full_like(input[..., 0], self.ignore_index)
|
||||
|
@ -1463,17 +1478,11 @@ class Base(nn.Module):
|
|||
elif name == "prom":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] )
|
||||
# meta-input, no corresponding token at the moment
|
||||
elif name == "task":
|
||||
task_name = input
|
||||
continue
|
||||
|
||||
seq_len = input.shape[0]
|
||||
|
||||
logit = logits[i][it:it+seq_len]
|
||||
it += seq_len + 1 # +1 to incorporate the separator
|
||||
|
||||
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_name in ["len", "stt"])
|
||||
|
||||
# for the AR, shift sequence so that it predicts the next token
|
||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||
|
|
Loading…
Reference in New Issue
Block a user