From 396163d40d717fc309dad196e79da756fefbd806 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 27 Feb 2025 23:55:06 -0600 Subject: [PATCH] do not like that --- vall_e/models/ar_nar_v2.py | 3 ++- vall_e/models/base_v2.py | 9 ++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index ad1eab7..7f42f42 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -1037,7 +1037,8 @@ def example_usage(): texts, proms, resps, tasks = sample_data() stats = {"step": i} - stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True) + with torch.autograd.set_detect_anomaly(cfg.trainer.detect_grad_anomaly): + stats |= engine.traverse(phns_list=texts, proms_list=proms, resps_list=resps, task_list=tasks, training=True) stats |= {"grad_norm": engine.get_global_grad_norm()} tqdm.write(f"{stats}") diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index cbba2e9..33faaff 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -944,13 +944,8 @@ class Base_V2(nn.Module): sequence = token.t() nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal ) - for level in enumerate(self.n_resp_levels): - loss_key = f'{name}[{level}].nll' - if loss_key not in loss: - loss[loss_key] = [] - loss[loss_key].append( nll[level] * loss_factor ) - - nll = None + if nll is not None: + nll = nll.mean() loss_key = f'{name}.nll' acc_key = f'{name}.acc'