better causal-ness for split loss calc, and also do masking for NAR-len for it
This commit is contained in:
parent
6b76419123
commit
be83ddabaa
|
@ -1445,15 +1445,30 @@ class Base(nn.Module):
|
||||||
batch_size = len( inputs )
|
batch_size = len( inputs )
|
||||||
|
|
||||||
for i, batch in enumerate( inputs ):
|
for i, batch in enumerate( inputs ):
|
||||||
quant_level = quant_levels[i]
|
|
||||||
|
|
||||||
it = 0
|
it = 0
|
||||||
|
quant_level = quant_levels[i]
|
||||||
task_name = None
|
task_name = None
|
||||||
|
|
||||||
|
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
|
||||||
|
dropout_mask = None
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
# do not use resp
|
if name == "dropout_mask":
|
||||||
|
dropout_mask = input
|
||||||
|
|
||||||
|
for name, input in batch:
|
||||||
|
# meta-input, no corresponding token at the moment
|
||||||
|
if name == "task":
|
||||||
|
task_name = input
|
||||||
|
if task_type in ["len", "stt"]:
|
||||||
|
causal = True
|
||||||
|
continue
|
||||||
|
# do not use resp as-is
|
||||||
if name == "resp":
|
if name == "resp":
|
||||||
if self.interleave:
|
if dropout_mask is not None:
|
||||||
|
# if mask use original token, else ignore
|
||||||
|
causal = False
|
||||||
|
target.append( torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) )
|
||||||
|
elif self.interleave:
|
||||||
input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
|
input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
|
||||||
elif task_type in summed_embeddings_task:
|
elif task_type in summed_embeddings_task:
|
||||||
input = torch.full_like(input[..., 0], self.ignore_index)
|
input = torch.full_like(input[..., 0], self.ignore_index)
|
||||||
|
@ -1463,18 +1478,12 @@ class Base(nn.Module):
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] )
|
input = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms ] )
|
||||||
# meta-input, no corresponding token at the moment
|
|
||||||
elif name == "task":
|
|
||||||
task_name = input
|
|
||||||
continue
|
|
||||||
|
|
||||||
seq_len = input.shape[0]
|
seq_len = input.shape[0]
|
||||||
|
|
||||||
logit = logits[i][it:it+seq_len]
|
logit = logits[i][it:it+seq_len]
|
||||||
it += seq_len + 1 # +1 to incorporate the separator
|
it += seq_len + 1 # +1 to incorporate the separator
|
||||||
|
|
||||||
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_name in ["len", "stt"])
|
|
||||||
|
|
||||||
# for the AR, shift sequence so that it predicts the next token
|
# for the AR, shift sequence so that it predicts the next token
|
||||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||||
if causal and seq_len > 1:
|
if causal and seq_len > 1:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user