I think I made resp_parallel_training=True faster with loss factoring?
This commit is contained in:
parent
06ef3daf3c
commit
ceecac6ffe
|
@ -629,7 +629,10 @@ class Engines(dict[str, Engine]):
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
key_name = cfg.lora.full_name
|
key_name = cfg.lora.full_name
|
||||||
|
|
||||||
stats.update(flatten_dict({key_name.split("-")[0]: model_stats}))
|
if len(self) == 1:
|
||||||
|
stats.update(flatten_dict(model_stats))
|
||||||
|
else:
|
||||||
|
stats.update(flatten_dict({key_name.split("-")[0]: model_stats}))
|
||||||
|
|
||||||
self._update()
|
self._update()
|
||||||
|
|
||||||
|
|
|
@ -748,7 +748,7 @@ class Base_V2(nn.Module):
|
||||||
# filter tokens that exceed the vocab size
|
# filter tokens that exceed the vocab size
|
||||||
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
|
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
|
||||||
# drop if all tokens are ignored
|
# drop if all tokens are ignored
|
||||||
if all(sequence == self.ignore_index):
|
if torch.all(sequence == self.ignore_index):
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# shift if causal
|
# shift if causal
|
||||||
|
@ -757,8 +757,14 @@ class Base_V2(nn.Module):
|
||||||
logit = logit[..., :-l, :] # shift the target so that token n...
|
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||||
sequence = sequence[..., l:] # ...predicts token n + 1
|
sequence = sequence[..., l:] # ...predicts token n + 1
|
||||||
|
|
||||||
|
# flatten batch
|
||||||
|
if sequence.dim() > 1:
|
||||||
|
logit = logit.reshape(-1, logit.shape[-1])
|
||||||
|
sequence = sequence.reshape(-1)
|
||||||
|
|
||||||
nll = None
|
nll = None
|
||||||
metrics = None
|
metrics = None
|
||||||
|
|
||||||
if compute_hard_loss:
|
if compute_hard_loss:
|
||||||
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index )
|
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index )
|
||||||
|
|
||||||
|
@ -868,41 +874,15 @@ class Base_V2(nn.Module):
|
||||||
if classifier_level.endswith(f':{i}:{i}'):
|
if classifier_level.endswith(f':{i}:{i}'):
|
||||||
level = i
|
level = i
|
||||||
break
|
break
|
||||||
|
"""
|
||||||
if name == "resp":
|
if name == "resp":
|
||||||
name = f'{name}[{level}]'
|
name = f'{name}[{level}]'
|
||||||
|
"""
|
||||||
sequence = token if token.dim() <= 1 else token[:, level]
|
sequence = token if token.dim() <= 1 else token[:, level]
|
||||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
|
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
|
||||||
else:
|
else:
|
||||||
nlls = []
|
sequence = token.t()
|
||||||
accs = []
|
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
||||||
|
|
||||||
for level, logit in enumerate( logits[batch_index] ):
|
|
||||||
sequence = token if token.dim() <= 1 else token[:, level]
|
|
||||||
nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal )
|
|
||||||
|
|
||||||
if name == "resp":
|
|
||||||
if nll is not None:
|
|
||||||
if f'{name}[{level}].nll' not in loss:
|
|
||||||
loss[f'{name}[{level}].nll'] = []
|
|
||||||
loss[f"{name}[{level}].nll"].append( nll * loss_factor )
|
|
||||||
|
|
||||||
if metrics is not None:
|
|
||||||
if f'{name}[{level}].acc' not in stats:
|
|
||||||
stats[f'{name}[{level}].acc'] = []
|
|
||||||
stats[f"{name}[{level}].acc"].append( metrics )
|
|
||||||
|
|
||||||
nll = None
|
|
||||||
metrics = None
|
|
||||||
else:
|
|
||||||
if nll:
|
|
||||||
nlls.append( nll )
|
|
||||||
if metrics:
|
|
||||||
accs.append( metrics )
|
|
||||||
if nlls:
|
|
||||||
nll = sum(nlls) / len(nlls)
|
|
||||||
if accs:
|
|
||||||
accs = sum(accs) / len(accs)
|
|
||||||
|
|
||||||
if nll is not None:
|
if nll is not None:
|
||||||
if f'{name}.nll' not in loss:
|
if f'{name}.nll' not in loss:
|
||||||
loss[f'{name}.nll'] = []
|
loss[f'{name}.nll'] = []
|
||||||
|
|
Loading…
Reference in New Issue
Block a user