slightly faster pathway for resp_parallel_training=True with loss factoring

This commit is contained in:
mrq 2025-02-26 22:47:52 -06:00
parent 06ef3daf3c
commit 0c224090d7
2 changed files with 10 additions and 30 deletions

View File

@ -629,7 +629,10 @@ class Engines(dict[str, Engine]):
if cfg.lora is not None: if cfg.lora is not None:
key_name = cfg.lora.full_name key_name = cfg.lora.full_name
stats.update(flatten_dict({key_name.split("-")[0]: model_stats})) if len(self) == 1:
stats.update(flatten_dict(model_stats))
else:
stats.update(flatten_dict({key_name.split("-")[0]: model_stats}))
self._update() self._update()

View File

@ -868,41 +868,18 @@ class Base_V2(nn.Module):
if classifier_level.endswith(f':{i}:{i}'): if classifier_level.endswith(f':{i}:{i}'):
level = i level = i
break break
"""
if name == "resp": if name == "resp":
name = f'{name}[{level}]' name = f'{name}[{level}]'
"""
sequence = token if token.dim() <= 1 else token[:, level] sequence = token if token.dim() <= 1 else token[:, level]
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal ) nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
else: else:
nlls = [] vocab_size = logits[batch_index].shape[-1]
accs = []
for level, logit in enumerate( logits[batch_index] ): logit = logits[batch_index][:, start:end].reshape(-1, vocab_size)
sequence = token if token.dim() <= 1 else token[:, level] sequence = token.reshape(-1).long()
nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal ) nll, metrics = _calc_loss( logit, sequence, causal )
if name == "resp":
if nll is not None:
if f'{name}[{level}].nll' not in loss:
loss[f'{name}[{level}].nll'] = []
loss[f"{name}[{level}].nll"].append( nll * loss_factor )
if metrics is not None:
if f'{name}[{level}].acc' not in stats:
stats[f'{name}[{level}].acc'] = []
stats[f"{name}[{level}].acc"].append( metrics )
nll = None
metrics = None
else:
if nll:
nlls.append( nll )
if metrics:
accs.append( metrics )
if nlls:
nll = sum(nlls) / len(nlls)
if accs:
accs = sum(accs) / len(accs)
if nll is not None: if nll is not None:
if f'{name}.nll' not in loss: if f'{name}.nll' not in loss:
loss[f'{name}.nll'] = [] loss[f'{name}.nll'] = []