slightly faster pathway for resp_parallel_training=True with loss factoring
This commit is contained in:
parent
06ef3daf3c
commit
0c224090d7
|
@ -629,7 +629,10 @@ class Engines(dict[str, Engine]):
|
|||
if cfg.lora is not None:
|
||||
key_name = cfg.lora.full_name
|
||||
|
||||
stats.update(flatten_dict({key_name.split("-")[0]: model_stats}))
|
||||
if len(self) == 1:
|
||||
stats.update(flatten_dict(model_stats))
|
||||
else:
|
||||
stats.update(flatten_dict({key_name.split("-")[0]: model_stats}))
|
||||
|
||||
self._update()
|
||||
|
||||
|
|
|
@ -868,41 +868,18 @@ class Base_V2(nn.Module):
|
|||
if classifier_level.endswith(f':{i}:{i}'):
|
||||
level = i
|
||||
break
|
||||
"""
|
||||
if name == "resp":
|
||||
name = f'{name}[{level}]'
|
||||
"""
|
||||
sequence = token if token.dim() <= 1 else token[:, level]
|
||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
|
||||
else:
|
||||
nlls = []
|
||||
accs = []
|
||||
vocab_size = logits[batch_index].shape[-1]
|
||||
|
||||
for level, logit in enumerate( logits[batch_index] ):
|
||||
sequence = token if token.dim() <= 1 else token[:, level]
|
||||
nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal )
|
||||
|
||||
if name == "resp":
|
||||
if nll is not None:
|
||||
if f'{name}[{level}].nll' not in loss:
|
||||
loss[f'{name}[{level}].nll'] = []
|
||||
loss[f"{name}[{level}].nll"].append( nll * loss_factor )
|
||||
|
||||
if metrics is not None:
|
||||
if f'{name}[{level}].acc' not in stats:
|
||||
stats[f'{name}[{level}].acc'] = []
|
||||
stats[f"{name}[{level}].acc"].append( metrics )
|
||||
|
||||
nll = None
|
||||
metrics = None
|
||||
else:
|
||||
if nll:
|
||||
nlls.append( nll )
|
||||
if metrics:
|
||||
accs.append( metrics )
|
||||
if nlls:
|
||||
nll = sum(nlls) / len(nlls)
|
||||
if accs:
|
||||
accs = sum(accs) / len(accs)
|
||||
|
||||
logit = logits[batch_index][:, start:end].reshape(-1, vocab_size)
|
||||
sequence = token.reshape(-1).long()
|
||||
nll, metrics = _calc_loss( logit, sequence, causal )
|
||||
if nll is not None:
|
||||
if f'{name}.nll' not in loss:
|
||||
loss[f'{name}.nll'] = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user