do not like that
This commit is contained in:
parent
f4f435d7f5
commit
93feb5660f
|
@ -1037,7 +1037,8 @@ def example_usage():
|
|||
texts, proms, resps, tasks = sample_data()
|
||||
|
||||
stats = {"step": i}
|
||||
stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True)
|
||||
with torch.autograd.set_detect_anomaly(cfg.trainer.detect_grad_anomaly):
|
||||
stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True)
|
||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||
|
||||
tqdm.write(f"{stats}")
|
||||
|
|
|
@ -939,18 +939,13 @@ class Base_V2(nn.Module):
|
|||
name = f'{name}[{level}]'
|
||||
|
||||
sequence = token if token.dim() <= 1 else token[:, level]
|
||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
|
||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
|
||||
else:
|
||||
sequence = token.t()
|
||||
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
||||
|
||||
for level in enumerate(self.n_resp_levels):
|
||||
loss_key = f'{name}[{level}].nll'
|
||||
if loss_key not in loss:
|
||||
loss[loss_key] = []
|
||||
loss[loss_key].append( nll[level] * loss_factor )
|
||||
|
||||
nll = None
|
||||
if nll is not None:
|
||||
nll = nll.sum()
|
||||
|
||||
loss_key = f'{name}.nll'
|
||||
acc_key = f'{name}.acc'
|
||||
|
@ -982,7 +977,7 @@ class Base_V2(nn.Module):
|
|||
|
||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
|
||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
||||
else:
|
||||
nlls = []
|
||||
accs = []
|
||||
|
@ -991,16 +986,11 @@ class Base_V2(nn.Module):
|
|||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||
nll, metrics = _calc_loss( logit, sequence, causal, level )
|
||||
|
||||
if nll:
|
||||
nlls.append( nll )
|
||||
if metrics:
|
||||
accs.append( metrics )
|
||||
|
||||
if nlls:
|
||||
nll = sum(nlls) / len(nlls)
|
||||
if accs:
|
||||
accs = sum(accs) / len(accs)
|
||||
|
||||
nlls.append( nll )
|
||||
|
||||
nll = sum(nlls)
|
||||
accs = sum(accs) / len(accs)
|
||||
|
||||
if nll is not None:
|
||||
if 'nll' not in loss:
|
||||
|
|
Loading…
Reference in New Issue
Block a user