do not like that

This commit is contained in:
mrq 2025-02-27 23:56:02 -06:00
parent f4f435d7f5
commit fc25a9a7dc
2 changed files with 4 additions and 8 deletions

View File

@ -1037,6 +1037,7 @@ def example_usage():
texts, proms, resps, tasks = sample_data() texts, proms, resps, tasks = sample_data()
stats = {"step": i} stats = {"step": i}
with torch.autograd.set_detect_anomaly(cfg.trainer.detect_grad_anomaly):
stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True) stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True)
stats |= {"grad_norm": engine.get_global_grad_norm()} stats |= {"grad_norm": engine.get_global_grad_norm()}

View File

@ -944,13 +944,8 @@ class Base_V2(nn.Module):
sequence = token.t() sequence = token.t()
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal ) nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
for level in enumerate(self.n_resp_levels): if nll is not None:
loss_key = f'{name}[{level}].nll' nll = nll.sum()
if loss_key not in loss:
loss[loss_key] = []
loss[loss_key].append( nll[level] * loss_factor )
nll = None
loss_key = f'{name}.nll' loss_key = f'{name}.nll'
acc_key = f'{name}.acc' acc_key = f'{name}.acc'