do not like that
This commit is contained in:
parent
f4f435d7f5
commit
396163d40d
|
@ -1037,6 +1037,7 @@ def example_usage():
|
||||||
texts, proms, resps, tasks = sample_data()
|
texts, proms, resps, tasks = sample_data()
|
||||||
|
|
||||||
stats = {"step": i}
|
stats = {"step": i}
|
||||||
|
with torch.autograd.set_detect_anomaly(cfg.trainer.detect_grad_anomaly):
|
||||||
stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True)
|
stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True)
|
||||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||||
|
|
||||||
|
|
|
@ -944,13 +944,8 @@ class Base_V2(nn.Module):
|
||||||
sequence = token.t()
|
sequence = token.t()
|
||||||
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
||||||
|
|
||||||
for level in enumerate(self.n_resp_levels):
|
if nll is not None:
|
||||||
loss_key = f'{name}[{level}].nll'
|
nll = nll.mean()
|
||||||
if loss_key not in loss:
|
|
||||||
loss[loss_key] = []
|
|
||||||
loss[loss_key].append( nll[level] * loss_factor )
|
|
||||||
|
|
||||||
nll = None
|
|
||||||
|
|
||||||
loss_key = f'{name}.nll'
|
loss_key = f'{name}.nll'
|
||||||
acc_key = f'{name}.acc'
|
acc_key = f'{name}.acc'
|
||||||
|
|
Loading…
Reference in New Issue
Block a user