From 93feb5660f585309af98b8e2b762a2c28fc6b5fa Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 27 Feb 2025 23:59:56 -0600 Subject: [PATCH] do not like that --- vall_e/models/ar_nar_v2.py | 3 ++- vall_e/models/base_v2.py | 28 +++++++++------------------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index ad1eab7..7f42f42 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -1037,7 +1037,8 @@ def example_usage(): texts, proms, resps, tasks = sample_data() stats = {"step": i} - stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True) + with torch.autograd.set_detect_anomaly(cfg.trainer.detect_grad_anomaly): + stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True) stats |= {"grad_norm": engine.get_global_grad_norm()} tqdm.write(f"{stats}") diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index cbba2e9..77e0545 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -939,18 +939,13 @@ class Base_V2(nn.Module): name = f'{name}[{level}]' sequence = token if token.dim() <= 1 else token[:, level] - nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level ) + nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal ) else: sequence = token.t() nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal ) - for level in enumerate(self.n_resp_levels): - loss_key = f'{name}[{level}].nll' - if loss_key not in loss: - loss[loss_key] = [] - loss[loss_key].append( nll[level] * loss_factor ) - - nll = None + if nll is not None: + nll = nll.sum() loss_key = f'{name}.nll' acc_key = f'{name}.acc' @@ -982,7 +977,7 @@ class Base_V2(nn.Module): sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) - nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level ) + nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal ) else: nlls = [] accs = [] @@ -991,16 +986,11 @@ class Base_V2(nn.Module): sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) nll, metrics = _calc_loss( logit, sequence, causal, level ) - - if nll: - nlls.append( nll ) - if metrics: - accs.append( metrics ) - - if nlls: - nll = sum(nlls) / len(nlls) - if accs: - accs = sum(accs) / len(accs) + + nlls.append( nll ) + + nll = sum(nlls) + accs = sum(accs) / len(accs) if nll is not None: if 'nll' not in loss: